Open3D (C++ API)  0.18.0+252c867
VoxelPoolingGradOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
14 
16 // namespace for code that is common for all kernels
17 namespace voxel_pooling_opkernel {
18 
19 template <class TReal, class TFeat>
20 class OutputAllocator {
21 public:
22  OutputAllocator(tensorflow::OpKernelContext* context) : context(context) {}
23 
24  void AllocPooledPositions(TReal** ptr, size_t num) {
25  using namespace tensorflow;
26  *ptr = nullptr;
27  Tensor* tensor = 0;
28  TensorShape shape({int64_t(num), 3});
29  OP_REQUIRES_OK(context, context->allocate_output(0, shape, &tensor));
30  auto flat_tensor = tensor->flat<TReal>();
31  *ptr = flat_tensor.data();
32  }
33 
34  void AllocPooledFeatures(TFeat** ptr, size_t num, int channels) {
35  using namespace tensorflow;
36  *ptr = nullptr;
37  Tensor* tensor = 0;
38  TensorShape shape({int64_t(num), channels});
39  OP_REQUIRES_OK(context, context->allocate_output(1, shape, &tensor));
40  auto flat_tensor = tensor->flat<TFeat>();
41  *ptr = flat_tensor.data();
42  }
43 
44 private:
45  tensorflow::OpKernelContext* context;
46 };
47 
48 // Base class with common code for the OpKernel implementations
49 class VoxelPoolingGradOpKernel : public tensorflow::OpKernel {
50 public:
51  explicit VoxelPoolingGradOpKernel(
52  tensorflow::OpKernelConstruction* construction)
53  : OpKernel(construction) {
54  using namespace tensorflow;
55  using namespace open3d::ml::impl;
56  std::string pos_fn_str;
57  OP_REQUIRES_OK(construction,
58  construction->GetAttr("position_fn", &pos_fn_str));
59 
60  if (pos_fn_str == "average")
61  position_fn = AVERAGE;
62  else if (pos_fn_str == "nearest_neighbor")
63  position_fn = NEAREST_NEIGHBOR;
64  else
65  position_fn = CENTER;
66 
67  std::string feat_fn_str;
68  OP_REQUIRES_OK(construction,
69  construction->GetAttr("feature_fn", &feat_fn_str));
70 
71  if (feat_fn_str == "average")
72  feature_fn = AVERAGE;
73  else if (feat_fn_str == "nearest_neighbor")
74  feature_fn = NEAREST_NEIGHBOR;
75  else
76  feature_fn = MAX;
77  }
78 
79  void Compute(tensorflow::OpKernelContext* context) override {
80  using namespace tensorflow;
81  using namespace open3d::ml::impl;
82 
83  const Tensor& positions = context->input(0);
84  OP_REQUIRES(
85  context, positions.shape().dims() == 2,
86  errors::InvalidArgument("positions must be a rank 2 tensor"));
87 
88  const Tensor& features = context->input(1);
89  OP_REQUIRES(
90  context, features.shape().dims() == 2,
91  errors::InvalidArgument("features must be a rank 2 tensor"));
92 
93  const Tensor& voxel_size = context->input(2);
94  OP_REQUIRES(
95  context, TensorShapeUtils::IsScalar(voxel_size.shape()),
96  errors::InvalidArgument("voxel_size must be a scalar, but is ",
97  voxel_size.shape().DebugString()));
98 
99  const Tensor& pooled_positions = context->input(3);
100  OP_REQUIRES(context, pooled_positions.shape().dims() == 2,
101  errors::InvalidArgument(
102  "pooled_positions must be a rank 2 tensor"));
103 
104  const Tensor& pooled_features_gradient = context->input(4);
105  OP_REQUIRES(
106  context, pooled_features_gradient.shape().dims() == 2,
107  errors::InvalidArgument(
108  "pooled_features_gradient must be a rank 2 tensor"));
109 
110  Tensor* features_backprop = nullptr;
111  OP_REQUIRES_OK(context, context->allocate_output(0, features.shape(),
112  &features_backprop));
113 
114  Kernel(context, *features_backprop, positions, features,
115  pooled_positions, pooled_features_gradient, voxel_size);
116  }
117 
118  // Function with the device specific code
119  virtual void Kernel(tensorflow::OpKernelContext* context,
120  tensorflow::Tensor& features_backprop,
121  const tensorflow::Tensor& positions,
122  const tensorflow::Tensor& features,
123  const tensorflow::Tensor& pooled_positions,
124  const tensorflow::Tensor& pooled_features_gradient,
125  const tensorflow::Tensor& voxel_size) = 0;
126 
127 protected:
130 };
131 
132 } // namespace voxel_pooling_opkernel
ImGuiContext * context
Definition: Window.cpp:76
Definition: ContinuousConv.h:16
AccumulationFn
Definition: VoxelPooling.h:21
@ CENTER
Definition: VoxelPooling.h:21
@ NEAREST_NEIGHBOR
Definition: VoxelPooling.h:21
@ MAX
Definition: VoxelPooling.h:21
@ AVERAGE
Definition: VoxelPooling.h:21