Open3D (C++ API)  0.17.0
SamplingOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
14 
15 class FurthestPointSamplingOpKernel : public tensorflow::OpKernel {
16 public:
18  tensorflow::OpKernelConstruction* construction)
19  : OpKernel(construction) {
20  using namespace tensorflow;
21 
22  OP_REQUIRES_OK(construction,
23  construction->GetAttr("sample_size", &sample_size));
24  OP_REQUIRES(construction, sample_size > 0,
25  errors::InvalidArgument(
26  "FurthestPointSampling expects positive npoint"));
27  }
28 
29  void Compute(tensorflow::OpKernelContext* context) override {
30  using namespace tensorflow;
31 
32  const Tensor& inp_tensor = context->input(0);
33  OP_REQUIRES(
34  context,
35  inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
36  errors::InvalidArgument("FurthestPointSampling expects "
37  "(batch_size,num_points,3) inp shape"));
38  int batch_size = inp_tensor.shape().dim_size(0);
39  int pts_size = inp_tensor.shape().dim_size(1);
40  auto inp_flat = inp_tensor.flat<float>();
41  const float* inp = &(inp_flat(0));
42 
43  Tensor* out_tensor;
44  OP_REQUIRES_OK(context, context->allocate_output(
45  0, TensorShape{batch_size, sample_size},
46  &out_tensor));
47  auto out_flat = out_tensor->flat<int>();
48  int* out = &(out_flat(0));
49 
50  Tensor temp_tensor;
51  OP_REQUIRES_OK(context,
52  context->allocate_temp(DataTypeToEnum<float>::value,
53  TensorShape{batch_size, pts_size},
54  &temp_tensor));
55  auto temp_flat = temp_tensor.flat<float>();
56  float* temp = &(temp_flat(0));
57 
58  Kernel(context, batch_size, pts_size, sample_size, inp, temp, out);
59  }
60 
61  virtual void Kernel(tensorflow::OpKernelContext* context,
62  int b,
63  int n,
64  int m,
65  const float* dataset,
66  float* temp,
67  int* idxs) = 0;
68 
69 protected:
71 };
ImGuiContext * context
Definition: Window.cpp:76
Definition: SamplingOpKernel.h:15
void Compute(tensorflow::OpKernelContext *context) override
Definition: SamplingOpKernel.h:29
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *dataset, float *temp, int *idxs)=0
int sample_size
Definition: SamplingOpKernel.h:70
FurthestPointSamplingOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: SamplingOpKernel.h:17