open3d.ml.tf.models.PointRCNN¶
-
class
open3d.ml.tf.models.
PointRCNN
(*args, **kwargs)¶ Object detection model. Based on the PoinRCNN architecture https://github.com/sshaoshuai/PointRCNN.
The network is not trainable end-to-end, it requires pre-training of the RPN module, followed by training of the RCNN module. For this the mode must be set to ‘RPN’, with this, the network only outputs intermediate results. If the RPN module is trained, the mode can be set to ‘RCNN’ (default), with this, the second module can be trained and the output are the final predictions.
For inference use the ‘RCNN’ mode.
- Parameters
name (string) – Name of model. Default to “PointRCNN”.
device (string) – ‘cuda’ or ‘cpu’. Default to ‘cuda’.
classes (string[]) – List of classes used for object detection: Default to [‘Car’].
score_thres (float) – Min confindence score for prediction. Default to 0.3.
npoints (int) – Number of processed input points. Default to 16384.
rpn (dict) – Config of RPN module. Default to {}.
rcnn (dict) – Config of RCNN module. Default to {}.
mode (string) – Execution mode, ‘RPN’ or ‘RCNN’. Default to ‘RCNN’.
-
__init__
(name='PointRCNN', classes=['Car'], score_thres=0.3, npoints=16384, rpn={}, rcnn={}, mode='RCNN', **kwargs)¶ Initialize self. See help(type(self)) for accurate signature.
-
call
(inputs, training=True)¶ Calls the model on new inputs.
In this case call just reapplies all ops in the graph to the new inputs (e.g. build a new computational graph from the provided inputs).
Note: This method should not be called directly. It is only meant to be overridden when subclassing tf.keras.Model. To call a model on an input, always use the __call__ method, i.e. model(inputs), which relies on the underlying call method.
- Parameters
inputs – A tensor or list of tensors.
training – Boolean or boolean scalar tensor, indicating whether to run the Network in training mode or inference mode.
mask – A mask or list of masks. A mask can be either a tensor or None (no mask).
- Returns
A tensor if there is a single output, or a list of tensors if there are more than one outputs.
-
filter_objects
(bbox_objs)¶ Filter objects based on classes to train.
- Parameters
bbox_objs – Bounding box objects from dataset class.
- Returns
Filtered bounding box objects.
-
static
generate_rpn_training_labels
(points, bboxes, bboxes_world, calib=None)¶ Generates labels for RPN network.
Classifies each point as foreground/background based on points inside bbox. We don’t train on ambiguous points which are just outside bounding boxes(calculated by extended_boxes). Also computes regression labels for bounding box proposals(in bounding box frame).
- Parameters
points – Input pointcloud.
bboxes – bounding boxes in camera frame.
bboxes_world – bounding boxes in world frame.
calib – Calibration file for cam_to_world matrix.
- Returns
Classification and Regression labels.
-
get_batch_gen
(dataset, steps_per_epoch=None, batch_size=1)¶
-
get_optimizer
(cfg)¶ Returns an optimizer object for the model.
- Parameters
cfg_pipeline – A Config object with the configuration of the pipeline.
- Returns
Returns a new optimizer object.
-
inference_end
(results, inputs)¶ This function is called after the inference.
This function can be implemented to apply post-processing on the network outputs.
- Parameters
results – The model outputs as returned by the call() function. Post-processing is applied on this object.
- Returns
Returns True if the inference is complete and otherwise False. Returning False can be used to implement inference for large point clouds which require multiple passes.
-
loss
(results, inputs, training=True)¶ Computes the loss given the network input and outputs.
- Parameters
Loss – A loss object.
results – This is the output of the model.
inputs – This is the input to the model.
- Returns
Returns the loss value.
-
preprocess
(data, attr)¶ Data preprocessing function.
This function is called before training to preprocess the data from a dataset.
- Parameters
data – A sample from the dataset.
attr – The corresponding attributes.
- Returns
Returns the preprocessed data
-
transform
(data, attr)¶ Transform function for the point cloud and features.
- Parameters
args – A list of tf Tensors.