make_criterion_tuple_aware#
- skactiveml.utils.make_criterion_tuple_aware(criterion, criterion_output_keys=None, forward_outputs=None)[source]#
Create a loss class (or wrap an existing instance) that selects part of a model’s (possibly tuple-valued) output before passing it to the criterion.
This utility generates (and caches) a dynamic subclass named TupleAware<LossName>Idx<…> of the given base loss class. The subclass overrides forward so that, if the input argument is a tuple (e.g., (logits, embeddings, …)), only the element(s) selected by criterion_output_keys are passed to the base class’s forward. If input is not a tuple, it is forwarded unchanged.
If no selection is required (i.e., criterion_output_keys is None and forward_outputs is None so that the full input is passed unchanged), the original criterion is returned without wrapping.
- Parameters:
- criteriontorch.nn.Module.__class__ or torch.nn.Module
Either a loss class (subclass of torch.nn.Module), e.g. nn.CrossEntropyLoss, or
a loss instance, e.g. nn.CrossEntropyLoss().
- criterion_output_keysstr or sequence of str or None, default=None
Name or names of the forward outputs that are passed to the loss / criterion during training. Use this when module.forward returns multiple outputs (e.g. (logits, embeddings, …)), but the criterion expects a single tensor input or a specific tuple of inputs. The names must refer to keys of forward_outputs. If criterion_output_keys is not None and forward_outputs is None, a ValueError is raised because the names cannot be resolved.
If a str, the corresponding named output of module.forward (i.e., the raw tensor selected via its index in forward_outputs before applying any transform) is passed to the criterion (e.g. “logits” to use only the class scores).
If a sequence of str, the selected named outputs are packed into a tuple and passed to the criterion in that order. Each raw forward output index may appear at most once: using multiple names that resolve to the same underlying index (e.g. “proba” and “logits” both pointing to index 0) is not allowed and results in a ValueError.
If None:
and forward_outputs is not None, the first output defined by forward_outputs is used as criterion input;
and forward_outputs is None, the full input is passed unchanged. In this case, the caller is responsible for ensuring that module.forward returns a single tensor if the criterion does not accept tuples.
- forward_outputsdict[str, tuple[int, Callable | None]] or None, default=None
Dictionary from output names to (idx, transform) tuples, as used in the estimator’s forward_outputs parameter. Only the keys and their associated indices idx are used here to resolve criterion_output_keys into raw output positions; the transform part is ignored by this helper. If criterion_output_keys is given as a string or sequence of strings, forward_outputs must be provided.
- Returns:
- torch.nn.Module.__class__ or torch.nn.Module
If criterion is a class, returns the generated subclass TupleAware<LossName>Idx<…> (or the original class if no selection is required).
If criterion is an instance, returns a new instance (deep copy) whose class is that subclass. The original instance is not modified. If no selection is required, the original instance is returned unchanged.