8 #include <torch/custom_class.h>
9 #include <torch/script.h>
25 : _values(values), _row_splits(row_splits) {}
34 c10::intrusive_ptr<RaggedTensor>
FromRowSplits(torch::Tensor values,
35 torch::Tensor row_splits,
36 bool validate =
true)
const;
52 torch::Tensor
GetItem(
int key)
const;
60 c10::intrusive_ptr<RaggedTensor>
Clone()
const;
62 c10::intrusive_ptr<RaggedTensor>
Concat(
63 c10::intrusive_ptr<RaggedTensor> r_tensor, int64_t axis)
const;
66 c10::intrusive_ptr<RaggedTensor>
Add(T value)
const {
71 c10::intrusive_ptr<RaggedTensor>
Add_(T value) {
73 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
77 c10::intrusive_ptr<RaggedTensor>
Sub(T value)
const {
82 c10::intrusive_ptr<RaggedTensor>
Sub_(T value) {
84 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
88 c10::intrusive_ptr<RaggedTensor>
Mul(T value)
const {
93 c10::intrusive_ptr<RaggedTensor>
Mul_(T value) {
95 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
99 c10::intrusive_ptr<RaggedTensor>
Div(T value)
const {
103 template <
typename T>
104 c10::intrusive_ptr<RaggedTensor>
Div_(T value) {
106 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
109 template <
typename T>
110 c10::intrusive_ptr<RaggedTensor>
FloorDiv(T value)
const {
111 return FromRowSplits(_values.floor_divide(value), _row_splits,
false);
114 template <
typename T>
116 _values.floor_divide_(value);
117 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
121 torch::Tensor _values, _row_splits;
124 static auto registry =
125 torch::class_<RaggedTensor>(
"my_classes",
"RaggedTensor")
126 .def(torch::init<>())
131 [](
const c10::intrusive_ptr<RaggedTensor>&
self) {
132 return self->ToString();
135 [](
const c10::intrusive_ptr<RaggedTensor>&
self) {
136 return self->ToString();
139 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
140 int64_t key) {
return self->GetItem(key); })
146 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
147 torch::Tensor value) {
return self->Add(value); })
149 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
150 torch::Tensor value) {
return self->Add_(value); })
152 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
153 torch::Tensor value) {
return self->Add(value); })
155 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
156 torch::Tensor value) {
return self->Add_(value); })
159 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
160 torch::Tensor value) {
return self->Sub(value); })
162 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
163 torch::Tensor value) {
return self->Sub_(value); })
165 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
166 torch::Tensor value) {
return self->Sub(value); })
168 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
169 torch::Tensor value) {
return self->Sub_(value); })
172 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
173 torch::Tensor value) {
return self->Mul(value); })
175 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
176 torch::Tensor value) {
return self->Mul_(value); })
178 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
179 torch::Tensor value) {
return self->Mul(value); })
181 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
182 torch::Tensor value) {
return self->Mul_(value); })
185 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
186 torch::Tensor value) {
return self->Div(value); })
188 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
189 torch::Tensor value) {
return self->Div_(value); })
191 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
192 torch::Tensor value) {
return self->Div(value); })
194 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
195 torch::Tensor value) {
return self->Div_(value); })
197 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
198 torch::Tensor value) {
return self->FloorDiv(value); })
199 .def(
"__ifloordiv__",
200 [](
const c10::intrusive_ptr<RaggedTensor>&
self,
201 torch::Tensor value) {
202 return self->FloorDiv_(value);
Definition: RaggedTensor.h:19
c10::intrusive_ptr< RaggedTensor > FloorDiv_(T value)
Definition: RaggedTensor.h:115
c10::intrusive_ptr< RaggedTensor > Div_(T value)
Definition: RaggedTensor.h:104
RaggedTensor(torch::Tensor values, torch::Tensor row_splits)
Constructor for creating RaggedTensor with values and row_splits.
Definition: RaggedTensor.h:24
RaggedTensor()
Definition: RaggedTensor.h:21
c10::intrusive_ptr< RaggedTensor > Add_(T value)
Definition: RaggedTensor.h:71
torch::Tensor GetValues() const
Returns _values tensor.
Definition: RaggedTensor.cpp:39
c10::intrusive_ptr< RaggedTensor > Sub_(T value)
Definition: RaggedTensor.h:82
c10::intrusive_ptr< RaggedTensor > FloorDiv(T value) const
Definition: RaggedTensor.h:110
c10::intrusive_ptr< RaggedTensor > FromRowSplits(torch::Tensor values, torch::Tensor row_splits, bool validate=true) const
Definition: RaggedTensor.cpp:12
int64_t Len() const
Definition: RaggedTensor.cpp:54
c10::intrusive_ptr< RaggedTensor > Div(T value) const
Definition: RaggedTensor.h:99
c10::intrusive_ptr< RaggedTensor > Clone() const
Copy Tensor to the same device.
Definition: RaggedTensor.cpp:56
torch::Tensor GetRowSplits() const
Returns _row_splits tensor.
Definition: RaggedTensor.cpp:40
c10::intrusive_ptr< RaggedTensor > Sub(T value) const
Definition: RaggedTensor.h:77
std::string ToString() const
Returns string representation.
Definition: RaggedTensor.cpp:42
c10::intrusive_ptr< RaggedTensor > Mul_(T value)
Definition: RaggedTensor.h:93
c10::intrusive_ptr< RaggedTensor > Add(T value) const
Definition: RaggedTensor.h:66
c10::intrusive_ptr< RaggedTensor > Concat(c10::intrusive_ptr< RaggedTensor > r_tensor, int64_t axis) const
Definition: RaggedTensor.cpp:60
c10::intrusive_ptr< RaggedTensor > Mul(T value) const
Definition: RaggedTensor.h:88
torch::Tensor GetItem(int key) const
Definition: RaggedTensor.cpp:49