Source code for skactiveml.pool._discriminative_al

"""
Module implementing discriminative active learning.
"""

# Authors: Marek Herde <marek.herde@uni-kassel.de>

import numpy as np
from sklearn import clone

from ..base import SingleAnnotatorPoolQueryStrategy, SkactivemlClassifier
from ..utils import (
    MISSING_LABEL,
    rand_argmax,
    is_unlabeled,
    simple_batch,
    check_type,
)


[docs]class DiscriminativeAL(SingleAnnotatorPoolQueryStrategy): """Discriminative Active Learning. This class implement the "Discriminative Active Learning" (DAL) strategy. Its idea is to solve a binary classification task to choose samples for labeling such that the labeled set and the unlabeled pool are indistinguishable. Parameters ---------- greedy_selection : bool, optional (default=False) This parameter is only relevant for `batch_size>1`. If `greedy_selection=False` the classifying discriminator is refitted after each sample selection within a batch. Otherwise, the discriminator is kept fixed. missing_label : scalar or string or np.nan or None, optional (default=np.nan) Value to represent a missing label. random_state : None or int or np.random.RandomState, optional (default=None) The random state to use. References ---------- [1] Gissin D, Shalev-Shwartz S. "Discriminative active learning." arXiv:1907.06347. 2019. """ def __init__( self, greedy_selection=False, missing_label=MISSING_LABEL, random_state=None, ): super().__init__( missing_label=missing_label, random_state=random_state ) self.greedy_selection = greedy_selection
[docs] def query( self, X, y, discriminator, candidates=None, batch_size=1, return_utilities=False, ): """Determines for which candidate samples labels are to be queried. Parameters ---------- X : array-like of shape (n_samples, n_features) Training data set, usually complete, i.e., including the labeled and unlabeled samples. y : array-like of shape (n_samples) Labels of the training data set (possibly including unlabeled ones indicated by `self.missing_label`). discriminator : skactiveml.base.SkactivemlClassifier Model implementing the methods `fit` and `predict_proba`. The parameters `classes` and `missing_label` will be internally redefined. candidates : None or array-like of shape (n_candidates), dtype=int or array-like of shape (n_candidates, n_features), optional (default=None) If `candidates` is `None`, the unlabeled samples from `(X, y)` are considered as candidates. If `candidates` is of shape `(n_candidates,)` and of type int, `candidates` is considered as the indices of the samples in `(X, y)`. If `candidates` is of shape `(n_candidates, n_features)`, the candidates are directly given in candidates (not necessarily contained in `X`). batch_size : int, optional (default=1) The number of samples to be selected in one AL cycle. return_utilities : bool, optional (default=False) If true, also return the utilities based on the query strategy. Returns ------- query_indices : numpy.ndarray of shape (batch_size,) The `query_indices` indicate for which candidate sample a label is to be queried, e.g., `query_indices[0]` indicates the index of the first selected sample. If `candidates` is `None` or of shape `(n_candidates,)`, the indexing refers to samples in `X`. If `candidates` is of shape (n_candidates, n_features), the indexing refers to samples in `candidates`. utilities : numpy.ndarray of shape (batch_size, n_samples) or numpy.ndarray of shape (batch_size, n_candidates) The utilities of samples after each selected sample of the batch, e.g., `utilities[0]` indicates the utilities used for selecting the first sample (with index `query_indices[0]`) of the batch. Utilities for labeled samples will be set to np.nan. If `candidates` is `None` or of shape `(n_candidates,)`, the indexing refers to samples in `X`. If `candidates` is of shape `(n_candidates, n_features)`, the indexing refers to samples in `candidates`. """ # Validate parameters. X, y, candidates, batch_size, return_utilities = self._validate_data( X, y, candidates, batch_size, return_utilities, reset=True ) check_type(discriminator, "discriminator", SkactivemlClassifier) check_type(self.greedy_selection, "greedy_selection", bool) # Retransform candidates and create a potential mapping to the samples # in `X`. X_cand, mapping = self._transform_candidates( candidates, X, y, enforce_mapping=True ) # Re-define discriminator to fit the setting of classifying # labeled (y=0) and unlabeled samples (y=1). discriminator = clone(discriminator) discriminator.classes = [0, 1] discriminator.missing_label = -1 if self.greedy_selection: # Return the top samples with the highest probabilities of # being unlabeled, which correspond to their utilities. y_discriminator = is_unlabeled(y, missing_label=self.missing_label) y_discriminator = y_discriminator.astype(int) discriminator.fit(X, y_discriminator) utilities_cand = discriminator.predict_proba(X_cand)[:, 1] # Remapping of `utilities` and `query_indices` if required. utilities = np.full(len(X), np.nan) utilities[mapping] = utilities_cand # Return `query_indices` and potential `utilities`. return simple_batch( utilities, self.random_state_, batch_size=batch_size, return_utilities=return_utilities, ) else: # Refit the binary classifier, i.e., the discriminator, after each # selected sample in a batch. X_discriminator = X query_indices_cand = [] utilities_cand = np.empty((batch_size, len(X_cand)), dtype=float) for i in range(batch_size): # Determine unlabeled vs. labeled samples. y_discriminator = is_unlabeled( y, missing_label=self.missing_label ) y_discriminator = y_discriminator.astype(int) # Mark already selected samples as labeled. y_discriminator[mapping[query_indices_cand]] = 0 # Fit discriminator to classify unlabeled vs. labeled samples. discriminator.fit(X_discriminator, y_discriminator) # Compute utilities as probabilities of being unlabeled. utilities_cand[i] = discriminator.predict_proba(X_cand)[:, 1] utilities_cand[i, query_indices_cand] = np.nan query_indices_cand.append( rand_argmax(utilities_cand[i], self.random_state_)[0] ) # Remapping of `utilities` and `query_indices` utilities = np.full((batch_size, len(X)), np.nan) utilities[:, mapping] = utilities_cand query_indices = mapping[query_indices_cand] # Check whether `utilities` are to be returned. if return_utilities: return query_indices, utilities else: return query_indices