Source code for skactiveml.utils._functions

import inspect
from types import MethodType
from makefun import with_signature

successful_skorch_torch_import = False
try:
    import sys
    from collections.abc import Sequence
    from copy import deepcopy
    from torch import nn

    from ._validation import _check_forward_outputs

    successful_skorch_torch_import = True
except ImportError:  # pragma: no cover
    pass


[docs] def call_func( f_callable, only_mandatory=False, ignore_var_keyword=False, **kwargs ): """Calls a function with the given parameters given in `kwargs`, if they exist as parameters in `f_callable`. Parameters ---------- f_callable : callable The function or object that is to be called. only_mandatory : boolean, default=False If `True`, only mandatory parameters are set. ignore_var_keyword : boolean, default=False If `False`, all kwargs are passed when `f_callable` uses a parameter that is of kind `Parameter.VAR_KEYWORD`, i.e., `**kwargs`. For further reference see the `inspect` package. kwargs : kwargs All parameters that could be used for calling f_callable. Returns ------- f_callable_result : return type of `f_callable` The return value of f_callable. """ params = inspect.signature(f_callable).parameters param_keys = params.keys() if only_mandatory: param_keys = list( filter(lambda k: params[k].default == params[k].empty, param_keys) ) has_var_keyword = any( filter(lambda p: p.kind == p.VAR_KEYWORD, params.values()) ) if has_var_keyword and not ignore_var_keyword and not only_mandatory: vars = kwargs else: vars = dict(filter(lambda e: e[0] in param_keys, kwargs.items())) return f_callable(**vars)
[docs] def match_signature(wrapped_obj_name, func_name): """A decorator that matches the signature to a given method from a reference and hides it when the reference object does not have the wrapped function. This is especially helpful for wrapper classes whose functions should appear. This decorator is heavily inspired by the `available_if` decorator from `sklearn.utils.metaestimators`. Parameters ---------- wrapped_obj_name : str The name of the object that will be wrapped. func_name : str The name of the function that will be wrapped. Returns ------- wrapped_obj : callable The wrapped function. """ class _MatchSignatureDescriptor: """_MatchSignatureDescriptor A descriptor that allows a wrapper to clone the signature of a method `func_name` from the wrapped object `wrapped_obj_name`. Furthermore, this extends upon the conditional property as implemented in `available_if` from `sklearn.utils.metaestimators`. Parameters ---------- fn: MethodType The method that should be wrapped. wrapped_obj_name: str The name of the wrapped object within the wrapper class. func_name : str The method name of the function that should be wrapped. """ def __init__(self, fn, wrapped_obj_name, func_name): self.fn = fn self.wrapped_obj_name = wrapped_obj_name self.func_name = func_name self.__name__ = func_name def __get__(self, obj, owner=None): """Wrap the method specified in `self.func_name` from the wrapped object `self.wrapped_obj_name` such that the signature will be the same. Parameters ---------- obj : object The wrapper object. This parameter will be None, if the method is accessed via the class and not an instantiated object. owner : class, default=None The wrapper class. Returns ------- The wrapped method. """ if obj is not None: reference_object = getattr(obj, self.wrapped_obj_name) if not hasattr(reference_object, self.func_name): raise AttributeError( f"This {reference_object} has" f" no method {self.func_name}." ) reference_function = getattr(reference_object, self.func_name) # Check if the refrenced function is a method of that object. # If it is, use `__func__` to copy the name of the `self` # parameter # If it is not, (i.e. it has been added as a lambda function), # add a provisory self argument in the first position if hasattr(reference_function, "__func__"): new_sig = inspect.signature(reference_function.__func__) else: reference_sig = inspect.signature(reference_function) new_parameters = list(reference_sig.parameters.values()) new_parameters.insert( 0, inspect.Parameter( "self", kind=inspect._ParameterKind.POSITIONAL_OR_KEYWORD, ), ) new_sig = inspect.Signature( new_parameters, return_annotation=reference_sig.return_annotation, ) # create a wrapper with the new signature and the correct # function name function_decorator = with_signature( func_signature=new_sig, func_name=self.fn.__name__, ) fn = function_decorator(self.fn) out = MethodType(fn, obj) else: out = self.fn return out return lambda fn: _MatchSignatureDescriptor( fn, wrapped_obj_name, func_name=func_name )
if successful_skorch_torch_import:
[docs] def make_criterion_tuple_aware( criterion, criterion_output_keys=None, forward_outputs=None, ): """ Create a loss class (or wrap an existing instance) that selects part of a model's (possibly tuple-valued) output before passing it to the criterion. This utility generates (and caches) a dynamic subclass named `TupleAware<LossName>Idx<...>` of the given base loss class. The subclass overrides `forward` so that, if the `input` argument is a tuple (e.g., `(logits, embeddings, ...)`), only the element(s) selected by `criterion_output_keys` are passed to the base class's `forward`. If `input` is not a tuple, it is forwarded unchanged. If no selection is required (i.e., `criterion_output_keys` is `None` and `forward_outputs` is `None` so that the full input is passed unchanged), the original `criterion` is returned without wrapping. Parameters ---------- criterion : torch.nn.Module.__class__ or torch.nn.Module - Either a loss class (subclass of `torch.nn.Module`), e.g. `nn.CrossEntropyLoss`, or - a loss instance, e.g. `nn.CrossEntropyLoss()`. criterion_output_keys : str or sequence of str or None, default=None Name or names of the forward outputs that are passed to the loss / criterion during training. Use this when `module.forward` returns multiple outputs (e.g. `(logits, embeddings, ...)`), but the criterion expects a single tensor input or a specific tuple of inputs. The names must refer to keys of `forward_outputs`. If `criterion_output_keys` is not `None` and `forward_outputs` is `None`, a `ValueError` is raised because the names cannot be resolved. - If a `str`, the corresponding named output of `module.forward` (i.e., the raw tensor selected via its index in `forward_outputs` before applying any transform) is passed to the criterion (e.g. `"logits"` to use only the class scores). - If a sequence of `str`, the selected named outputs are packed into a tuple and passed to the criterion in that order. Each raw forward output index may appear at most once: using multiple names that resolve to the same underlying index (e.g. `"proba"` and `"logits"` both pointing to index 0) is not allowed and results in a `ValueError`. - If `None`: - and `forward_outputs` is not `None`, the first output defined by `forward_outputs` is used as criterion input; - and `forward_outputs` is `None`, the full `input` is passed unchanged. In this case, the caller is responsible for ensuring that `module.forward` returns a single tensor if the criterion does not accept tuples. forward_outputs : dict[str, tuple[int, Callable | None]] or None,\ default=None Dictionary from output names to `(idx, transform)` tuples, as used in the estimator's `forward_outputs` parameter. Only the keys and their associated indices `idx` are used here to resolve `criterion_output_keys` into raw output positions; the transform part is ignored by this helper. If `criterion_output_keys` is given as a string or sequence of strings, `forward_outputs` must be provided. Returns ------- torch.nn.Module.__class__ or torch.nn.Module - If `criterion` is a class, returns the generated subclass `TupleAware<LossName>Idx<...>` (or the original class if no selection is required). - If `criterion` is an instance, returns a new instance (deep copy) whose class is that subclass. The original instance is not modified. If no selection is required, the original instance is returned unchanged. """ # Validate criterion if isinstance(criterion, nn.Module): base_cls = criterion.__class__ elif isinstance(criterion, type) and issubclass(criterion, nn.Module): base_cls = criterion else: raise TypeError( "criterion must be an `nn.Module` subclass or instance." ) # Check invalid combination. if criterion_output_keys is not None and forward_outputs is None: raise ValueError( "`criterion_output_keys` is not `None`, but `forward_outputs` " "is `None`. Pass the same `forward_outputs` mapping that " "describes `module.forward`." ) # No selection at all: pass full input unchanged. if criterion_output_keys is None and forward_outputs is None: return criterion # Validate forward_outputs if it is provided. _check_forward_outputs(forward_outputs=forward_outputs) # Resolve criterion_output_keys -> list of names. if criterion_output_keys is None: selected_names = [next(iter(forward_outputs))] elif isinstance(criterion_output_keys, str): selected_names = [criterion_output_keys] elif isinstance(criterion_output_keys, Sequence) and not isinstance( criterion_output_keys, bytes ): if len(criterion_output_keys) == 0: raise ValueError( "`criterion_output_keys` must not be an empty sequence." ) # Ensure all names are strings for name in criterion_output_keys: if not isinstance(name, str): raise TypeError( "All entries in `criterion_output_keys` must be " "strings when a sequence is provided, got " f"{type(name)} in {criterion_output_keys!r}." ) selected_names = list(criterion_output_keys) else: raise TypeError( "`criterion_output_keys` must be `None`, a string, or a " f"sequence of strings, got {type(criterion_output_keys)}." ) # Map names -> raw indices using forward_outputs[name][0]. indices = [] for name in selected_names: spec = forward_outputs.get(name, None) if spec is None: raise ValueError( f"Unknown forward output name {name!r} in " "`criterion_output_keys`. Available names are: " f"{list(forward_outputs.keys())}." ) indices.append(spec[0]) # Enforce: each raw index at most once. if len(set(indices)) != len(indices): raise ValueError( "`criterion_output_keys` must not contain multiple names " "that refer to the same raw forward output index. " f"Names {selected_names!r} resolve to indices {indices!r}." ) # Build selector: int or tuple[int,...]. if len(indices) == 1: selector = indices[0] max_idx = selector else: selector = tuple(indices) max_idx = max(selector) # Build a class-name key from the selector for caching / pickling. if isinstance(selector, int): idx_key = str(selector) else: idx_key = "_".join(str(i) for i in selector) cls_name = f"TupleAware{base_cls.__name__}Idx{idx_key}" # Build the wrapper `forward`. def forward(self, input, target, *args, **kwargs): if isinstance(input, tuple): if self._criterion_output_max_selector >= len(input): raise ValueError( f"`forward_outputs` references raw output index " f"{self._criterion_output_max_selector}, but " f"`module.forward` returned only {len(input)} " f"object(s)." ) if isinstance(self._criterion_output_selector, int): input = input[self._criterion_output_selector] else: # selector is a tuple of indices input = tuple( input[i] for i in self._criterion_output_selector ) return base_cls.forward(self, input, target, *args, **kwargs) # Keep signature/tool-tips identical. forward.__signature__ = inspect.signature(base_cls.forward) forward.__doc__ = base_cls.forward.__doc__ # Create/reuse subclass and expose it under this module. mod = sys.modules[__name__] TupleAwareCls = getattr(mod, cls_name, None) if TupleAwareCls is None: TupleAwareCls = type( cls_name, (base_cls,), { "forward": forward, "__module__": mod.__name__, "_criterion_output_selector": selector, "_criterion_output_max_selector": max_idx, }, ) setattr(mod, TupleAwareCls.__name__, TupleAwareCls) if isinstance(criterion, nn.Module): wrapped = deepcopy(criterion) object.__setattr__(wrapped, "__class__", TupleAwareCls) return wrapped else: return TupleAwareCls