(9238339 (Tue Mar 14 18:49:09 2023 -0700))
#include <torch/script.h>
#include <sstream>
#include <type_traits>
#include "open3d/ml/ShapeChecking.h"
Go to the source code of this file.
◆ CHECK_CONTIGUOUS
#define CHECK_CONTIGUOUS |
( |
|
x | ) |
|
Value: do { \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
} while (0)
◆ CHECK_CUDA
Value: do { \
TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
} while (0)
◆ CHECK_SAME_DEVICE_TYPE
#define CHECK_SAME_DEVICE_TYPE |
( |
|
... | ) |
|
Value: do { \
TORCH_CHECK( \
false, \
#__VA_ARGS__ \
" must all have the same device type but got " + \
} \
} while (0)
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:120
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:95
◆ CHECK_SAME_DTYPE
#define CHECK_SAME_DTYPE |
( |
|
... | ) |
|
Value: do { \
TORCH_CHECK(false, \
#__VA_ARGS__ \
" must all have the same dtype but got " + \
} \
} while (0)
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:108
◆ CHECK_SHAPE
#define CHECK_SHAPE |
( |
|
tensor, |
|
|
|
... |
|
) |
| |
Value: do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) =
CheckShape(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)
std::tuple< bool, std::string > CheckShape(torch::Tensor tensor, TDimX &&dimex, TArgs &&... args)
Definition: TorchHelper.h:158
◆ CHECK_SHAPE_COMBINE_FIRST_DIMS
#define CHECK_SHAPE_COMBINE_FIRST_DIMS |
( |
|
tensor, |
|
|
|
... |
|
) |
| |
Value: do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)
◆ CHECK_SHAPE_COMBINE_LAST_DIMS
#define CHECK_SHAPE_COMBINE_LAST_DIMS |
( |
|
tensor, |
|
|
|
... |
|
) |
| |
Value: do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)
◆ CHECK_SHAPE_IGNORE_FIRST_DIMS
#define CHECK_SHAPE_IGNORE_FIRST_DIMS |
( |
|
tensor, |
|
|
|
... |
|
) |
| |
Value: do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)
◆ CHECK_SHAPE_IGNORE_LAST_DIMS
#define CHECK_SHAPE_IGNORE_LAST_DIMS |
( |
|
tensor, |
|
|
|
... |
|
) |
| |
Value: do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)
◆ CHECK_TYPE
#define CHECK_TYPE |
( |
|
x, |
|
|
|
type |
|
) |
| |
◆ TorchDtype_t
◆ CheckShape()
template<open3d::ml::op_util::CSOpt Opt = open3d::ml::op_util::CSOpt::NONE, class TDimX , class... TArgs>
std::tuple<bool, std::string> CheckShape |
( |
torch::Tensor |
tensor, |
|
|
TDimX && |
dimex, |
|
|
TArgs &&... |
args |
|
) |
| |
◆ CompareTorchDtype()
template<class T , class TDtype >
bool CompareTorchDtype |
( |
const TDtype & |
t | ) |
|
|
inline |
◆ CreateTempTensor()
torch::Tensor CreateTempTensor |
( |
const int64_t |
size, |
|
|
const torch::Device & |
device, |
|
|
void ** |
ptr = nullptr |
|
) |
| |
|
inline |
◆ GetShapeVector()
◆ SameDeviceType()
bool SameDeviceType |
( |
std::initializer_list< torch::Tensor > |
tensors | ) |
|
|
inline |
◆ SameDtype()
bool SameDtype |
( |
std::initializer_list< torch::Tensor > |
tensors | ) |
|
|
inline |
◆ TensorInfoStr()
std::string TensorInfoStr |
( |
std::initializer_list< torch::Tensor > |
tensors | ) |
|
|
inline |
◆ ToTorchDtype()
◆ ToTorchDtype< double >()
◆ ToTorchDtype< float >()
◆ ToTorchDtype< int16_t >()
◆ ToTorchDtype< int32_t >()
◆ ToTorchDtype< int64_t >()
◆ ToTorchDtype< int8_t >()
◆ ToTorchDtype< uint8_t >()