Source code for skactiveml.classifier.multiannotator._annot_mix_classifier

try:
    import math
    import numpy as np
    import torch

    from sklearn.utils.validation import check_array
    from torch import nn
    from torch.nn import KLDivLoss
    from torch.nn import functional as F
    from torch.utils.data import default_collate

    from ...base import SkactivemlClassifier
    from ...utils import (
        MISSING_LABEL,
        check_n_features,
        check_scalar,
    )
    from ._utils import (
        _MultiAnnotatorClassificationModule,
        _SkorchMultiAnnotatorClassifier,
    )

[docs] class AnnotMixClassifier(_SkorchMultiAnnotatorClassifier): """Annot-Mix Annot-Mix [1]_ trains a multi-annotator classifier using an extension of MixUp [2]_. The main idea is to apply MixUp not only to samples and class labels, but to sample–annotator pairs: it convexly combines inputs and their annotator-specific noisy labels and trains a one-stage model that jointly estimates the true label distribution and each annotator’s reliability. In this way, Annot-Mix can handle multiple, potentially conflicting labels per sample while using MixUp-style regularization to become more robust to label noise. Parameters ---------- clf_module : nn.Module or nn.Module.__class__ A PyTorch module as classification model outputting logits for samples as input. In general, the uninstantiated class should be passed, although instantiated modules will also work. The `forward` module must return logits as first element and optional sample embeddings as second element. If no sample embeddings are returned, the implementation uses the original samples. alpha : float, default=0.5 MixUp concentration parameter. The mix coefficient `lambda` is drawn from `Beta(alpha, alpha)`. Use `alpha=0` to disable MixUp. annotator_embed_dim : int, default=16 Dimensionality of the annotator embedding used to model annotator-specific behavior. sample_embed_dim : int, default=0 Dimensionality of an optional learnable sample-embedding used to model sample-specific behavior of each annotator. If `sample_embed_dim=0`, the annotator performances are only modeled as class-specific. hidden_dim : int or None, default=None Hidden size of the fusion multi-layer perceptron that propagates sample and annotator representations. If `None`, a sensible default is used, which depends on the other input parameters. Note that this parameter has no effect for `n_hidden_layers=0`. n_hidden_layers : int, default=0 Number of hidden layers in the fusion multi-layer perceptron. hidden_dropout : float, default=0.1 Dropout probability applied in the fusion multi-layer perceptron. Note that this parameter has no effect for `n_hidden_layers=0`. eta : float in (0, 1), default=0.9 Prior annotator performance, i.e., the probability of obtaining a correct annotation from an arbitrary annotator for an arbitrary sample of an arbitrary class. n_annotators : int, default=None Number of annotators. If `n_annotators=None`, the number of annotators is inferred by the shape of `y` during training. neural_net_param_dict : dict, default=None Additional arguments for `skorch.net.NeuralNet`. If `neural_net_param_dict` is `None`, no extra arguments are added. `module`, `criterion`, `predict_nonlinearity`, and `train_split` are not allowed in this dictionary. sample_dtype : str or type, default=np.float32 Dtype to which input samples are cast inside the estimator. If set to `None`, the input dtype is preserved. classes : array-like of shape (n_classes,), default=None Holds the label for each class. If `None`, the classes are determined during the fit. missing_label : scalar or string or np.nan or None, default=np.nan Value to represent a missing label. cost_matrix : array-like of shape (n_classes, n_classes), default=None Cost matrix with `cost_matrix[i,j]` indicating cost of predicting class `classes[j]` for a sample of class `classes[i]`. Can be only set, if `classes` is not `None`. random_state : int or RandomState instance or None, default=None Determines random number for `predict` method. Pass an int for reproducible results across multiple method calls. References ---------- .. [1] Herde, M., Lührs, L., Huseljic, D., & Sick, B. (2024). Annot-Mix: Learning with Noisy Class Labels from Multiple Annotators via a Mixup Extension. Eur. Conf. Artif. Intell. .. [2] Zhang, H., Cisse, M., Dauphin, Y. N., & Lopez-Paz, D. (2018). mixup: Beyond Empirical Risk Minimization. Int. Conf. Learn. Represent. """ _ALLOWED_EXTRA_OUTPUTS = { "logits", "embeddings", "annotator_perf", "annotator_class", "annotator_embeddings", } def __init__( self, clf_module, alpha=0.5, sample_embed_dim=0, annotator_embed_dim=16, hidden_dim=None, n_hidden_layers=0, hidden_dropout=0.1, eta=0.9, n_annotators=None, neural_net_param_dict=None, sample_dtype=np.float32, classes=None, cost_matrix=None, missing_label=MISSING_LABEL, random_state=None, ): super(AnnotMixClassifier, self).__init__( multi_annotator_module=_AnnotMixModule, clf_module=clf_module, criterion=KLDivLoss, classes=classes, missing_label=missing_label, cost_matrix=cost_matrix, random_state=random_state, neural_net_param_dict=neural_net_param_dict, sample_dtype=sample_dtype, ) self.clf_module = clf_module self.alpha = alpha self.sample_embed_dim = sample_embed_dim self.annotator_embed_dim = annotator_embed_dim self.hidden_dim = hidden_dim self.n_hidden_layers = n_hidden_layers self.hidden_dropout = hidden_dropout self.eta = eta self.n_annotators = n_annotators
[docs] def predict( self, X, extra_outputs=None, ): """Return class predictions for the test samples `X`. By default, this method returns only the class predictions `y_pred`. If `extra_outputs` is provided, a tuple is returned whose first element is `y_pred` and whose remaining elements are the requested additional forward outputs, in the order specified by `extra_outputs`. Parameters ---------- X : array-like of shape (n_samples, ...) Test samples. extra_outputs : None or str or sequence of str, default=None Names of additional outputs to return next to `y_pred`. The names must be a subset of the following keys: - "logits" : Additionally return the class-membership logits `L_class` for the samples in `X`. - "embeddings" : Additionally return the learned embeddings `X_embed` for the samples in `X`. - "annotator_perf" : additionally return the estimated annotator performance probabilities `P_perf` for each sample–annotator pair. - "annotator_class" : Additionally return the annotator–class probability estimates `P_annot` for each sample, class, and annotator. - "annotator_embeddings" : Additionally return the learned embeddings `A_embed` for the annotators as the next element of the output tuple. Returns ------- y_pred : numpy.ndarray of shape (n_samples,) Class labels of the test samples. *extras : numpy.ndarray, optional Only returned if `extra_outputs` is not `None`. In that case, the method returns a tuple whose first element is `y_pred` and whose remaining elements correspond to the requested forward outputs in the order given by `extra_outputs`. Potential outputs are: - `L_class` : `np.ndarray` of shape `(n_samples, n_classes)`, where `L_class[n, c]` is the logit for the class `classes_[c]` of sample `X[n]`. - `X_embed` : `np.ndarray` of shape `(n_samples, ...)`, where `X_embed[n]` refers to the learned embedding for sample `X[n]`. - `P_perf` : `np.ndarray` of shape `(n_samples, n_annotators)`, where `P_perf[n, m]` refers to the estimated label correctness probability (performance) of annotator `m` when labeling sample `X[n]`. - `P_annot` : `np.ndarray` of shape `(n_samples, n_annotators, n_classes)`, where `P_annot[n, m, c]` refers to the probability that annotator `m` provides the class label `c` for sample `X[n]`. - `A_embed` : `np.ndarray` of shape `(n_annotators, annotator_embed_dim)`, where `A_embed[m]` refers to the learned embedding for annotator `m`. """ return SkactivemlClassifier.predict( self, X=X, extra_outputs=extra_outputs, )
[docs] def predict_proba( self, X, extra_outputs=None, ): """Return class probability estimates for the test samples `X`. By default, this method returns only the class probabilities `P`. If `extra_outputs` is provided, a tuple is returned whose first element is `P` and whose remaining elements are the requested additional forward outputs, in the order specified by `extra_outputs`. Parameters ---------- X : array-like of shape (n_samples, ...) Test samples. extra_outputs : None or str or sequence of str, default=None Names of additional outputs to return next to `P`. The names must be a subset of the following keys: - "logits" : Additionally return the class-membership logits `L_class` for the samples in `X`. - "embeddings" : Additionally return the learned embeddings `X_embed` for the samples in `X`. - "annotator_perf" : additionally return the estimated annotator performance probabilities `P_perf` for each sample–annotator pair. - "annotator_class" : Additionally return the annotator–class probability estimates `P_annot` for each sample, class, and annotator. - "annotator_embeddings" : Additionally return the learned embeddings `A_embed` for the annotators as the next element of the output tuple. Returns ------- P : numpy.ndarray of shape (n_samples, n_classes) Class probabilities of the test samples. Classes are ordered according to `self.classes_`. *extras : numpy.ndarray, optional Only returned if `extra_outputs` is not `None`. In that case, the method returns a tuple whose first element is `P` and whose remaining elements correspond to the requested forward outputs in the order given by `extra_outputs`. Potential outputs are: - `L_class` : `np.ndarray` of shape `(n_samples, n_classes)`, where `L_class[n, c]` is the logit for the class `classes_[c]` of sample `X[n]`. - `X_embed` : `np.ndarray` of shape `(n_samples, ...)`, where `X_embed[n]` refers to the learned embedding for sample `X[n]`. - `P_perf` : `np.ndarray` of shape `(n_samples, n_annotators)`, where `P_perf[n, m]` refers to the estimated label correctness probability (performance) of annotator `m` when labeling sample `X[n]`. - `P_annot : `np.ndarray` of shape `(n_samples, n_annotators, n_classes)`, where `P_annot[n, m, c]` refers to the probability that annotator `m` provides the class label `c` for sample `X[n]`. - `A_embed` : `np.ndarray` of shape `(n_annotators, annotator_embed_dim)`, where `A_embed[m]` refers to the learned embedding for annotator `m`. """ # Check input parameters. self._validate_data_kwargs() X = check_array(X, **self.check_X_dict_) check_n_features( self, X, reset=not hasattr(self, "n_features_in_") ) extra_outputs = self._normalize_extra_outputs( extra_outputs=extra_outputs, allowed_names=AnnotMixClassifier._ALLOWED_EXTRA_OUTPUTS, ) # Initialize module, if not done yet. if not hasattr(self, "neural_net_"): self.initialize() # Set forward options to obtain the different outputs required # by the input parameters. net = self.neural_net_.module_ old_forward_return = net.forward_return forward_outputs = {"probas": (0, nn.Softmax(dim=-1))} forward_returns = ["logits_class"] out_idx = 1 if "logits" in extra_outputs: forward_outputs["logits"] = (0, None) if "embeddings" in extra_outputs: forward_outputs["embeddings"] = (out_idx, None) forward_returns.append("x_embed") out_idx += 1 if "annotator_perf" in extra_outputs: def _transform_annotator_perf(P_perf): P_perf = P_perf.exp() return P_perf.reshape(-1, self.n_annotators_) forward_outputs["annotator_perf"] = ( out_idx, _transform_annotator_perf, ) forward_returns.append("log_p_annotator_perf") out_idx += 1 if "annotator_class" in extra_outputs: def _transform_annotator_class(P_annot): P_annot = P_annot.exp() return P_annot.reshape( -1, self.n_annotators_, len(self.classes_) ) forward_outputs["annotator_class"] = ( out_idx, _transform_annotator_class, ) forward_returns.append("log_p_annotator_class") out_idx += 1 if "annotator_embeddings" in extra_outputs: def _transform_annotator_embeddings(A_embed): return A_embed[: self.n_annotators_] forward_outputs["annotator_embeddings"] = ( out_idx, _transform_annotator_embeddings, ) forward_returns.append("a_embed") # Compute predictions for the different outputs required # by the input parameters. try: net.set_forward_return(forward_returns) fw_out = self._forward_with_named_outputs( X=X, forward_outputs=forward_outputs, extra_outputs=extra_outputs, ) finally: net.set_forward_return(old_forward_return) # Initialize fallbacks if the classifier hasn't been fitted before. self._initialize_fallbacks( fw_out[0] if isinstance(fw_out, tuple) else fw_out ) return fw_out
def _build_neural_net_param_overrides(self, X, y): """Initialize the internal `sklearn` wrapper from `skorch`.""" # Check parameters specific to `AnnotMixClassifier`. check_scalar( self.alpha, name="alpha", target_type=float, min_val=0.0, min_inclusive=True, ) check_scalar( self.sample_embed_dim, name="sample_embed_dim", target_type=int, min_val=0, min_inclusive=True, ) check_scalar( self.annotator_embed_dim, name="annotator_embed_dim", target_type=int, min_val=1, min_inclusive=True, ) hidden_dim = self.hidden_dim if hidden_dim is None: hidden_dim = min( 4 * len(self.classes_), max( 128, 2 * (self.annotator_embed_dim + self.sample_embed_dim), ), ) check_scalar( hidden_dim, name="hidden_dim", target_type=int, min_val=1, min_inclusive=True, ) check_scalar( self.n_hidden_layers, name="n_hidden_layers", target_type=int, min_val=0, min_inclusive=True, ) check_scalar( self.hidden_dropout, name="hidden_dropout", target_type=float, min_val=0.0, min_inclusive=True, max_val=1.0, max_inclusive=False, ) check_scalar( self.eta, name="eta", target_type=float, min_val=1 / len(self.classes_), min_inclusive=False, max_val=1.0, max_inclusive=False, ) collate_fn = _MixUpCollate( n_classes=len(self.classes_), n_annotators=self.n_annotators_, alpha=self.alpha, missing_label=-1, ) return { "criterion__reduction": "batchmean", "module__n_classes": len(self.classes_), "module__n_annotators": self.n_annotators_, "module__sample_embed_dim": self.sample_embed_dim, "module__annotator_embed_dim": self.annotator_embed_dim, "module__hidden_dim": hidden_dim, "module__n_hidden_layers": self.n_hidden_layers, "module__hidden_dropout": self.hidden_dropout, "module__eta": self.eta, "iterator_train__collate_fn": collate_fn, }
class _AnnotMixModule(_MultiAnnotatorClassificationModule): """ Auxiliary module for Annot-Mix [1]_ that produces class logits and annotator-conditioned outputs, while training with MixUp [2]_. Parameters ---------- n_classes : int Number of classes. n_annotators : int Number of annotators. clf_module : nn.Module or nn.Module.__class__ Classifier backbone/head that maps `x -> logits_class` or `(logits_class, x_embed)`. If it returns only logits, `x_embed` is set to the input `x` (or to `None` if `x` is not an embedding). clf_module_param_dict : dict Keyword args for constructing `clf_module` if a class is passed. annotator_embed_dim : int Dimensionality of the annotator embedding used to model annotator-specific behavior. sample_embed_dim : int or None Dimensionality of an optional learnable sample-embedding used to model sample-specific behavior of each annotator. If `sample_embed_dim=0`, no additional sample embedding is learned. hidden_dim : int or None Hidden size of the fusion multi-layer perceptron that propagates sample and annotator representations. If `None`, a sensible default is used, which depends on the other input parameters. n_hidden_layers : int Number of hidden layers in the fusion multi-layer perceptron. hidden_dropout : float Dropout probability applied in the fusion multi-layer perceptron. eta : float in (0, 1) Prior annotator performance, i.e., the probability of obtaining a correct annotation from an arbitrary annotator for an arbitrary sample. References ---------- .. [1] Herde, M., Lührs, L., Huseljic, D., & Sick, B. (2024). Annot-Mix: Learning with Noisy Class Labels from Multiple Annotators via a Mixup Extension. Eur. Conf. Artif. Intell. .. [2] Zhang, H., Cisse, M., Dauphin, Y. N., & Lopez-Paz, D. (2018). mixup: Beyond Empirical Risk Minimization. Int. Conf. Learn. Represent. """ # Optional names that can be returned *after* logits_class OUTPUTS = ( "logits_class", "x_embed", "a_embed", "log_p_annotator_class", "log_p_annotator_perf", ) def __init__( self, n_classes, n_annotators, clf_module, clf_module_param_dict, sample_embed_dim, annotator_embed_dim, hidden_dim, n_hidden_layers, hidden_dropout, eta, ): super().__init__( clf_module=clf_module, clf_module_param_dict=clf_module_param_dict, default_forward_outputs="log_p_annotator_class", full_forward_outputs=[ "logits_class", "x_embed", "a_embed", "log_p_annotator_class", "log_p_annotator_perf", ], ) # Define integer variables. self.n_classes = n_classes self.annotator_embed_dim = annotator_embed_dim # Set up layer to learn annotator embeddings. self.register_buffer( "a", torch.eye(n_annotators, dtype=torch.float32) ) self.sample_embed = None if sample_embed_dim > 0: self.sample_embed = nn.LazyLinear( out_features=sample_embed_dim, ) self.annotator_embed = nn.Linear( in_features=n_annotators, out_features=annotator_embed_dim, ) # Post-scale diagonal bump as inductive bias. eta = math.log(eta / (1.0 - eta)) + math.log(n_classes - 1.0) prior_conf = nn.Parameter( eta * torch.eye(n_classes, dtype=torch.float32).flatten() ) # Set up annotator confusion head. full_dim = sample_embed_dim + annotator_embed_dim blocks, dim = [], full_dim for _ in range(n_hidden_layers): blocks += [ nn.Dropout(hidden_dropout), nn.Linear(dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.SiLU(), ] dim = hidden_dim out = nn.Linear(dim, n_classes * n_classes) out.bias = prior_conf blocks += [out] self.annotator_confusion_head = nn.Sequential(*blocks) def forward(self, x, a=None): """ Parameters ---------- x : torch.Tensor of shape (batch_size, ...) Input batch. Shape depends on `clf_module`. a : torch.Tensor of shape (batch_size, ...) or None Annotator features/IDs. Needed if any of {"a_embed", "log_p_annotator_class", "p_perf"} are requested. Returns ------- out : torch.Tensor or tuple Given `set_forward_return`, tensors are appended in the order: - `"logits_class"`, - `"x_embed"`, - "log_p_annotator_class"` - "log_p_annotator_perf"` - `"a_embed"`. """ # Obtain classifier outputs. logits_class, x_embed = self.clf_module_forward(x) # Append classifier output if required. out = [] if "logits_class" in self.forward_return: out.append(logits_class) if "x_embed" in self.forward_return: out.append(x_embed) need_annotator_output = any( k in self.forward_return for k in ( "a_embed", "log_p_annotator_class", "log_p_annotator_perf", ) ) if need_annotator_output: a = a if a is not None else self.a # Sample/annotator embeddings for annotator head. if self.sample_embed: x_embed = self.sample_embed( x_embed.detach().flatten(start_dim=1) ) a_embed = self.annotator_embed(a) # Generate pairs of samples and annotator if not done yet. if not self.training: combs = torch.cartesian_prod( torch.arange(len(x), device=x.device), torch.arange(len(a_embed), device=a_embed.device), ) if self.sample_embed: x_embed = x_embed[combs[:, 0]] a_embed_return = a_embed.clone().detach() a_embed = a_embed[combs[:, 1]] logits_class = logits_class[combs[:, 0]] else: a_embed_return = a_embed # Compute confusion matrix logits per sample-annotator pair. annot_head_input = a_embed if self.sample_embed: annot_head_input = torch.cat([x_embed, a_embed], dim=-1) logits_conf = self.annotator_confusion_head(annot_head_input) logits_conf = logits_conf.view( -1, self.n_classes, self.n_classes ) # Compute log-probabilities for class and confusion matrices. p_conf_log = F.log_softmax(logits_conf, dim=-1) p_class_log = F.log_softmax(logits_class, dim=-1) # Compute and append annotator correctness log-probabilities. if "log_p_annotator_perf" in self.forward_return: log_diag_conf = torch.diagonal( p_conf_log, dim1=-2, dim2=-1 ) p_perf = torch.logsumexp( p_class_log + log_diag_conf, dim=-1 ) out.append(p_perf) # Compute and append annotator class log-probabilities. if "log_p_annotator_class" in self.forward_return: log_p_annotator_class = torch.logsumexp( p_class_log[:, :, None] + p_conf_log, dim=1 ) out.append(log_p_annotator_class) if "a_embed" in self.forward_return: out.append(a_embed_return) return out[0] if len(out) == 1 else tuple(out) class _MixUpCollate: """ Collate that expands a batch into all (sample, annotator) pairs and optionally applies MixUp [1]_ jointly to samples, annotators, and labels [2]_. Parameters ---------- n_classes : int Number of classes (for one-hot encoding). n_annotators : int Number of annotators (for one-hot encoding) alpha : float, default=1.0 MixUp Beta(alpha, alpha) parameter. If <= 0, no MixUp is applied. missing_label : int or float, default=-1 Value in `y` indicating an unlabeled sample. Rows whose sample label equals `missing_label` are excluded from the (sample, annotator) pairs. If set to `float('nan')` or `numpy.nan`, NaN labels are treated as missing. Notes ----- Labels are returned as one-hot encoded vectors of length `n_classes`. References ---------- .. [1] Zhang, H., Cisse, M., Dauphin, Y. N., & Lopez-Paz, D. (2018). mixup: Beyond Empirical Risk Minimization. Int. Conf. Learn. Represent. .. [2] Herde, M., Lührs, L., Huseljic, D., & Sick, B. (2024). Annot-Mix: Learning with Noisy Class Labels from Multiple Annotators via a Mixup Extension. Eur. Conf. Artif. Intell. """ def __init__( self, n_classes, n_annotators, alpha=1.0, missing_label=-1 ): if n_classes <= 0: raise ValueError("`n_classes` must be a positive integer.") if n_annotators <= 0: raise ValueError("`n_annotators` must be a positive integer.") alpha = float(alpha) if alpha < 0: raise ValueError("`alpha` must be >= 0 for MixUp.") self.n_classes = int(n_classes) self.n_annotators = int(n_annotators) self.a = torch.eye(self.n_annotators, dtype=torch.float32) self.alpha = alpha self.missing_label = missing_label def __call__(self, batch): # 1) Basic collation (supports tensors/ndarrays/nested dicts) of # samples X, labels y, and annotators a. x = default_collate([b[0] for b in batch]) y = default_collate([b[1] for b in batch]) # Expect labels of shape (n_samples, n_annotators) if y.dim() != 2 or y.shape[1] != self.n_annotators: raise ValueError( f"y must have shape (n_samples, {self.n_annotators}), " f"got {tuple(y.shape)}." ) n_samples, _ = y.shape # Flatten labels to (n_samples * n_annotators,). y = y.view(-1) # 2) Build all (sample, annotator) combinations # sample indices: 0..B-1 repeated for each annotator. idx_s = torch.arange( n_samples, dtype=torch.long ).repeat_interleave(self.n_annotators) # Annotator indices: 0..A-1 tiled B times. idx_a = torch.arange(self.n_annotators, dtype=torch.long).repeat( n_samples ) # Mask out pairs whose sample is unlabeled. if isinstance(self.missing_label, float) and ( self.missing_label != self.missing_label ): # NaN mask = ~torch.isnan(y.to(torch.float32)) else: mask = y != self.missing_label idx_s = idx_s[mask] idx_a = idx_a[mask] y_pairs = y[mask] # 3) Select data per pair. x_pairs = x.index_select(0, idx_s) a_pairs = self.a.index_select(0, idx_a) # One-hot labels (ensure integer dtype for F.one_hot). y_pairs = y_pairs.to(torch.long) y_oh = F.one_hot(y_pairs, num_classes=self.n_classes).to( dtype=torch.float32 ) # 4) Optional MixUp across pairs (jointly mixing x, a, and y). if self.alpha > 0: x_pairs, a_pairs, y_oh, _, _ = _mix_up( x_pairs, a_pairs, y_oh, alpha=self.alpha ) x_out = {"x": x_pairs, "a": a_pairs} return x_out, y_oh def _mix_up(*arrays, alpha=1.0, lmbda=None, permute_indices=None): """ MixUp [1]_ multiple arrays using the same permutation and lambdas. Parameters ---------- arrays : sequence of torch.Tensor Tensors with the same length `N` along the first dimension. Each will be mixed with the same permutation and mixing coefficients. alpha : float, default=1.0 Beta(alpha, alpha) parameter. Used only if `lmbda is None`. If `alpha == 0`, returns inputs unchanged (with `lmbda` all ones). If `alpha < 0`, a ValueError is raised. lmbda : torch.Tensor of shape (N,), default=None Precomputed mixing coefficients in [0, 1]. If not provided, sampled from `Beta(alpha, alpha)` on the same device as the first array when `alpha > 0`, or set to ones if `alpha == 0`. permute_indices : torch.Tensor of shape (N,), default=None Precomputed permutation indices. If not provided, a random permutation is generated on the same device as the first array. Returns ------- outputs : tuple Tuple of mixed tensors in the same order as `arrays`, followed by `(lmbda, permute_indices)`. References ---------- .. [1] Zhang, H., Cisse, M., Dauphin, Y. N., & Lopez-Paz, D. (2018). mixup: Beyond Empirical Risk Minimization. Int. Conf. Learn. Represent. """ if len(arrays) == 0: raise ValueError("At least one array must be provided to _mix_up.") # All arrays must share the same leading dimension. N = arrays[0].shape[0] for arr in arrays[1:]: if arr.shape[0] != N: raise ValueError( "All arrays must have the same length in dim 0." ) first = arrays[0] device = first.device alpha = float(alpha) if alpha < 0: raise ValueError("alpha must be >= 0 for MixUp.") # Handle lambda. if lmbda is None: if alpha == 0: lmbda = torch.ones(N, device=device, dtype=torch.float32) else: lmbda = ( torch.distributions.Beta(alpha, alpha) .sample((N,)) .to(device=device, dtype=torch.float32) ) else: lmbda = torch.as_tensor(lmbda, device=device, dtype=torch.float32) if lmbda.dim() != 1 or lmbda.shape[0] != N: raise ValueError( f"`lmbda` must have shape ({N},), " f"got {tuple(lmbda.shape)}." ) # Handle permutation. if permute_indices is None: permute_indices = torch.randperm(N, device=device) else: permute_indices = torch.as_tensor( permute_indices, device=device, dtype=torch.long ) if permute_indices.dim() != 1 or permute_indices.shape[0] != N: raise ValueError( f"`permute_indices` must have shape ({N},), " f"got {tuple(permute_indices.shape)}." ) # Broadcast lmbda to array shapes and mix. outputs = [] for arr in arrays: view_shape = (N,) + (1,) * (arr.dim() - 1) lam_view = lmbda.view(view_shape) mixed = lam_view * arr + (1.0 - lam_view) * arr.index_select( 0, permute_indices ) outputs.append(mixed.to(arr.dtype)) outputs.extend([lmbda, permute_indices]) return tuple(outputs) except ImportError: # pragma: no cover pass