Open3D (C++ API)  0.11.0
TorchHelper.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // The MIT License (MIT)
5 //
6 // Copyright (c) 2020 www.open3d.org
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a copy
9 // of this software and associated documentation files (the "Software"), to deal
10 // in the Software without restriction, including without limitation the rights
11 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 // copies of the Software, and to permit persons to whom the Software is
13 // furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 // IN THE SOFTWARE.
25 // ----------------------------------------------------------------------------
26 
27 #pragma once
28 #include <sstream>
29 #include <type_traits>
30 
32 #include "torch/script.h"
33 
34 // Macros for checking tensor properties
35 #define CHECK_CUDA(x) \
36  do { \
37  TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
38  } while (0)
39 
40 #define CHECK_CONTIGUOUS(x) \
41  do { \
42  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
43  } while (0)
44 
45 #define CHECK_TYPE(x, type) \
46  do { \
47  TORCH_CHECK(x.dtype() == torch::type, #x " must have type " #type) \
48  } while (0)
49 
50 #define CHECK_SAME_DEVICE_TYPE(...) \
51  do { \
52  if (!SameDeviceType({__VA_ARGS__})) { \
53  TORCH_CHECK( \
54  false, \
55  #__VA_ARGS__ \
56  " must all have the same device type but got " + \
57  TensorInfoStr({__VA_ARGS__})) \
58  } \
59  } while (0)
60 
61 #define CHECK_SAME_DTYPE(...) \
62  do { \
63  if (!SameDtype({__VA_ARGS__})) { \
64  TORCH_CHECK(false, \
65  #__VA_ARGS__ \
66  " must all have the same dtype but got " + \
67  TensorInfoStr({__VA_ARGS__})) \
68  } \
69  } while (0)
70 
71 // Conversion from standard types to torch types
73 template <class T>
75  TORCH_CHECK(false, "Unsupported type");
76 }
77 template <>
79  return torch::kUInt8;
80 }
81 template <>
83  return torch::kInt8;
84 }
85 template <>
87  return torch::kInt16;
88 }
89 template <>
91  return torch::kInt32;
92 }
93 template <>
95  return torch::kInt64;
96 }
97 template <>
99  return torch::kFloat32;
100 }
101 template <>
103  return torch::kFloat64;
104 }
105 
106 // convenience function for comparing standard types with torch types
107 template <class T, class TDtype>
108 inline bool CompareTorchDtype(const TDtype& t) {
109  return ToTorchDtype<T>() == t;
110 }
111 
112 // convenience function to check if all tensors have the same device type
113 inline bool SameDeviceType(std::initializer_list<torch::Tensor> tensors) {
114  if (tensors.size()) {
115  auto device_type = tensors.begin()->device().type();
116  for (auto t : tensors) {
117  if (device_type != t.device().type()) {
118  return false;
119  }
120  }
121  }
122  return true;
123 }
124 
125 // convenience function to check if all tensors have the same dtype
126 inline bool SameDtype(std::initializer_list<torch::Tensor> tensors) {
127  if (tensors.size()) {
128  auto device_type = tensors.begin()->dtype();
129  for (auto t : tensors) {
130  if (device_type != t.dtype()) {
131  return false;
132  }
133  }
134  }
135  return true;
136 }
137 
138 inline std::string TensorInfoStr(std::initializer_list<torch::Tensor> tensors) {
139  std::stringstream sstr;
140  size_t count = 0;
141  for (const auto t : tensors) {
142  sstr << t.sizes() << " " << t.toString() << " " << t.device();
143  ++count;
144  if (count < tensors.size()) sstr << ", ";
145  }
146  return sstr.str();
147 }
148 
149 // convenience function for creating a tensor for temp memory
150 inline torch::Tensor CreateTempTensor(const int64_t size,
151  const torch::Device& device,
152  void** ptr = nullptr) {
153  torch::Tensor tensor = torch::empty(
154  {size}, torch::dtype(ToTorchDtype<uint8_t>()).device(device));
155  if (ptr) {
156  *ptr = tensor.data_ptr<uint8_t>();
157  }
158  return tensor;
159 }
160 
161 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
162  torch::Tensor tensor) {
163  using namespace open3d::ml::op_util;
164 
165  std::vector<DimValue> shape;
166  const int rank = tensor.dim();
167  for (int i = 0; i < rank; ++i) {
168  shape.push_back(tensor.size(i));
169  }
170  return shape;
171 }
172 
174  class TDimX,
175  class... TArgs>
176 std::tuple<bool, std::string> CheckShape(torch::Tensor tensor,
177  TDimX&& dimex,
178  TArgs&&... args) {
179  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
180  std::forward<TDimX>(dimex),
181  std::forward<TArgs>(args)...);
182 }
183 
184 //
185 // Macros for checking the shape of Tensors.
186 // Usage:
187 // {
188 // using namespace open3d::ml::op_util;
189 // Dim w("w");
190 // Dim h("h");
191 // CHECK_SHAPE(tensor1, 10, w, h); // checks if the first dim is 10
192 // // and assigns w and h based on
193 // // the shape of tensor1
194 //
195 // CHECK_SHAPE(tensor2, 10, 20, h); // this checks if the the last dim
196 // // of tensor2 matches the last dim
197 // // of tensor1. The first two dims
198 // // must match 10, 20.
199 // }
200 //
201 //
202 // See "../ShapeChecking.h" for more info and limitations.
203 //
204 #define CHECK_SHAPE(tensor, ...) \
205  do { \
206  bool cs_success_; \
207  std::string cs_errstr_; \
208  std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
209  TORCH_CHECK(cs_success_, \
210  "invalid shape for '" #tensor "', " + cs_errstr_) \
211  } while (0)
212 
213 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...) \
214  do { \
215  bool cs_success_; \
216  std::string cs_errstr_; \
217  std::tie(cs_success_, cs_errstr_) = \
218  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
219  TORCH_CHECK(cs_success_, \
220  "invalid shape for '" #tensor "', " + cs_errstr_) \
221  } while (0)
222 
223 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...) \
224  do { \
225  bool cs_success_; \
226  std::string cs_errstr_; \
227  std::tie(cs_success_, cs_errstr_) = \
228  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
229  TORCH_CHECK(cs_success_, \
230  "invalid shape for '" #tensor "', " + cs_errstr_) \
231  } while (0)
232 
233 #define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...) \
234  do { \
235  bool cs_success_; \
236  std::string cs_errstr_; \
237  std::tie(cs_success_, cs_errstr_) = \
238  CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
239  TORCH_CHECK(cs_success_, \
240  "invalid shape for '" #tensor "', " + cs_errstr_) \
241  } while (0)
242 
243 #define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...) \
244  do { \
245  bool cs_success_; \
246  std::string cs_errstr_; \
247  std::tie(cs_success_, cs_errstr_) = \
248  CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
249  TORCH_CHECK(cs_success_, \
250  "invalid shape for '" #tensor "', " + cs_errstr_) \
251  } while (0)
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(torch::Tensor tensor)
Definition: TorchHelper.h:161
TorchDtype_t ToTorchDtype< int16_t >()
Definition: TorchHelper.h:86
TorchDtype_t ToTorchDtype< uint8_t >()
Definition: TorchHelper.h:78
TorchDtype_t ToTorchDtype< float >()
Definition: TorchHelper.h:98
TorchDtype_t ToTorchDtype< int64_t >()
Definition: TorchHelper.h:94
torch::Tensor CreateTempTensor(const int64_t size, const torch::Device &device, void **ptr=nullptr)
Definition: TorchHelper.h:150
CSOpt
Check shape options.
Definition: ShapeChecking.h:424
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:113
TorchDtype_t ToTorchDtype< int8_t >()
Definition: TorchHelper.h:82
int size
Definition: FilePCD.cpp:59
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:138
int count
Definition: FilePCD.cpp:61
char type
Definition: FilePCD.cpp:60
std::remove_const< decltype(torch::kInt32)>::type TorchDtype_t
Definition: TorchHelper.h:72
bool CompareTorchDtype(const TDtype &t)
Definition: TorchHelper.h:108
TorchDtype_t ToTorchDtype< double >()
Definition: TorchHelper.h:102
std::tuple< bool, std::string > CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex, TArgs &&... args)
Definition: ShapeChecking.h:593
TorchDtype_t ToTorchDtype()
Definition: TorchHelper.h:74
TorchDtype_t ToTorchDtype< int32_t >()
Definition: TorchHelper.h:90
Definition: ShapeChecking.h:35
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:126