Open3D (C++ API)  0.17.0
VoxelPoolingOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
14 
16 // namespace for code that is common for all kernels
17 namespace voxel_pooling_opkernel {
18 
19 template <class TReal, class TFeat>
20 class OutputAllocator {
21 public:
22  OutputAllocator(tensorflow::OpKernelContext* context) : context(context) {}
23 
24  void AllocPooledPositions(TReal** ptr, size_t num) {
25  using namespace tensorflow;
26  *ptr = nullptr;
27  Tensor* tensor = 0;
28  TensorShape shape({int64_t(num), 3});
29  OP_REQUIRES_OK(context, context->allocate_output(0, shape, &tensor));
30  auto flat_tensor = tensor->flat<TReal>();
31  *ptr = flat_tensor.data();
32  }
33 
34  void AllocPooledFeatures(TFeat** ptr, size_t num, int channels) {
35  using namespace tensorflow;
36  *ptr = nullptr;
37  Tensor* tensor = 0;
38  TensorShape shape({int64_t(num), channels});
39  OP_REQUIRES_OK(context, context->allocate_output(1, shape, &tensor));
40  auto flat_tensor = tensor->flat<TFeat>();
41  *ptr = flat_tensor.data();
42  }
43 
44 private:
45  tensorflow::OpKernelContext* context;
46 };
47 
48 // Base class with common code for the OpKernel implementations
49 class VoxelPoolingOpKernel : public tensorflow::OpKernel {
50 public:
51  explicit VoxelPoolingOpKernel(
52  tensorflow::OpKernelConstruction* construction)
53  : OpKernel(construction) {
54  using namespace tensorflow;
55  using namespace open3d::ml::impl;
56  std::string pos_fn_str;
57  OP_REQUIRES_OK(construction,
58  construction->GetAttr("position_fn", &pos_fn_str));
59 
60  if (pos_fn_str == "average")
61  position_fn = AVERAGE;
62  else if (pos_fn_str == "nearest_neighbor")
63  position_fn = NEAREST_NEIGHBOR;
64  else
65  position_fn = CENTER;
66 
67  std::string feat_fn_str;
68  OP_REQUIRES_OK(construction,
69  construction->GetAttr("feature_fn", &feat_fn_str));
70 
71  if (feat_fn_str == "average")
72  feature_fn = AVERAGE;
73  else if (feat_fn_str == "nearest_neighbor")
74  feature_fn = NEAREST_NEIGHBOR;
75  else
76  feature_fn = MAX;
77 
78  OP_REQUIRES_OK(construction, construction->GetAttr("debug", &debug));
79  }
80 
81  void Compute(tensorflow::OpKernelContext* context) override {
82  using namespace tensorflow;
83  using namespace open3d::ml::impl;
84  const Tensor& positions = context->input(0);
85  OP_REQUIRES(
86  context, positions.shape().dims() == 2,
87  errors::InvalidArgument("positions must be a rank 2 tensor"));
88 
89  const Tensor& features = context->input(1);
90  OP_REQUIRES(
91  context, features.shape().dims() == 2,
92  errors::InvalidArgument("features must be a rank 2 tensor"));
93 
94  const Tensor& voxel_size = context->input(2);
95  OP_REQUIRES(
96  context, TensorShapeUtils::IsScalar(voxel_size.shape()),
97  errors::InvalidArgument("voxel_size must be a scalar, but is ",
98  voxel_size.shape().DebugString()));
99 
100  Kernel(context, positions, features, voxel_size);
101  }
102 
103  // Function with the device specific code
104  virtual void Kernel(tensorflow::OpKernelContext* context,
105  const tensorflow::Tensor& positions,
106  const tensorflow::Tensor& features,
107  const tensorflow::Tensor& voxel_size) = 0;
108 
109 protected:
112  bool debug;
113 };
114 
115 } // namespace voxel_pooling_opkernel
ImGuiContext * context
Definition: Window.cpp:76
Definition: ContinuousConv.h:16
AccumulationFn
Definition: VoxelPooling.h:21
@ CENTER
Definition: VoxelPooling.h:21
@ NEAREST_NEIGHBOR
Definition: VoxelPooling.h:21
@ MAX
Definition: VoxelPooling.h:21
@ AVERAGE
Definition: VoxelPooling.h:21