Open3D (C++ API)  0.18.0+252c867
VoxelizeOpKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 //#include "open3d/ml/impl/misc/VoxelPooling.h"
12 #include "tensorflow/core/framework/op.h"
13 #include "tensorflow/core/framework/op_kernel.h"
14 #include "tensorflow/core/lib/core/errors.h"
15 
17 // namespace for code that is common for all kernels
18 namespace voxelize_opkernel {
19 
20 class OutputAllocator {
21 public:
22  OutputAllocator(tensorflow::OpKernelContext* context) : context(context) {}
23 
24  void AllocVoxelCoords(int32_t** ptr, int64_t rows, int64_t cols) {
25  using namespace tensorflow;
26  *ptr = nullptr;
27  Tensor* tensor = 0;
28  TensorShape shape({rows, cols});
29  OP_REQUIRES_OK(context, context->allocate_output(0, shape, &tensor));
30  auto flat_tensor = tensor->flat<int32_t>();
31  *ptr = flat_tensor.data();
32  }
33 
34  void AllocVoxelPointIndices(int64_t** ptr, int64_t num) {
35  using namespace tensorflow;
36  *ptr = nullptr;
37  Tensor* tensor = 0;
38  TensorShape shape({num});
39  OP_REQUIRES_OK(context, context->allocate_output(1, shape, &tensor));
40  auto flat_tensor = tensor->flat<int64>();
41  *ptr = (int64_t*)flat_tensor.data();
42  }
43 
44  void AllocVoxelPointRowSplits(int64_t** ptr, int64_t num) {
45  using namespace tensorflow;
46  *ptr = nullptr;
47  Tensor* tensor = 0;
48  TensorShape shape({num});
49  OP_REQUIRES_OK(context, context->allocate_output(2, shape, &tensor));
50  auto flat_tensor = tensor->flat<int64>();
51  *ptr = (int64_t*)flat_tensor.data();
52  }
53 
54  void AllocVoxelBatchSplits(int64_t** ptr, int64_t num) {
55  using namespace tensorflow;
56  *ptr = nullptr;
57  Tensor* tensor = 0;
58  TensorShape shape({num});
59  OP_REQUIRES_OK(context, context->allocate_output(3, shape, &tensor));
60  auto flat_tensor = tensor->flat<int64>();
61  *ptr = (int64_t*)flat_tensor.data();
62  }
63 
64 private:
65  tensorflow::OpKernelContext* context;
66 };
67 
68 // Base class with common code for the OpKernel implementations
69 class VoxelizeOpKernel : public tensorflow::OpKernel {
70 public:
71  explicit VoxelizeOpKernel(tensorflow::OpKernelConstruction* construction)
72  : OpKernel(construction) {
73  OP_REQUIRES_OK(construction,
74  construction->GetAttr("max_points_per_voxel",
75  &max_points_per_voxel));
76  OP_REQUIRES_OK(construction,
77  construction->GetAttr("max_voxels", &max_voxels));
78  }
79 
80  void Compute(tensorflow::OpKernelContext* context) override {
81  using namespace tensorflow;
82  const Tensor& points = context->input(0);
83  const Tensor& row_splits = context->input(1);
84  const Tensor& voxel_size = context->input(2);
85  const Tensor& points_range_min = context->input(3);
86  const Tensor& points_range_max = context->input(4);
87 
88  {
89  using namespace open3d::ml::op_util;
90  Dim num_points("num_points");
91  Dim ndim("ndim");
92  CHECK_SHAPE(context, points, num_points, ndim);
93  CHECK_SHAPE(context, voxel_size, ndim);
94  CHECK_SHAPE(context, points_range_min, ndim);
95  CHECK_SHAPE(context, points_range_max, ndim);
96  OP_REQUIRES(
97  context, ndim.value() > 0 && ndim.value() < 9,
98  errors::InvalidArgument(
99  "the number of dimensions must be in [1,..,8]"));
100  }
101 
102  Kernel(context, points, row_splits, voxel_size, points_range_min,
103  points_range_max);
104  }
105 
106  // Function with the device specific code
107  virtual void Kernel(tensorflow::OpKernelContext* context,
108  const tensorflow::Tensor& points,
109  const tensorflow::Tensor& row_splits,
110  const tensorflow::Tensor& voxel_size,
111  const tensorflow::Tensor& points_range_min,
112  const tensorflow::Tensor& points_range_max) = 0;
113 
114 protected:
115  tensorflow::int64 max_points_per_voxel;
116  tensorflow::int64 max_voxels;
117 };
118 
119 } // namespace voxelize_opkernel
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:186
ImGuiContext * context
Definition: Window.cpp:76
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:50
int points
Definition: FilePCD.cpp:54
const char const char value recording_handle imu_sample recording_handle uint8_t size_t data_size k4a_record_configuration_t config target_format k4a_capture_t capture_handle k4a_imu_sample_t imu_sample playback_handle k4a_logging_message_cb_t void min_level device_handle k4a_imu_sample_t int32_t
Definition: K4aPlugin.cpp:395
Definition: ShapeChecking.h:16