41 const std::vector<Tensor>& index_tensors)
62 const std::vector<Tensor>& index_tensors);
67 const Tensor& tensor,
const std::vector<Tensor>& index_tensors);
71 static std::pair<std::vector<Tensor>,
SizeVector>
104 const std::vector<Tensor>& index_tensors);
140 const std::vector<Tensor>& index_tensors,
145 if (indexed_shape.size() != indexed_strides.size()) {
147 "Internal error: indexed_shape's ndim {} does not equal to " 148 "indexd_strides' ndim {}",
149 indexed_shape.size(), indexed_strides.size());
151 num_indices_ = indexed_shape.size();
154 std::vector<Tensor> inputs;
155 inputs.push_back(src);
156 for (
const Tensor& index_tensor : index_tensors) {
157 if (index_tensor.NumDims() != 0) {
158 inputs.push_back(index_tensor);
164 if (num_indices_ != static_cast<int64_t>(indexed_strides.size())) {
166 "Internal error: indexed_shape's ndim {} does not equal to " 167 "indexd_strides' ndim {}",
168 num_indices_, indexed_strides.size());
170 for (int64_t i = 0; i < num_indices_; ++i) {
176 if (src.
GetDtype() != dst.GetDtype()) {
178 "src's dtype {} is not the same as dst's dtype {}.",
186 char* ptr = indexer_.GetInputPtr(0, workload_idx);
187 ptr += GetIndexedOffset(workload_idx) * element_byte_size_ *
188 (mode_ == AdvancedIndexerMode::GET);
193 char* ptr = indexer_.GetOutputPtr(workload_idx);
194 ptr += GetIndexedOffset(workload_idx) * element_byte_size_ *
195 (mode_ == AdvancedIndexerMode::SET);
202 for (int64_t i = 0; i < num_indices_; ++i) {
203 int64_t index = *(
reinterpret_cast<int64_t*
>(
204 indexer_.GetInputPtr(i + 1, workload_idx)));
206 "Index out of bounds");
SizeVector output_shape_
Output shape.
Definition: AdvancedIndexing.h:114
const char const char value recording_handle imu_sample recording_handle uint8_t size_t data_size k4a_record_configuration_t config target_format k4a_capture_t capture_handle k4a_imu_sample_t imu_sample playback_handle k4a_logging_message_cb_t void min_level device_handle k4a_imu_sample_t timeout_in_ms capture_handle capture_handle capture_handle image_handle temperature_c k4a_image_t image_handle uint8_t image_handle image_handle image_handle image_handle image_handle timestamp_usec white_balance image_handle k4a_device_configuration_t config device_handle char size_t serial_number_size bool int32_t int32_t int32_t int32_t k4a_color_control_mode_t default_mode mode
Definition: K4aPlugin.cpp:676
Dtype GetDtype() const
Definition: Tensor.h:742
std::vector< Tensor > index_tensors_
The processed index tensors.
Definition: AdvancedIndexing.h:111
int64_t num_indices_
Definition: AdvancedIndexing.h:218
OPEN3D_HOST_DEVICE int64_t GetIndexedOffset(int64_t workload_idx) const
Definition: AdvancedIndexing.h:200
Tensor GetTensor() const
Definition: AdvancedIndexing.h:46
AdvancedIndexerMode mode_
Definition: AdvancedIndexing.h:217
SizeVector GetIndexedShape() const
Definition: AdvancedIndexing.h:54
int offset
Definition: FilePCD.cpp:62
void LogError(const char *format, const Args &... args)
Definition: Console.h:174
Indexer indexer_
Definition: AdvancedIndexing.h:216
void RunPreprocess()
Preprocess tensor and index tensors.
Definition: AdvancedIndexing.cpp:128
static std::pair< Tensor, std::vector< Tensor > > ShuffleIndexedDimsToFront(const Tensor &tensor, const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.cpp:59
static std::vector< Tensor > ExpandBoolTensors(const std::vector< Tensor > &index_tensors)
Expand boolean tensor to integer index.
Definition: AdvancedIndexing.cpp:248
static Tensor RestrideIndexTensor(const Tensor &index_tensor, int64_t dims_before, int64_t dims_after)
Definition: AdvancedIndexing.cpp:118
int64_t NumWorkloads() const
Definition: AdvancedIndexing.h:213
Definition: AdvancedIndexing.h:134
SizeVector indexed_strides_
Definition: AdvancedIndexing.h:122
#define OPEN3D_HOST_DEVICE
Definition: CUDAUtils.h:54
Definition: SizeVector.h:40
static int64_t ByteSize(const Dtype &dtype)
Definition: Dtype.h:61
SizeVector indexed_shape_
Definition: AdvancedIndexing.h:118
static bool IsIndexSplittedBySlice(const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.cpp:35
static Tensor RestrideTensor(const Tensor &tensor, int64_t dims_before, int64_t dims_indexed, SizeVector replacement_shape)
Definition: AdvancedIndexing.cpp:103
Tensor tensor_
Definition: AdvancedIndexing.h:108
SizeVector GetIndexedStrides() const
Definition: AdvancedIndexing.h:56
SizeVector GetOutputShape() const
Definition: AdvancedIndexing.h:52
OPEN3D_HOST_DEVICE char * GetInputPtr(int64_t workload_idx) const
Definition: AdvancedIndexing.h:185
Definition: Open3DViewer.h:29
Definition: Indexer.h:260
static std::pair< std::vector< Tensor >, SizeVector > ExpandToCommonShapeExceptZeroDim(const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.cpp:81
std::vector< Tensor > GetIndexTensors() const
Definition: AdvancedIndexing.h:48
AdvancedIndexerMode
Definition: AdvancedIndexing.h:136
This class is based on PyTorch's aten/src/ATen/native/Indexing.cpp.
Definition: AdvancedIndexing.h:38
OPEN3D_HOST_DEVICE char * GetOutputPtr(int64_t workload_idx) const
Definition: AdvancedIndexing.h:192
AdvancedIndexPreprocessor(const Tensor &tensor, const std::vector< Tensor > &index_tensors)
Definition: AdvancedIndexing.h:40
static std::string ToString(const Dtype &dtype)
Definition: Dtype.h:97
AdvancedIndexer(const Tensor &src, const Tensor &dst, const std::vector< Tensor > &index_tensors, const SizeVector &indexed_shape, const SizeVector &indexed_strides, AdvancedIndexerMode mode)
Definition: AdvancedIndexing.h:138
int64_t element_byte_size_
Definition: AdvancedIndexing.h:219