theseus.Objective.update
- Objective.update(input_tensors: Optional[Dict[str, Tensor]] = None, batch_ignore_mask: Optional[Tensor] = None, _update_vectorization: bool = True)
Updates all variables with the given input tensor dictionary.
The behavior of this method can be summarized by the following pseudocode:
for name, tensor in input_tensors.items(): var = self.get_var_with_name(name).update(tensor) check_batch_size_consistency(self.all_variables)
Any variables not included in the input tensors dictionary will retain their current tensors.
After updating, the objective will modify its batch size property according to the resulting tensors. Therefore, all variable tensors must have a consistent batch size (either 1 or the same value as the others), after the update is completed. Note that this includes variables not referenced in the
input_tensorsdictionary.- Parameters
input_tensors (Dict[str, torch.Tensor], optional) – if given, it must be a dictionary mapping variable names to tensors; if a variable with the given name is registered in the objective, its tensor will be replaced with the one in the dictionary (possibly permanently, depending on the value of
also_update). Defaults toNone, in which case nothing will be updated. In both cases, the objective will resolve the batch size with whatever tensors are stored after updating.batch_ignore_mask (torch.Tensor, optional) – an optional tensor of shape (batch_size,) of boolean type. Any
Truevalues indicate that this batch index should remain unchanged in all variables. Defaults toNone.
- Raises
ValueError – if tensors with inconsistent batch dimension are given.