Open3D (C++ API)  0.18.0
NeighborSearchAllocator.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 //
8 
10 #include "torch/script.h"
11 
12 // These classes implement functors that can be passed to the neighbor search
13 // functions.
14 
15 template <class T, class TIndex>
17 public:
18  NeighborSearchAllocator(torch::DeviceType device_type, int device_idx)
19  : device_type(device_type), device_idx(device_idx) {}
20 
21  void AllocIndices(TIndex** ptr, size_t num) {
22  neighbors_index = torch::empty(
23  {int64_t(num)}, torch::dtype(ToTorchDtype<TIndex>())
24  .device(device_type, device_idx));
25  *ptr = neighbors_index.data_ptr<TIndex>();
26  }
27 
28  void AllocDistances(T** ptr, size_t num) {
29  neighbors_distance = torch::empty(
30  {int64_t(num)}, torch::dtype(ToTorchDtype<T>())
31  .device(device_type, device_idx));
32  *ptr = neighbors_distance.data_ptr<T>();
33  }
34 
35  const TIndex* IndicesPtr() const {
36  return neighbors_index.data_ptr<TIndex>();
37  }
38 
39  const T* DistancesPtr() const { return neighbors_distance.data_ptr<T>(); }
40 
41  const torch::Tensor& NeighborsIndex() const { return neighbors_index; }
42  const torch::Tensor& NeighborsDistance() const {
43  return neighbors_distance;
44  }
45 
46 private:
47  torch::Tensor neighbors_index;
48  torch::Tensor neighbors_distance;
49  torch::DeviceType device_type;
50  int device_idx;
51 };
Definition: NeighborSearchAllocator.h:16
const TIndex * IndicesPtr() const
Definition: NeighborSearchAllocator.h:35
const torch::Tensor & NeighborsDistance() const
Definition: NeighborSearchAllocator.h:42
void AllocIndices(TIndex **ptr, size_t num)
Definition: NeighborSearchAllocator.h:21
NeighborSearchAllocator(torch::DeviceType device_type, int device_idx)
Definition: NeighborSearchAllocator.h:18
const T * DistancesPtr() const
Definition: NeighborSearchAllocator.h:39
void AllocDistances(T **ptr, size_t num)
Definition: NeighborSearchAllocator.h:28
const torch::Tensor & NeighborsIndex() const
Definition: NeighborSearchAllocator.h:41