10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
18 : OpKernel(construction) {}
21 using namespace tensorflow;
23 const Tensor& inp_tensor =
context->input(0);
26 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
27 errors::InvalidArgument(
"ThreeNN expects "
28 "(batch_size,num_points,3) inp shape"));
29 int batch_size = inp_tensor.shape().dim_size(0);
30 int pts_num_out = inp_tensor.shape().dim_size(1);
31 auto inp_flat = inp_tensor.flat<
float>();
32 const float* inp = &(inp_flat(0));
34 const Tensor& data_tensor =
context->input(1);
37 data_tensor.dims() == 3 && data_tensor.shape().dim_size(2) == 3,
38 errors::InvalidArgument(
40 "(batch_size,num_points,3) data shape"));
41 int pts_num_in = data_tensor.shape().dim_size(1);
42 auto data_flat = data_tensor.flat<
float>();
43 const float*
data = &(data_flat(0));
49 0, TensorShape{batch_size, pts_num_out, 3}, &out_dist));
50 auto out_flat0 = out_dist->flat<
float>();
51 float* out0 = &(out_flat0(0));
57 1, TensorShape{batch_size, pts_num_out, 3}, &out_idx));
58 auto out_flat1 = out_idx->flat<
int>();
59 int* out1 = &(out_flat1(0));
78 tensorflow::OpKernelConstruction* construction)
79 : OpKernel(construction) {}
82 using namespace tensorflow;
84 const Tensor& inp_tensor =
context->input(0);
86 context, inp_tensor.dims() == 3,
87 errors::InvalidArgument(
"ThreeInterpolate expects "
88 "(batch_size,num_points,3) inp shape"));
89 int batch_size = inp_tensor.shape().dim_size(0);
90 int C = inp_tensor.shape().dim_size(1);
91 int M = inp_tensor.shape().dim_size(2);
92 auto inp_flat = inp_tensor.flat<
float>();
93 const float* inp = &(inp_flat(0));
95 const Tensor& idx_tensor =
context->input(1);
97 context, idx_tensor.dims() == 3,
98 errors::InvalidArgument(
"ThreeInterpolate expects "
99 "(batch_size,num_points,3) idx shape"));
100 int N = idx_tensor.shape().dim_size(1);
101 auto idx_flat = idx_tensor.flat<
int>();
102 const int* idx = &(idx_flat(0));
104 const Tensor& weights_tensor =
context->input(2);
105 OP_REQUIRES(
context, weights_tensor.dims() == 3,
106 errors::InvalidArgument(
107 "ThreeInterpolate expects "
108 "(batch_size,num_points,3) weights shape"));
109 auto weights_flat = weights_tensor.flat<
float>();
110 const float* weights = &(weights_flat(0));
115 0, TensorShape{batch_size, C, N}, &out_tensor));
116 auto out_flat = out_tensor->flat<
float>();
117 float* out = &(out_flat(0));
136 tensorflow::OpKernelConstruction* construction)
137 : OpKernel(construction) {
138 OP_REQUIRES_OK(construction, construction->GetAttr(
"M", &
M));
142 using namespace tensorflow;
144 const Tensor& inp_tensor =
context->input(0);
146 context, inp_tensor.dims() == 3,
147 errors::InvalidArgument(
"ThreeInterpolateGrad expects "
148 "(batch_size,num_points,3) inp shape"));
149 int batch_size = inp_tensor.shape().dim_size(0);
150 int C = inp_tensor.shape().dim_size(1);
151 int N = inp_tensor.shape().dim_size(2);
152 auto inp_flat = inp_tensor.flat<
float>();
153 const float* inp = &(inp_flat(0));
155 const Tensor& idx_tensor =
context->input(1);
157 context, idx_tensor.dims() == 3,
158 errors::InvalidArgument(
"ThreeInterpolateGrad expects "
159 "(batch_size,num_points,3) idx shape"));
160 auto idx_flat = idx_tensor.flat<
int>();
161 const int* idx = &(idx_flat(0));
163 const Tensor& weights_tensor =
context->input(2);
164 OP_REQUIRES(
context, weights_tensor.dims() == 3,
165 errors::InvalidArgument(
166 "ThreeInterpolateGrad expects "
167 "(batch_size,num_points,3) weights shape"));
168 auto weights_flat = weights_tensor.flat<
float>();
169 const float* weights = &(weights_flat(0));
174 0, TensorShape{batch_size, C, M}, &out_tensor));
175 auto out_flat = out_tensor->flat<
float>();
176 float* out = &(out_flat(0));
186 const float* grad_out,
189 float* grad_points) = 0;
Eigen::Matrix3Xd M
Definition: PointCloudPlanarPatchDetection.cpp:507
ImGuiContext * context
Definition: Window.cpp:76
Definition: InterpolateOpKernel.h:133
ThreeInterpolateGradOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:135
int M
Definition: InterpolateOpKernel.h:192
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points)=0
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:141
Definition: InterpolateOpKernel.h:75
ThreeInterpolateOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:77
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:81
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out)=0
Definition: InterpolateOpKernel.h:15
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:20
ThreeNNOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:17
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)=0
const char const char value recording_handle imu_sample recording_handle uint8_t data
Definition: K4aPlugin.cpp:269