Open3D (C++ API)  0.11.0
Macros | Typedefs | Functions
TorchHelper.h File Reference

(9f2dd765 (Tue Sep 15 08:14:29 2020 -0700))

#include <sstream>
#include <type_traits>
#include "open3d/ml/ShapeChecking.h"
#include "torch/script.h"

Go to the source code of this file.

Macros

#define CHECK_CUDA(x)
 
#define CHECK_CONTIGUOUS(x)
 
#define CHECK_TYPE(x, type)
 
#define CHECK_SAME_DEVICE_TYPE(...)
 
#define CHECK_SAME_DTYPE(...)
 
#define CHECK_SHAPE(tensor, ...)
 
#define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...)
 
#define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...)
 
#define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...)
 
#define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...)
 

Typedefs

typedef std::remove_const< decltype(torch::kInt32)>::type TorchDtype_t
 

Functions

template<class T >
TorchDtype_t ToTorchDtype ()
 
template<>
TorchDtype_t ToTorchDtype< uint8_t > ()
 
template<>
TorchDtype_t ToTorchDtype< int8_t > ()
 
template<>
TorchDtype_t ToTorchDtype< int16_t > ()
 
template<>
TorchDtype_t ToTorchDtype< int32_t > ()
 
template<>
TorchDtype_t ToTorchDtype< int64_t > ()
 
template<>
TorchDtype_t ToTorchDtype< float > ()
 
template<>
TorchDtype_t ToTorchDtype< double > ()
 
template<class T , class TDtype >
bool CompareTorchDtype (const TDtype &t)
 
bool SameDeviceType (std::initializer_list< torch::Tensor > tensors)
 
bool SameDtype (std::initializer_list< torch::Tensor > tensors)
 
std::string TensorInfoStr (std::initializer_list< torch::Tensor > tensors)
 
torch::Tensor CreateTempTensor (const int64_t size, const torch::Device &device, void **ptr=nullptr)
 
std::vector< open3d::ml::op_util::DimValueGetShapeVector (torch::Tensor tensor)
 
template<open3d::ml::op_util::CSOpt Opt = open3d::ml::op_util::CSOpt::NONE, class TDimX , class... TArgs>
std::tuple< bool, std::string > CheckShape (torch::Tensor tensor, TDimX &&dimex, TArgs &&... args)
 

Macro Definition Documentation

◆ CHECK_CONTIGUOUS

#define CHECK_CONTIGUOUS (   x)
Value:
do { \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
} while (0)

◆ CHECK_CUDA

#define CHECK_CUDA (   x)
Value:
do { \
TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
} while (0)

◆ CHECK_SAME_DEVICE_TYPE

#define CHECK_SAME_DEVICE_TYPE (   ...)
Value:
do { \
if (!SameDeviceType({__VA_ARGS__})) { \
TORCH_CHECK( \
false, \
#__VA_ARGS__ \
" must all have the same device type but got " + \
TensorInfoStr({__VA_ARGS__})) \
} \
} while (0)
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:113
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:138

◆ CHECK_SAME_DTYPE

#define CHECK_SAME_DTYPE (   ...)
Value:
do { \
if (!SameDtype({__VA_ARGS__})) { \
TORCH_CHECK(false, \
#__VA_ARGS__ \
" must all have the same dtype but got " + \
TensorInfoStr({__VA_ARGS__})) \
} \
} while (0)
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:138
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition: TorchHelper.h:126

◆ CHECK_SHAPE

#define CHECK_SHAPE (   tensor,
  ... 
)
Value:
do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)
std::tuple< bool, std::string > CheckShape(torch::Tensor tensor, TDimX &&dimex, TArgs &&... args)
Definition: TorchHelper.h:176

◆ CHECK_SHAPE_COMBINE_FIRST_DIMS

#define CHECK_SHAPE_COMBINE_FIRST_DIMS (   tensor,
  ... 
)
Value:
do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)

◆ CHECK_SHAPE_COMBINE_LAST_DIMS

#define CHECK_SHAPE_COMBINE_LAST_DIMS (   tensor,
  ... 
)
Value:
do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)

◆ CHECK_SHAPE_IGNORE_FIRST_DIMS

#define CHECK_SHAPE_IGNORE_FIRST_DIMS (   tensor,
  ... 
)
Value:
do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)

◆ CHECK_SHAPE_IGNORE_LAST_DIMS

#define CHECK_SHAPE_IGNORE_LAST_DIMS (   tensor,
  ... 
)
Value:
do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)

◆ CHECK_TYPE

#define CHECK_TYPE (   x,
  type 
)
Value:
do { \
TORCH_CHECK(x.dtype() == torch::type, #x " must have type " #type) \
} while (0)
char type
Definition: FilePCD.cpp:60

Typedef Documentation

◆ TorchDtype_t

typedef std::remove_const<decltype(torch::kInt32)>::type TorchDtype_t

Function Documentation

◆ CheckShape()

template<open3d::ml::op_util::CSOpt Opt = open3d::ml::op_util::CSOpt::NONE, class TDimX , class... TArgs>
std::tuple<bool, std::string> CheckShape ( torch::Tensor  tensor,
TDimX &&  dimex,
TArgs &&...  args 
)

◆ CompareTorchDtype()

template<class T , class TDtype >
bool CompareTorchDtype ( const TDtype &  t)
inline

◆ CreateTempTensor()

torch::Tensor CreateTempTensor ( const int64_t  size,
const torch::Device &  device,
void **  ptr = nullptr 
)
inline

◆ GetShapeVector()

std::vector<open3d::ml::op_util::DimValue> GetShapeVector ( torch::Tensor  tensor)
inline

◆ SameDeviceType()

bool SameDeviceType ( std::initializer_list< torch::Tensor >  tensors)
inline

◆ SameDtype()

bool SameDtype ( std::initializer_list< torch::Tensor >  tensors)
inline

◆ TensorInfoStr()

std::string TensorInfoStr ( std::initializer_list< torch::Tensor >  tensors)
inline

◆ ToTorchDtype()

template<class T >
TorchDtype_t ToTorchDtype ( )
inline

◆ ToTorchDtype< double >()

template<>
TorchDtype_t ToTorchDtype< double > ( )
inline

◆ ToTorchDtype< float >()

template<>
TorchDtype_t ToTorchDtype< float > ( )
inline

◆ ToTorchDtype< int16_t >()

template<>
TorchDtype_t ToTorchDtype< int16_t > ( )
inline

◆ ToTorchDtype< int32_t >()

template<>
TorchDtype_t ToTorchDtype< int32_t > ( )
inline

◆ ToTorchDtype< int64_t >()

template<>
TorchDtype_t ToTorchDtype< int64_t > ( )
inline

◆ ToTorchDtype< int8_t >()

template<>
TorchDtype_t ToTorchDtype< int8_t > ( )
inline

◆ ToTorchDtype< uint8_t >()

template<>
TorchDtype_t ToTorchDtype< uint8_t > ( )
inline