Open3D (C++ API)  0.18.0
InterpolateOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
14 
15 class ThreeNNOpKernel : public tensorflow::OpKernel {
16 public:
17  explicit ThreeNNOpKernel(tensorflow::OpKernelConstruction* construction)
18  : OpKernel(construction) {}
19 
20  void Compute(tensorflow::OpKernelContext* context) override {
21  using namespace tensorflow;
22 
23  const Tensor& inp_tensor = context->input(0);
24  OP_REQUIRES(
25  context,
26  inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
27  errors::InvalidArgument("ThreeNN expects "
28  "(batch_size,num_points,3) inp shape"));
29  int batch_size = inp_tensor.shape().dim_size(0);
30  int pts_num_out = inp_tensor.shape().dim_size(1);
31  auto inp_flat = inp_tensor.flat<float>();
32  const float* inp = &(inp_flat(0));
33 
34  const Tensor& data_tensor = context->input(1);
35  OP_REQUIRES(
36  context,
37  data_tensor.dims() == 3 && data_tensor.shape().dim_size(2) == 3,
38  errors::InvalidArgument(
39  "ThreeNN expects "
40  "(batch_size,num_points,3) data shape"));
41  int pts_num_in = data_tensor.shape().dim_size(1);
42  auto data_flat = data_tensor.flat<float>();
43  const float* data = &(data_flat(0));
44 
45  Tensor* out_dist;
46  OP_REQUIRES_OK(
47  context,
48  context->allocate_output(
49  0, TensorShape{batch_size, pts_num_out, 3}, &out_dist));
50  auto out_flat0 = out_dist->flat<float>();
51  float* out0 = &(out_flat0(0));
52 
53  Tensor* out_idx;
54  OP_REQUIRES_OK(
55  context,
56  context->allocate_output(
57  1, TensorShape{batch_size, pts_num_out, 3}, &out_idx));
58  auto out_flat1 = out_idx->flat<int>();
59  int* out1 = &(out_flat1(0));
60 
61  Kernel(context, batch_size, pts_num_out, pts_num_in, inp, data, out0,
62  out1);
63  }
64 
65  virtual void Kernel(tensorflow::OpKernelContext* context,
66  int b,
67  int n,
68  int m,
69  const float* unknown,
70  const float* known,
71  float* dist2,
72  int* idx) = 0;
73 };
74 
75 class ThreeInterpolateOpKernel : public tensorflow::OpKernel {
76 public:
78  tensorflow::OpKernelConstruction* construction)
79  : OpKernel(construction) {}
80 
81  void Compute(tensorflow::OpKernelContext* context) override {
82  using namespace tensorflow;
83 
84  const Tensor& inp_tensor = context->input(0);
85  OP_REQUIRES(
86  context, inp_tensor.dims() == 3,
87  errors::InvalidArgument("ThreeInterpolate expects "
88  "(batch_size,num_points,3) inp shape"));
89  int batch_size = inp_tensor.shape().dim_size(0);
90  int C = inp_tensor.shape().dim_size(1);
91  int M = inp_tensor.shape().dim_size(2);
92  auto inp_flat = inp_tensor.flat<float>();
93  const float* inp = &(inp_flat(0));
94 
95  const Tensor& idx_tensor = context->input(1);
96  OP_REQUIRES(
97  context, idx_tensor.dims() == 3,
98  errors::InvalidArgument("ThreeInterpolate expects "
99  "(batch_size,num_points,3) idx shape"));
100  int N = idx_tensor.shape().dim_size(1);
101  auto idx_flat = idx_tensor.flat<int>();
102  const int* idx = &(idx_flat(0));
103 
104  const Tensor& weights_tensor = context->input(2);
105  OP_REQUIRES(context, weights_tensor.dims() == 3,
106  errors::InvalidArgument(
107  "ThreeInterpolate expects "
108  "(batch_size,num_points,3) weights shape"));
109  auto weights_flat = weights_tensor.flat<float>();
110  const float* weights = &(weights_flat(0));
111 
112  Tensor* out_tensor;
113  OP_REQUIRES_OK(context,
114  context->allocate_output(
115  0, TensorShape{batch_size, C, N}, &out_tensor));
116  auto out_flat = out_tensor->flat<float>();
117  float* out = &(out_flat(0));
118 
119  Kernel(context, batch_size, C, M, N, inp, idx, weights, out);
120  }
121 
122  virtual void Kernel(tensorflow::OpKernelContext* context,
123  int b,
124  int c,
125  int m,
126  int n,
127  const float* points,
128  const int* idx,
129  const float* weight,
130  float* out) = 0;
131 };
132 
133 class ThreeInterpolateGradOpKernel : public tensorflow::OpKernel {
134 public:
136  tensorflow::OpKernelConstruction* construction)
137  : OpKernel(construction) {
138  OP_REQUIRES_OK(construction, construction->GetAttr("M", &M));
139  }
140 
141  void Compute(tensorflow::OpKernelContext* context) override {
142  using namespace tensorflow;
143 
144  const Tensor& inp_tensor = context->input(0);
145  OP_REQUIRES(
146  context, inp_tensor.dims() == 3,
147  errors::InvalidArgument("ThreeInterpolateGrad expects "
148  "(batch_size,num_points,3) inp shape"));
149  int batch_size = inp_tensor.shape().dim_size(0);
150  int C = inp_tensor.shape().dim_size(1);
151  int N = inp_tensor.shape().dim_size(2);
152  auto inp_flat = inp_tensor.flat<float>();
153  const float* inp = &(inp_flat(0));
154 
155  const Tensor& idx_tensor = context->input(1);
156  OP_REQUIRES(
157  context, idx_tensor.dims() == 3,
158  errors::InvalidArgument("ThreeInterpolateGrad expects "
159  "(batch_size,num_points,3) idx shape"));
160  auto idx_flat = idx_tensor.flat<int>();
161  const int* idx = &(idx_flat(0));
162 
163  const Tensor& weights_tensor = context->input(2);
164  OP_REQUIRES(context, weights_tensor.dims() == 3,
165  errors::InvalidArgument(
166  "ThreeInterpolateGrad expects "
167  "(batch_size,num_points,3) weights shape"));
168  auto weights_flat = weights_tensor.flat<float>();
169  const float* weights = &(weights_flat(0));
170 
171  Tensor* out_tensor;
172  OP_REQUIRES_OK(context,
173  context->allocate_output(
174  0, TensorShape{batch_size, C, M}, &out_tensor));
175  auto out_flat = out_tensor->flat<float>();
176  float* out = &(out_flat(0));
177 
178  Kernel(context, batch_size, C, N, M, inp, idx, weights, out);
179  }
180 
181  virtual void Kernel(tensorflow::OpKernelContext* context,
182  int b,
183  int c,
184  int n,
185  int m,
186  const float* grad_out,
187  const int* idx,
188  const float* weight,
189  float* grad_points) = 0;
190 
191 protected:
192  int M;
193 };
Eigen::Matrix3Xd M
Definition: PointCloudPlanarPatchDetection.cpp:507
ImGuiContext * context
Definition: Window.cpp:76
Definition: InterpolateOpKernel.h:133
ThreeInterpolateGradOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:135
int M
Definition: InterpolateOpKernel.h:192
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points)=0
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:141
Definition: InterpolateOpKernel.h:75
ThreeInterpolateOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:77
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:81
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out)=0
Definition: InterpolateOpKernel.h:15
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:20
ThreeNNOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:17
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)=0
int points
Definition: FilePCD.cpp:54
const char const char value recording_handle imu_sample recording_handle uint8_t data
Definition: K4aPlugin.cpp:269