10 #include "torch/script.h"
15 template <
class T,
class TIndex>
19 : device_type(device_type), device_idx(device_idx) {}
22 neighbors_index = torch::empty(
23 {int64_t(num)}, torch::dtype(ToTorchDtype<TIndex>())
24 .device(device_type, device_idx));
25 *ptr = neighbors_index.data_ptr<TIndex>();
29 neighbors_distance = torch::empty(
30 {int64_t(num)}, torch::dtype(ToTorchDtype<T>())
31 .device(device_type, device_idx));
32 *ptr = neighbors_distance.data_ptr<T>();
36 return neighbors_index.data_ptr<TIndex>();
39 const T*
DistancesPtr()
const {
return neighbors_distance.data_ptr<T>(); }
43 return neighbors_distance;
47 torch::Tensor neighbors_index;
48 torch::Tensor neighbors_distance;
49 torch::DeviceType device_type;
Definition: NeighborSearchAllocator.h:16
const TIndex * IndicesPtr() const
Definition: NeighborSearchAllocator.h:35
const torch::Tensor & NeighborsDistance() const
Definition: NeighborSearchAllocator.h:42
void AllocIndices(TIndex **ptr, size_t num)
Definition: NeighborSearchAllocator.h:21
NeighborSearchAllocator(torch::DeviceType device_type, int device_idx)
Definition: NeighborSearchAllocator.h:18
const T * DistancesPtr() const
Definition: NeighborSearchAllocator.h:39
void AllocDistances(T **ptr, size_t num)
Definition: NeighborSearchAllocator.h:28
const torch::Tensor & NeighborsIndex() const
Definition: NeighborSearchAllocator.h:41