Open3D (C++ API)  0.18.0
TorchHelper.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 #include <torch/script.h>
10 
11 #include <sstream>
12 #include <type_traits>
13 
15 
16 // Macros for checking tensor properties
17 #define CHECK_CUDA(x) \
18  do { \
19  TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
20  } while (0)
21 
22 #define CHECK_CONTIGUOUS(x) \
23  do { \
24  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
25  } while (0)
26 
27 #define CHECK_TYPE(x, type) \
28  do { \
29  TORCH_CHECK(x.dtype() == torch::type, #x " must have type " #type) \
30  } while (0)
31 
32 #define CHECK_SAME_DEVICE_TYPE(...) \
33  do { \
34  if (!SameDeviceType({__VA_ARGS__})) { \
35  TORCH_CHECK( \
36  false, \
37  #__VA_ARGS__ \
38  " must all have the same device type but got " + \
39  TensorInfoStr({__VA_ARGS__})) \
40  } \
41  } while (0)
42 
43 #define CHECK_SAME_DTYPE(...) \
44  do { \
45  if (!SameDtype({__VA_ARGS__})) { \
46  TORCH_CHECK(false, \
47  #__VA_ARGS__ \
48  " must all have the same dtype but got " + \
49  TensorInfoStr({__VA_ARGS__})) \
50  } \
51  } while (0)
52 
53 // Conversion from standard types to torch types
54 typedef std::remove_const<decltype(torch::kInt32)>::type TorchDtype_t;
55 template <class T>
57  TORCH_CHECK(false, "Unsupported type");
58 }
59 template <>
61  return torch::kUInt8;
62 }
63 template <>
65  return torch::kInt8;
66 }
67 template <>
69  return torch::kInt16;
70 }
71 template <>
73  return torch::kInt32;
74 }
75 template <>
77  return torch::kInt64;
78 }
79 template <>
81  return torch::kFloat32;
82 }
83 template <>
85  return torch::kFloat64;
86 }
87 
88 // convenience function for comparing standard types with torch types
89 template <class T, class TDtype>
90 inline bool CompareTorchDtype(const TDtype& t) {
91  return ToTorchDtype<T>() == t;
92 }
93 
94 // convenience function to check if all tensors have the same device type
95 inline bool SameDeviceType(std::initializer_list<torch::Tensor> tensors) {
96  if (tensors.size()) {
97  auto device_type = tensors.begin()->device().type();
98  for (auto t : tensors) {
99  if (device_type != t.device().type()) {
100  return false;
101  }
102  }
103  }
104  return true;
105 }
106 
107 // convenience function to check if all tensors have the same dtype
108 inline bool SameDtype(std::initializer_list<torch::Tensor> tensors) {
109  if (tensors.size()) {
110  auto dtype = tensors.begin()->dtype();
111  for (auto t : tensors) {
112  if (dtype != t.dtype()) {
113  return false;
114  }
115  }
116  }
117  return true;
118 }
119 
120 inline std::string TensorInfoStr(std::initializer_list<torch::Tensor> tensors) {
121  std::stringstream sstr;
122  size_t count = 0;
123  for (const auto t : tensors) {
124  sstr << t.sizes() << " " << t.toString() << " " << t.device();
125  ++count;
126  if (count < tensors.size()) sstr << ", ";
127  }
128  return sstr.str();
129 }
130 
131 // convenience function for creating a tensor for temp memory
132 inline torch::Tensor CreateTempTensor(const int64_t size,
133  const torch::Device& device,
134  void** ptr = nullptr) {
135  torch::Tensor tensor = torch::empty(
136  {size}, torch::dtype(ToTorchDtype<uint8_t>()).device(device));
137  if (ptr) {
138  *ptr = tensor.data_ptr<uint8_t>();
139  }
140  return tensor;
141 }
142 
143 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
144  torch::Tensor tensor) {
145  using namespace open3d::ml::op_util;
146 
147  std::vector<DimValue> shape;
148  const int rank = tensor.dim();
149  for (int i = 0; i < rank; ++i) {
150  shape.push_back(tensor.size(i));
151  }
152  return shape;
153 }
154 
156  class TDimX,
157  class... TArgs>
158 std::tuple<bool, std::string> CheckShape(torch::Tensor tensor,
159  TDimX&& dimex,
160  TArgs&&... args) {
161  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
162  std::forward<TDimX>(dimex),
163  std::forward<TArgs>(args)...);
164 }
165 
166 //
167 // Macros for checking the shape of Tensors.
168 // Usage:
169 // {
170 // using namespace open3d::ml::op_util;
171 // Dim w("w");
172 // Dim h("h");
173 // CHECK_SHAPE(tensor1, 10, w, h); // checks if the first dim is 10
174 // // and assigns w and h based on
175 // // the shape of tensor1
176 //
177 // CHECK_SHAPE(tensor2, 10, 20, h); // this checks if the the last dim
178 // // of tensor2 matches the last dim
179 // // of tensor1. The first two dims
180 // // must match 10, 20.
181 // }
182 //
183 //
184 // See "../ShapeChecking.h" for more info and limitations.
185 //
186 #define CHECK_SHAPE(tensor, ...) \
187  do { \
188  bool cs_success_; \
189  std::string cs_errstr_; \
190  std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
191  TORCH_CHECK(cs_success_, \
192  "invalid shape for '" #tensor "', " + cs_errstr_) \
193  } while (0)
194 
195 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...) \
196  do { \
197  bool cs_success_; \
198  std::string cs_errstr_; \
199  std::tie(cs_success_, cs_errstr_) = \
200  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
201  TORCH_CHECK(cs_success_, \
202  "invalid shape for '" #tensor "', " + cs_errstr_) \
203  } while (0)
204 
205 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...) \
206  do { \
207  bool cs_success_; \
208  std::string cs_errstr_; \
209  std::tie(cs_success_, cs_errstr_) = \
210  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
211  TORCH_CHECK(cs_success_, \
212  "invalid shape for '" #tensor "', " + cs_errstr_) \
213  } while (0)
214 
215 #define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...) \
216  do { \
217  bool cs_success_; \
218  std::string cs_errstr_; \
219  std::tie(cs_success_, cs_errstr_) = \
220  CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
221  TORCH_CHECK(cs_success_, \
222  "invalid shape for '" #tensor "', " + cs_errstr_) \
223  } while (0)
224 
225 #define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...) \
226  do { \
227  bool cs_success_; \
228  std::string cs_errstr_; \
229  std::tie(cs_success_, cs_errstr_) = \
230  CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
231  TORCH_CHECK(cs_success_, \
232  "invalid shape for '" #tensor "', " + cs_errstr_) \
233  } while (0)
std::tuple< bool, std::string > CheckShape(torch::Tensor tensor, TDimX &&dimex, TArgs &&... args)
Definition: TorchHelper.h:158
TorchDtype_t ToTorchDtype< int64_t >()
Definition: TorchHelper.h:76
TorchDtype_t ToTorchDtype< uint8_t >()
Definition: TorchHelper.h:60
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:120
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(torch::Tensor tensor)
Definition: TorchHelper.h:143
TorchDtype_t ToTorchDtype< int16_t >()
Definition: TorchHelper.h:68
TorchDtype_t ToTorchDtype< int8_t >()
Definition: TorchHelper.h:64
TorchDtype_t ToTorchDtype< double >()
Definition: TorchHelper.h:84
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:108
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:95
std::remove_const< decltype(torch::kInt32)>::type TorchDtype_t
Definition: TorchHelper.h:54
TorchDtype_t ToTorchDtype()
Definition: TorchHelper.h:56
torch::Tensor CreateTempTensor(const int64_t size, const torch::Device &device, void **ptr=nullptr)
Definition: TorchHelper.h:132
TorchDtype_t ToTorchDtype< int32_t >()
Definition: TorchHelper.h:72
bool CompareTorchDtype(const TDtype &t)
Definition: TorchHelper.h:90
TorchDtype_t ToTorchDtype< float >()
Definition: TorchHelper.h:80
int size
Definition: FilePCD.cpp:40
int count
Definition: FilePCD.cpp:42
char type
Definition: FilePCD.cpp:41
Definition: ShapeChecking.h:16
CSOpt
Check shape options.
Definition: ShapeChecking.h:405