10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
18 : OpKernel(construction) {
19 using namespace tensorflow;
20 OP_REQUIRES_OK(construction, construction->GetAttr(
"sampled_pts_num",
25 using namespace tensorflow;
27 const Tensor& inp_tensor =
context->input(0);
30 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
31 errors::InvalidArgument(
"RoiPool expects "
32 "(batch_size,num_points,3) inp shape"));
33 int batch_size = inp_tensor.shape().dim_size(0);
34 int pts_num = inp_tensor.shape().dim_size(1);
35 auto inp_flat = inp_tensor.flat<
float>();
36 const float* inp = &(inp_flat(0));
38 const Tensor& boxes3d_tensor =
context->input(1);
40 boxes3d_tensor.dims() == 3 &&
41 boxes3d_tensor.shape().dim_size(2) == 7,
42 errors::InvalidArgument(
44 "(batch_size,num_boxes,7) boxes3d shape"));
45 int boxes_num = boxes3d_tensor.shape().dim_size(1);
46 auto boxes3d_flat = boxes3d_tensor.flat<
float>();
47 const float* boxes3d = &(boxes3d_flat(0));
49 const Tensor& feats_tensor =
context->input(2);
51 feats_tensor.dims() == 3 &&
52 feats_tensor.shape().dim_size(1) == pts_num,
53 errors::InvalidArgument(
55 "(batch_size,num_points,feats) feats shape"));
56 int feature_in_len = feats_tensor.shape().dim_size(2);
57 auto feats_flat = feats_tensor.flat<
float>();
58 const float* feats = &(feats_flat(0));
64 TensorShape{batch_size, boxes_num,
65 sampled_pts_num, 3 + feature_in_len},
67 auto out_flat0 = out_feats->flat<
float>();
68 float* out0 = &(out_flat0(0));
72 1, TensorShape{batch_size, boxes_num},
74 auto out_flat1 = out_flags->flat<
int>();
75 int* out1 = &(out_flat1(0));
77 Kernel(
context, batch_size, pts_num, boxes_num, feature_in_len,
89 const float* pts_feature,
90 float* pooled_features,
91 int* pooled_empty_flag) = 0;
ImGuiContext * context
Definition: Window.cpp:76
Definition: RoiPoolOpKernel.h:15
RoiPoolOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: RoiPoolOpKernel.h:17
int sampled_pts_num
Definition: RoiPoolOpKernel.h:94
virtual void Kernel(tensorflow::OpKernelContext *context, int batch_size, int pts_num, int boxes_num, int feature_in_len, int sampled_pts_num, const float *xyz, const float *boxes3d, const float *pts_feature, float *pooled_features, int *pooled_empty_flag)=0
void Compute(tensorflow::OpKernelContext *context) override
Definition: RoiPoolOpKernel.h:24