Open3D (C++ API)  0.18.0
BallQueryOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
14 
15 class BallQueryOpKernel : public tensorflow::OpKernel {
16 public:
17  explicit BallQueryOpKernel(tensorflow::OpKernelConstruction* construction)
18  : OpKernel(construction) {
19  using namespace tensorflow;
20 
21  OP_REQUIRES_OK(construction,
22  construction->GetAttr("nsample", &nsample));
23  OP_REQUIRES_OK(construction, construction->GetAttr("radius", &radius));
24  OP_REQUIRES(
25  construction, nsample > 0,
26  errors::InvalidArgument("BallQuery expects positive nsample"));
27  }
28 
29  void Compute(tensorflow::OpKernelContext* context) override {
30  using namespace tensorflow;
31 
32  const Tensor& inp_tensor = context->input(0);
33  OP_REQUIRES(
34  context,
35  inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
36  errors::InvalidArgument("BallQuery expects "
37  "(batch_size,num_points,3) inp shape"));
38  int batch_size = inp_tensor.shape().dim_size(0);
39  int pts_size = inp_tensor.shape().dim_size(1);
40  auto inp_flat = inp_tensor.flat<float>();
41  const float* inp = &(inp_flat(0));
42 
43  const Tensor& center_tensor = context->input(1);
44  OP_REQUIRES(context,
45  center_tensor.dims() == 3 &&
46  center_tensor.shape().dim_size(2) == 3,
47  errors::InvalidArgument(
48  "BallQuery expects "
49  "(batch_size,num_points,3) center shape"));
50  int ball_size = center_tensor.shape().dim_size(1);
51  auto center_flat = center_tensor.flat<float>();
52  const float* center = &(center_flat(0));
53 
54  Tensor* out_tensor;
55  OP_REQUIRES_OK(context,
56  context->allocate_output(
57  0, TensorShape{batch_size, ball_size, nsample},
58  &out_tensor));
59  auto out_flat = out_tensor->flat<int>();
60  int* out = &(out_flat(0));
61 
62  Kernel(context, batch_size, pts_size, ball_size, radius, nsample,
63  center, inp, out);
64  }
65 
66  virtual void Kernel(tensorflow::OpKernelContext* context,
67  int b,
68  int n,
69  int m,
70  float radius,
71  int nsample,
72  const float* new_xyz,
73  const float* xyz,
74  int* idx) = 0;
75 
76 protected:
77  int nsample;
78  float radius;
79 };
ImGuiContext * context
Definition: Window.cpp:76
Definition: BallQueryOpKernel.h:15
int nsample
Definition: BallQueryOpKernel.h:77
BallQueryOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: BallQueryOpKernel.h:17
void Compute(tensorflow::OpKernelContext *context) override
Definition: BallQueryOpKernel.h:29
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx)=0
float radius
Definition: BallQueryOpKernel.h:78