9 #include <torch/script.h>
12 #include <type_traits>
17 #define CHECK_CUDA(x) \
19 TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
22 #define CHECK_CONTIGUOUS(x) \
24 TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
27 #define CHECK_TYPE(x, type) \
29 TORCH_CHECK(x.dtype() == torch::type, #x " must have type " #type) \
32 #define CHECK_SAME_DEVICE_TYPE(...) \
34 if (!SameDeviceType({__VA_ARGS__})) { \
38 " must all have the same device type but got " + \
39 TensorInfoStr({__VA_ARGS__})) \
43 #define CHECK_SAME_DTYPE(...) \
45 if (!SameDtype({__VA_ARGS__})) { \
48 " must all have the same dtype but got " + \
49 TensorInfoStr({__VA_ARGS__})) \
57 TORCH_CHECK(
false,
"Unsupported type");
81 return torch::kFloat32;
85 return torch::kFloat64;
89 template <
class T,
class TDtype>
91 return ToTorchDtype<T>() == t;
97 auto device_type = tensors.begin()->device().type();
98 for (
auto t : tensors) {
99 if (device_type != t.device().type()) {
108 inline bool SameDtype(std::initializer_list<torch::Tensor> tensors) {
109 if (tensors.size()) {
110 auto dtype = tensors.begin()->dtype();
111 for (
auto t : tensors) {
112 if (dtype != t.dtype()) {
120 inline std::string
TensorInfoStr(std::initializer_list<torch::Tensor> tensors) {
121 std::stringstream sstr;
123 for (
const auto t : tensors) {
124 sstr << t.sizes() <<
" " << t.toString() <<
" " << t.device();
126 if (
count < tensors.size()) sstr <<
", ";
133 const torch::Device& device,
134 void** ptr =
nullptr) {
135 torch::Tensor tensor = torch::empty(
138 *ptr = tensor.data_ptr<uint8_t>();
144 torch::Tensor tensor) {
147 std::vector<DimValue> shape;
148 const int rank = tensor.dim();
149 for (
int i = 0; i < rank; ++i) {
150 shape.push_back(tensor.size(i));
158 std::tuple<bool, std::string>
CheckShape(torch::Tensor tensor,
161 return open3d::ml::op_util::CheckShape<Opt>(
GetShapeVector(tensor),
162 std::forward<TDimX>(dimex),
163 std::forward<TArgs>(args)...);
186 #define CHECK_SHAPE(tensor, ...) \
189 std::string cs_errstr_; \
190 std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
191 TORCH_CHECK(cs_success_, \
192 "invalid shape for '" #tensor "', " + cs_errstr_) \
195 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...) \
198 std::string cs_errstr_; \
199 std::tie(cs_success_, cs_errstr_) = \
200 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
201 TORCH_CHECK(cs_success_, \
202 "invalid shape for '" #tensor "', " + cs_errstr_) \
205 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...) \
208 std::string cs_errstr_; \
209 std::tie(cs_success_, cs_errstr_) = \
210 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
211 TORCH_CHECK(cs_success_, \
212 "invalid shape for '" #tensor "', " + cs_errstr_) \
215 #define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...) \
218 std::string cs_errstr_; \
219 std::tie(cs_success_, cs_errstr_) = \
220 CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
221 TORCH_CHECK(cs_success_, \
222 "invalid shape for '" #tensor "', " + cs_errstr_) \
225 #define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...) \
228 std::string cs_errstr_; \
229 std::tie(cs_success_, cs_errstr_) = \
230 CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
231 TORCH_CHECK(cs_success_, \
232 "invalid shape for '" #tensor "', " + cs_errstr_) \
std::tuple< bool, std::string > CheckShape(torch::Tensor tensor, TDimX &&dimex, TArgs &&... args)
Definition: TorchHelper.h:158
TorchDtype_t ToTorchDtype< int64_t >()
Definition: TorchHelper.h:76
TorchDtype_t ToTorchDtype< uint8_t >()
Definition: TorchHelper.h:60
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:120
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(torch::Tensor tensor)
Definition: TorchHelper.h:143
TorchDtype_t ToTorchDtype< int16_t >()
Definition: TorchHelper.h:68
TorchDtype_t ToTorchDtype< int8_t >()
Definition: TorchHelper.h:64
TorchDtype_t ToTorchDtype< double >()
Definition: TorchHelper.h:84
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:108
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:95
std::remove_const< decltype(torch::kInt32)>::type TorchDtype_t
Definition: TorchHelper.h:54
TorchDtype_t ToTorchDtype()
Definition: TorchHelper.h:56
torch::Tensor CreateTempTensor(const int64_t size, const torch::Device &device, void **ptr=nullptr)
Definition: TorchHelper.h:132
TorchDtype_t ToTorchDtype< int32_t >()
Definition: TorchHelper.h:72
bool CompareTorchDtype(const TDtype &t)
Definition: TorchHelper.h:90
TorchDtype_t ToTorchDtype< float >()
Definition: TorchHelper.h:80
Definition: ShapeChecking.h:16
CSOpt
Check shape options.
Definition: ShapeChecking.h:405