Open3D (C++ API)  0.18.0+252c867
RaggedTensor.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include <torch/custom_class.h>
9 #include <torch/script.h>
10 
11 #include <vector>
12 
14 
19 struct RaggedTensor : torch::CustomClassHolder {
20 public:
22 
24  RaggedTensor(torch::Tensor values, torch::Tensor row_splits)
25  : _values(values), _row_splits(row_splits) {}
26 
34  c10::intrusive_ptr<RaggedTensor> FromRowSplits(torch::Tensor values,
35  torch::Tensor row_splits,
36  bool validate = true) const;
37 
39  torch::Tensor GetValues() const;
40 
42  torch::Tensor GetRowSplits() const;
43 
45  std::string ToString() const;
46 
52  torch::Tensor GetItem(int key) const;
53 
57  int64_t Len() const;
58 
60  c10::intrusive_ptr<RaggedTensor> Clone() const;
61 
62  c10::intrusive_ptr<RaggedTensor> Concat(
63  c10::intrusive_ptr<RaggedTensor> r_tensor, int64_t axis) const;
64 
65  template <typename T>
66  c10::intrusive_ptr<RaggedTensor> Add(T value) const {
67  return FromRowSplits(_values + value, _row_splits, false);
68  }
69 
70  template <typename T>
71  c10::intrusive_ptr<RaggedTensor> Add_(T value) {
72  _values += value;
73  return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
74  }
75 
76  template <typename T>
77  c10::intrusive_ptr<RaggedTensor> Sub(T value) const {
78  return FromRowSplits(_values - value, _row_splits, false);
79  }
80 
81  template <typename T>
82  c10::intrusive_ptr<RaggedTensor> Sub_(T value) {
83  _values -= value;
84  return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
85  }
86 
87  template <typename T>
88  c10::intrusive_ptr<RaggedTensor> Mul(T value) const {
89  return FromRowSplits(_values * value, _row_splits, false);
90  }
91 
92  template <typename T>
93  c10::intrusive_ptr<RaggedTensor> Mul_(T value) {
94  _values *= value;
95  return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
96  }
97 
98  template <typename T>
99  c10::intrusive_ptr<RaggedTensor> Div(T value) const {
100  return FromRowSplits(_values / value, _row_splits, false);
101  }
102 
103  template <typename T>
104  c10::intrusive_ptr<RaggedTensor> Div_(T value) {
105  _values /= value;
106  return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
107  }
108 
109  template <typename T>
110  c10::intrusive_ptr<RaggedTensor> FloorDiv(T value) const {
111  return FromRowSplits(_values.floor_divide(value), _row_splits, false);
112  }
113 
114  template <typename T>
115  c10::intrusive_ptr<RaggedTensor> FloorDiv_(T value) {
116  _values.floor_divide_(value);
117  return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
118  }
119 
120 private:
121  torch::Tensor _values, _row_splits;
122 };
123 
124 static auto registry =
125  torch::class_<RaggedTensor>("my_classes", "RaggedTensor")
126  .def(torch::init<>())
127  .def("from_row_splits", &RaggedTensor::FromRowSplits)
128  .def("get_values", &RaggedTensor::GetValues)
129  .def("get_row_splits", &RaggedTensor::GetRowSplits)
130  .def("__repr__",
131  [](const c10::intrusive_ptr<RaggedTensor>& self) {
132  return self->ToString();
133  })
134  .def("__str__",
135  [](const c10::intrusive_ptr<RaggedTensor>& self) {
136  return self->ToString();
137  })
138  .def("__getitem__",
139  [](const c10::intrusive_ptr<RaggedTensor>& self,
140  int64_t key) { return self->GetItem(key); })
141  .def("__len__", &RaggedTensor::Len)
142  .def("clone", &RaggedTensor::Clone)
143  .def("concat", &RaggedTensor::Concat)
144 
145  .def("add",
146  [](const c10::intrusive_ptr<RaggedTensor>& self,
147  torch::Tensor value) { return self->Add(value); })
148  .def("add_",
149  [](const c10::intrusive_ptr<RaggedTensor>& self,
150  torch::Tensor value) { return self->Add_(value); })
151  .def("__add__",
152  [](const c10::intrusive_ptr<RaggedTensor>& self,
153  torch::Tensor value) { return self->Add(value); })
154  .def("__iadd__",
155  [](const c10::intrusive_ptr<RaggedTensor>& self,
156  torch::Tensor value) { return self->Add_(value); })
157 
158  .def("sub",
159  [](const c10::intrusive_ptr<RaggedTensor>& self,
160  torch::Tensor value) { return self->Sub(value); })
161  .def("sub_",
162  [](const c10::intrusive_ptr<RaggedTensor>& self,
163  torch::Tensor value) { return self->Sub_(value); })
164  .def("__sub__",
165  [](const c10::intrusive_ptr<RaggedTensor>& self,
166  torch::Tensor value) { return self->Sub(value); })
167  .def("__isub__",
168  [](const c10::intrusive_ptr<RaggedTensor>& self,
169  torch::Tensor value) { return self->Sub_(value); })
170 
171  .def("mul",
172  [](const c10::intrusive_ptr<RaggedTensor>& self,
173  torch::Tensor value) { return self->Mul(value); })
174  .def("mul_",
175  [](const c10::intrusive_ptr<RaggedTensor>& self,
176  torch::Tensor value) { return self->Mul_(value); })
177  .def("__mul__",
178  [](const c10::intrusive_ptr<RaggedTensor>& self,
179  torch::Tensor value) { return self->Mul(value); })
180  .def("__imul__",
181  [](const c10::intrusive_ptr<RaggedTensor>& self,
182  torch::Tensor value) { return self->Mul_(value); })
183 
184  .def("div",
185  [](const c10::intrusive_ptr<RaggedTensor>& self,
186  torch::Tensor value) { return self->Div(value); })
187  .def("div_",
188  [](const c10::intrusive_ptr<RaggedTensor>& self,
189  torch::Tensor value) { return self->Div_(value); })
190  .def("__truediv__",
191  [](const c10::intrusive_ptr<RaggedTensor>& self,
192  torch::Tensor value) { return self->Div(value); })
193  .def("__itruediv__",
194  [](const c10::intrusive_ptr<RaggedTensor>& self,
195  torch::Tensor value) { return self->Div_(value); })
196  .def("__floordiv__",
197  [](const c10::intrusive_ptr<RaggedTensor>& self,
198  torch::Tensor value) { return self->FloorDiv(value); })
199  .def("__ifloordiv__",
200  [](const c10::intrusive_ptr<RaggedTensor>& self,
201  torch::Tensor value) {
202  return self->FloorDiv_(value);
203  });
Definition: RaggedTensor.h:19
c10::intrusive_ptr< RaggedTensor > FloorDiv_(T value)
Definition: RaggedTensor.h:115
c10::intrusive_ptr< RaggedTensor > Div_(T value)
Definition: RaggedTensor.h:104
RaggedTensor(torch::Tensor values, torch::Tensor row_splits)
Constructor for creating RaggedTensor with values and row_splits.
Definition: RaggedTensor.h:24
RaggedTensor()
Definition: RaggedTensor.h:21
c10::intrusive_ptr< RaggedTensor > Add_(T value)
Definition: RaggedTensor.h:71
torch::Tensor GetValues() const
Returns _values tensor.
Definition: RaggedTensor.cpp:39
c10::intrusive_ptr< RaggedTensor > Sub_(T value)
Definition: RaggedTensor.h:82
c10::intrusive_ptr< RaggedTensor > FloorDiv(T value) const
Definition: RaggedTensor.h:110
c10::intrusive_ptr< RaggedTensor > FromRowSplits(torch::Tensor values, torch::Tensor row_splits, bool validate=true) const
Definition: RaggedTensor.cpp:12
int64_t Len() const
Definition: RaggedTensor.cpp:54
c10::intrusive_ptr< RaggedTensor > Div(T value) const
Definition: RaggedTensor.h:99
c10::intrusive_ptr< RaggedTensor > Clone() const
Copy Tensor to the same device.
Definition: RaggedTensor.cpp:56
torch::Tensor GetRowSplits() const
Returns _row_splits tensor.
Definition: RaggedTensor.cpp:40
c10::intrusive_ptr< RaggedTensor > Sub(T value) const
Definition: RaggedTensor.h:77
std::string ToString() const
Returns string representation.
Definition: RaggedTensor.cpp:42
c10::intrusive_ptr< RaggedTensor > Mul_(T value)
Definition: RaggedTensor.h:93
c10::intrusive_ptr< RaggedTensor > Add(T value) const
Definition: RaggedTensor.h:66
c10::intrusive_ptr< RaggedTensor > Concat(c10::intrusive_ptr< RaggedTensor > r_tensor, int64_t axis) const
Definition: RaggedTensor.cpp:60
c10::intrusive_ptr< RaggedTensor > Mul(T value) const
Definition: RaggedTensor.h:88
torch::Tensor GetItem(int key) const
Definition: RaggedTensor.cpp:49