29 #include <type_traits> 32 #include "torch/script.h" 35 #define CHECK_CUDA(x) \ 37 TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \ 40 #define CHECK_CONTIGUOUS(x) \ 42 TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \ 45 #define CHECK_TYPE(x, type) \ 47 TORCH_CHECK(x.dtype() == torch::type, #x " must have type " #type) \ 50 #define CHECK_SAME_DEVICE_TYPE(...) \ 52 if (!SameDeviceType({__VA_ARGS__})) { \ 56 " must all have the same device type but got " + \ 57 TensorInfoStr({__VA_ARGS__})) \ 61 #define CHECK_SAME_DTYPE(...) \ 63 if (!SameDtype({__VA_ARGS__})) { \ 66 " must all have the same dtype but got " + \ 67 TensorInfoStr({__VA_ARGS__})) \ 75 TORCH_CHECK(
false,
"Unsupported type");
99 return torch::kFloat32;
103 return torch::kFloat64;
107 template <
class T,
class TDtype>
109 return ToTorchDtype<T>() == t;
114 if (tensors.size()) {
115 auto device_type = tensors.begin()->device().type();
116 for (
auto t : tensors) {
117 if (device_type != t.device().type()) {
126 inline bool SameDtype(std::initializer_list<torch::Tensor> tensors) {
127 if (tensors.size()) {
128 auto device_type = tensors.begin()->dtype();
129 for (
auto t : tensors) {
130 if (device_type != t.dtype()) {
138 inline std::string
TensorInfoStr(std::initializer_list<torch::Tensor> tensors) {
139 std::stringstream sstr;
141 for (
const auto t : tensors) {
142 sstr << t.sizes() <<
" " << t.toString() <<
" " << t.device();
144 if (count < tensors.size()) sstr <<
", ";
151 const torch::Device& device,
152 void** ptr =
nullptr) {
153 torch::Tensor tensor = torch::empty(
156 *ptr = tensor.data_ptr<uint8_t>();
162 torch::Tensor tensor) {
165 std::vector<DimValue> shape;
166 const int rank = tensor.dim();
167 for (
int i = 0; i < rank; ++i) {
168 shape.push_back(tensor.size(i));
176 std::tuple<bool, std::string>
CheckShape(torch::Tensor tensor,
179 return open3d::ml::op_util::CheckShape<Opt>(
GetShapeVector(tensor),
180 std::forward<TDimX>(dimex),
181 std::forward<TArgs>(args)...);
204 #define CHECK_SHAPE(tensor, ...) \ 207 std::string cs_errstr_; \ 208 std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \ 209 TORCH_CHECK(cs_success_, \ 210 "invalid shape for '" #tensor "', " + cs_errstr_) \ 213 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...) \ 216 std::string cs_errstr_; \ 217 std::tie(cs_success_, cs_errstr_) = \ 218 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \ 219 TORCH_CHECK(cs_success_, \ 220 "invalid shape for '" #tensor "', " + cs_errstr_) \ 223 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...) \ 226 std::string cs_errstr_; \ 227 std::tie(cs_success_, cs_errstr_) = \ 228 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \ 229 TORCH_CHECK(cs_success_, \ 230 "invalid shape for '" #tensor "', " + cs_errstr_) \ 233 #define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...) \ 236 std::string cs_errstr_; \ 237 std::tie(cs_success_, cs_errstr_) = \ 238 CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \ 239 TORCH_CHECK(cs_success_, \ 240 "invalid shape for '" #tensor "', " + cs_errstr_) \ 243 #define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...) \ 246 std::string cs_errstr_; \ 247 std::tie(cs_success_, cs_errstr_) = \ 248 CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \ 249 TORCH_CHECK(cs_success_, \ 250 "invalid shape for '" #tensor "', " + cs_errstr_) \ std::vector< open3d::ml::op_util::DimValue > GetShapeVector(torch::Tensor tensor)
Definition: TorchHelper.h:161
TorchDtype_t ToTorchDtype< int16_t >()
Definition: TorchHelper.h:86
TorchDtype_t ToTorchDtype< uint8_t >()
Definition: TorchHelper.h:78
TorchDtype_t ToTorchDtype< float >()
Definition: TorchHelper.h:98
TorchDtype_t ToTorchDtype< int64_t >()
Definition: TorchHelper.h:94
torch::Tensor CreateTempTensor(const int64_t size, const torch::Device &device, void **ptr=nullptr)
Definition: TorchHelper.h:150
CSOpt
Check shape options.
Definition: ShapeChecking.h:424
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:113
TorchDtype_t ToTorchDtype< int8_t >()
Definition: TorchHelper.h:82
int size
Definition: FilePCD.cpp:59
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:138
int count
Definition: FilePCD.cpp:61
char type
Definition: FilePCD.cpp:60
std::remove_const< decltype(torch::kInt32)>::type TorchDtype_t
Definition: TorchHelper.h:72
bool CompareTorchDtype(const TDtype &t)
Definition: TorchHelper.h:108
TorchDtype_t ToTorchDtype< double >()
Definition: TorchHelper.h:102
std::tuple< bool, std::string > CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex, TArgs &&... args)
Definition: ShapeChecking.h:593
TorchDtype_t ToTorchDtype()
Definition: TorchHelper.h:74
TorchDtype_t ToTorchDtype< int32_t >()
Definition: TorchHelper.h:90
Definition: ShapeChecking.h:35
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:126