Open3D (C++ API)  0.18.0
InvertNeighborsListOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
14 
16 // namespace for code that is common for all kernels
17 namespace invert_neighbors_list_opkernel {
18 
19 // Base class with common code for the OpKernel implementations
20 class InvertNeighborsListOpKernel : public tensorflow::OpKernel {
21 public:
22  explicit InvertNeighborsListOpKernel(
23  tensorflow::OpKernelConstruction* construction)
24  : OpKernel(construction) {}
25 
26  void Compute(tensorflow::OpKernelContext* context) override {
27  using namespace tensorflow;
28  static_assert(sizeof(int64) == sizeof(int64_t),
29  "int64 type is not compatible");
30 
31  const Tensor& num_points_tensor = context->input(0);
32  OP_REQUIRES(context,
33  TensorShapeUtils::IsScalar(num_points_tensor.shape()),
34  errors::InvalidArgument(
35  "num_points must be scalar, got shape ",
36  num_points_tensor.shape().DebugString()));
37  const int64 num_points = num_points_tensor.scalar<int64>()();
38 
39  const Tensor& inp_neighbors_index = context->input(1);
40 
41  const Tensor& inp_neighbors_row_splits = context->input(2);
42 
43  const Tensor& inp_neighbors_attributes = context->input(3);
44 
45  // check input shapes
46  {
47  using namespace open3d::ml::op_util;
48  Dim num_neighbors("num_neighbors");
49 
50  CHECK_SHAPE(context, inp_neighbors_index, num_neighbors);
51  CHECK_SHAPE_IGNORE_LAST_DIMS(context, inp_neighbors_attributes,
52  num_neighbors || 0);
53  CHECK_SHAPE(context, inp_neighbors_row_splits, Dim());
54  }
55 
56  // compute the number of attributes for each neighbor
57  int num_attributes;
58  if (inp_neighbors_attributes.shape().dim_size(0) == 0) {
59  num_attributes = 0;
60  } else {
61  num_attributes = 1;
62  for (int i = 1; i < inp_neighbors_attributes.shape().dims(); ++i)
63  num_attributes *= inp_neighbors_attributes.shape().dim_size(i);
64  }
65 
66  Tensor* neighbors_index = 0;
67  TensorShape neighbors_index_shape(inp_neighbors_index.shape());
68  OP_REQUIRES_OK(context,
69  context->allocate_output(0, neighbors_index_shape,
70  &neighbors_index));
71 
72  Tensor* neighbors_row_splits = 0;
73  TensorShape neighbors_row_splits_shape({num_points + 1});
74  OP_REQUIRES_OK(context,
75  context->allocate_output(1, neighbors_row_splits_shape,
76  &neighbors_row_splits));
77 
78  Tensor* neighbors_attributes = 0;
79  TensorShape neighbors_attributes_shape(
80  inp_neighbors_attributes.shape());
81  OP_REQUIRES_OK(context,
82  context->allocate_output(2, neighbors_attributes_shape,
83  &neighbors_attributes));
84 
85  Kernel(context, inp_neighbors_index, inp_neighbors_row_splits,
86  inp_neighbors_attributes, num_attributes, *neighbors_index,
87  *neighbors_row_splits, *neighbors_attributes);
88  }
89 
90  // Function with the device specific code
91  virtual void Kernel(tensorflow::OpKernelContext* context,
92  const tensorflow::Tensor& inp_neighbors_index,
93  const tensorflow::Tensor& inp_neighbors_row_splits,
94  const tensorflow::Tensor& inp_neighbors_attributes,
95  const int num_attributes,
96  tensorflow::Tensor& neighbors_index,
97  tensorflow::Tensor& neighbors_row_splits,
98  tensorflow::Tensor& neighbors_attributes) = 0;
99 
100 private:
101 };
102 
103 } // namespace invert_neighbors_list_opkernel
#define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor,...)
Definition: TorchHelper.h:225
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:186
ImGuiContext * context
Definition: Window.cpp:76
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:50
Definition: ShapeChecking.h:16