10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
18 tensorflow::OpKernelConstruction*
context)
19 : tensorflow::OpKernel(
context) {
20 using namespace tensorflow;
24 errors::InvalidArgument(
25 "TrilinearDevoxelize expects positive resolution"));
29 using namespace tensorflow;
30 const Tensor& coords =
context->input(0);
32 context, coords.dims() == 3 && coords.shape().dim_size(1) == 3,
33 errors::InvalidArgument(
"TrilinearDevoxelize expects "
34 "(batch_size, 3, N) coordinate shape"));
35 const Tensor& feat =
context->input(1);
36 OP_REQUIRES(
context, feat.dims() == 5,
37 errors::InvalidArgument(
"TrilinearDevoxelize expects "
38 "5 dimensions for features"));
40 int batch_size = coords.shape().dim_size(0);
41 int num_points = coords.shape().dim_size(2);
42 int feat_dim = feat.shape().dim_size(1);
44 auto coords_flat = coords.flat<
float>();
45 auto feat_flat = feat.flat<
float>();
47 const float* inp_coords = &(coords_flat(0));
48 const float* inp_feat = &(feat_flat(0));
53 0, TensorShape{batch_size, feat_dim, num_points},
58 1, TensorShape{batch_size, 8, num_points},
63 2, TensorShape{batch_size, 8, num_points},
65 auto flat_0 = out_tensor_0->flat<
float>();
66 auto flat_1 = out_tensor_1->flat<
int>();
67 auto flat_2 = out_tensor_2->flat<
float>();
69 float* out_0 = &(flat_0(0));
70 int* out_1 = &(flat_1(0));
71 float* out_2 = &(flat_2(0));
75 r *
r *
r,
true, inp_coords, inp_feat, out_1, out_2, out_0);
78 r *
r *
r,
false, inp_coords, inp_feat, out_1, out_2, out_0);
103 tensorflow::OpKernelConstruction*
context)
104 : tensorflow::OpKernel(
context) {
105 using namespace tensorflow;
109 errors::InvalidArgument(
110 "TrilinearDevoxelizeGrad expects positive resolution"));
114 using namespace tensorflow;
115 const Tensor& grad_y =
context->input(0);
118 errors::InvalidArgument(
"TrilinearDevoxelizeGrad expects "
119 "(batch_size, C, N) gradient shape"));
120 const Tensor& inds =
context->input(1);
122 context, inds.dims() == 3 && inds.shape().dim_size(1) == 8,
123 errors::InvalidArgument(
"TrilinearDevoxelizeGrad expects "
124 "(batch_size, 8, N) indices shape"));
125 const Tensor& wgts =
context->input(2);
127 context, wgts.dims() == 3 && wgts.shape().dim_size(1) == 8,
128 errors::InvalidArgument(
"TrilinearDevoxelizeGrad expects "
129 "(batch_size, 8, N) weights shape"));
131 int batch_size = grad_y.shape().dim_size(0);
132 int num_points = grad_y.shape().dim_size(2);
133 int feat_dim = grad_y.shape().dim_size(1);
135 auto grad_y_flat = grad_y.flat<
float>();
136 auto inds_flat = inds.flat<
int>();
137 auto wgts_flat = wgts.flat<
float>();
139 const float* inp_grad_y = &(grad_y_flat(0));
140 const int* inp_inds = &(inds_flat(0));
141 const float* inp_wgts = &(wgts_flat(0));
146 0, TensorShape{batch_size, feat_dim, r, r, r},
148 auto flat_tensor = out_tensor->flat<
float>();
150 float* out = &(flat_tensor(0));
153 inp_wgts, inp_grad_y, out);
ImGuiContext * context
Definition: Window.cpp:76
Definition: TrilinearDevoxelizeKernel.h:100
int r
Definition: TrilinearDevoxelizeKernel.h:167
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:113
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r3, const int *inds, const float *wgts, const float *grad_y, float *grad_x)=0
TrilinearDevoxelizeGradOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:102
Definition: TrilinearDevoxelizeKernel.h:15
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r, int r2, int r3, bool training, const float *coords, const float *feat, int *inds, float *wgts, float *outs)=0
TrilinearDevoxelizeOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:17
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:28
int r
Definition: TrilinearDevoxelizeKernel.h:96
bool is_training
Definition: TrilinearDevoxelizeKernel.h:97