10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
18 tensorflow::OpKernelConstruction* construction)
19 : OpKernel(construction) {
20 using namespace tensorflow;
22 OP_REQUIRES_OK(construction,
23 construction->GetAttr(
"sample_size", &
sample_size));
25 errors::InvalidArgument(
26 "FurthestPointSampling expects positive npoint"));
30 using namespace tensorflow;
32 const Tensor& inp_tensor =
context->input(0);
35 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
36 errors::InvalidArgument(
"FurthestPointSampling expects "
37 "(batch_size,num_points,3) inp shape"));
38 int batch_size = inp_tensor.shape().dim_size(0);
39 int pts_size = inp_tensor.shape().dim_size(1);
40 auto inp_flat = inp_tensor.flat<
float>();
41 const float* inp = &(inp_flat(0));
45 0, TensorShape{batch_size, sample_size},
47 auto out_flat = out_tensor->flat<
int>();
48 int* out = &(out_flat(0));
52 context->allocate_temp(DataTypeToEnum<float>::value,
53 TensorShape{batch_size, pts_size},
55 auto temp_flat = temp_tensor.flat<
float>();
56 float* temp = &(temp_flat(0));
ImGuiContext * context
Definition: Window.cpp:76
Definition: SamplingOpKernel.h:15
void Compute(tensorflow::OpKernelContext *context) override
Definition: SamplingOpKernel.h:29
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *dataset, float *temp, int *idxs)=0
int sample_size
Definition: SamplingOpKernel.h:70
FurthestPointSamplingOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: SamplingOpKernel.h:17