Open3D (C++ API)  0.18.0
ReduceSubarraysSumOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "tensorflow/core/framework/op.h"
11 #include "tensorflow/core/framework/op_kernel.h"
12 #include "tensorflow/core/lib/core/errors.h"
13 
15 // namespace for code that is common for all kernels
16 namespace reduce_subarrays_sum_opkernel {
17 
18 // Base class with common code for the OpKernel implementations
19 class ReduceSubarraysSumOpKernel : public tensorflow::OpKernel {
20 public:
21  explicit ReduceSubarraysSumOpKernel(
22  tensorflow::OpKernelConstruction* construction)
23  : OpKernel(construction) {}
24 
25  void Compute(tensorflow::OpKernelContext* context) override {
26  using namespace tensorflow;
27  static_assert(sizeof(int64) == sizeof(int64_t),
28  "int64 type is not compatible");
29 
30  const Tensor& values = context->input(0);
31  OP_REQUIRES(context, values.shape().dims() == 1,
32  errors::InvalidArgument("values must be a rank 1 tensor"));
33 
34  const Tensor& row_splits = context->input(1);
35  OP_REQUIRES(
36  context, row_splits.shape().dims() == 1,
37  errors::InvalidArgument("row_splits must be a rank 1 tensor"));
38 
39  // special treatment for empty values vector
40  if (values.shape().dim_size(0) == 0) {
41  Tensor* sums_tensor = 0;
42  OP_REQUIRES_OK(context, context->allocate_output(0, values.shape(),
43  &sums_tensor));
44  return;
45  }
46 
47  Tensor* sums_tensor = 0;
48  TensorShape sums_shape({row_splits.shape().dim_size(0) - 1});
49  OP_REQUIRES_OK(context,
50  context->allocate_output(0, sums_shape, &sums_tensor));
51 
52  Kernel(context, values, row_splits, *sums_tensor);
53  }
54 
55  // Function with the device specific code
56  virtual void Kernel(tensorflow::OpKernelContext* context,
57  const tensorflow::Tensor& values,
58  const tensorflow::Tensor& row_splits,
59  tensorflow::Tensor& sums) = 0;
60 
61 private:
62 };
63 
64 } // namespace reduce_subarrays_sum_opkernel
ImGuiContext * context
Definition: Window.cpp:76