10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
18 : OpKernel(construction) {
19 using namespace tensorflow;
21 OP_REQUIRES_OK(construction,
22 construction->GetAttr(
"nsample", &
nsample));
23 OP_REQUIRES_OK(construction, construction->GetAttr(
"radius", &
radius));
26 errors::InvalidArgument(
"BallQuery expects positive nsample"));
30 using namespace tensorflow;
32 const Tensor& inp_tensor =
context->input(0);
35 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
36 errors::InvalidArgument(
"BallQuery expects "
37 "(batch_size,num_points,3) inp shape"));
38 int batch_size = inp_tensor.shape().dim_size(0);
39 int pts_size = inp_tensor.shape().dim_size(1);
40 auto inp_flat = inp_tensor.flat<
float>();
41 const float* inp = &(inp_flat(0));
43 const Tensor& center_tensor =
context->input(1);
45 center_tensor.dims() == 3 &&
46 center_tensor.shape().dim_size(2) == 3,
47 errors::InvalidArgument(
49 "(batch_size,num_points,3) center shape"));
50 int ball_size = center_tensor.shape().dim_size(1);
51 auto center_flat = center_tensor.flat<
float>();
52 const float* center = &(center_flat(0));
57 0, TensorShape{batch_size, ball_size, nsample},
59 auto out_flat = out_tensor->flat<
int>();
60 int* out = &(out_flat(0));
ImGuiContext * context
Definition: Window.cpp:76
Definition: BallQueryOpKernel.h:15
int nsample
Definition: BallQueryOpKernel.h:77
BallQueryOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: BallQueryOpKernel.h:17
void Compute(tensorflow::OpKernelContext *context) override
Definition: BallQueryOpKernel.h:29
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx)=0
float radius
Definition: BallQueryOpKernel.h:78