make_criterion_tuple_aware#

skactiveml.utils.make_criterion_tuple_aware(criterion, criterion_output_keys=None, forward_outputs=None)[source]#

Create a loss class (or wrap an existing instance) that selects part of a model’s (possibly tuple-valued) output before passing it to the criterion.

This utility generates (and caches) a dynamic subclass named TupleAware<LossName>Idx<…> of the given base loss class. The subclass overrides forward so that, if the input argument is a tuple (e.g., (logits, embeddings, …)), only the element(s) selected by criterion_output_keys are passed to the base class’s forward. If input is not a tuple, it is forwarded unchanged.

If no selection is required (i.e., criterion_output_keys is None and forward_outputs is None so that the full input is passed unchanged), the original criterion is returned without wrapping.

Parameters:
criteriontorch.nn.Module.__class__ or torch.nn.Module
  • Either a loss class (subclass of torch.nn.Module), e.g. nn.CrossEntropyLoss, or

  • a loss instance, e.g. nn.CrossEntropyLoss().

criterion_output_keysstr or sequence of str or None, default=None

Name or names of the forward outputs that are passed to the loss / criterion during training. Use this when module.forward returns multiple outputs (e.g. (logits, embeddings, …)), but the criterion expects a single tensor input or a specific tuple of inputs. The names must refer to keys of forward_outputs. If criterion_output_keys is not None and forward_outputs is None, a ValueError is raised because the names cannot be resolved.

  • If a str, the corresponding named output of module.forward (i.e., the raw tensor selected via its index in forward_outputs before applying any transform) is passed to the criterion (e.g. “logits” to use only the class scores).

  • If a sequence of str, the selected named outputs are packed into a tuple and passed to the criterion in that order. Each raw forward output index may appear at most once: using multiple names that resolve to the same underlying index (e.g. “proba” and “logits” both pointing to index 0) is not allowed and results in a ValueError.

  • If None:

    • and forward_outputs is not None, the first output defined by forward_outputs is used as criterion input;

    • and forward_outputs is None, the full input is passed unchanged. In this case, the caller is responsible for ensuring that module.forward returns a single tensor if the criterion does not accept tuples.

forward_outputsdict[str, tuple[int, Callable | None]] or None, default=None

Dictionary from output names to (idx, transform) tuples, as used in the estimator’s forward_outputs parameter. Only the keys and their associated indices idx are used here to resolve criterion_output_keys into raw output positions; the transform part is ignored by this helper. If criterion_output_keys is given as a string or sequence of strings, forward_outputs must be provided.

Returns:
torch.nn.Module.__class__ or torch.nn.Module
  • If criterion is a class, returns the generated subclass TupleAware<LossName>Idx<…> (or the original class if no selection is required).

  • If criterion is an instance, returns a new instance (deep copy) whose class is that subclass. The original instance is not modified. If no selection is required, the original instance is returned unchanged.