"""
Module implementing discriminative active learning.
"""
# Authors: Marek Herde <marek.herde@uni-kassel.de>
import numpy as np
from sklearn import clone
from ..base import SingleAnnotatorPoolQueryStrategy, SkactivemlClassifier
from ..utils import (
MISSING_LABEL,
rand_argmax,
is_unlabeled,
simple_batch,
check_type,
)
[docs]class DiscriminativeAL(SingleAnnotatorPoolQueryStrategy):
"""Discriminative Active Learning.
This class implement the "Discriminative Active Learning" (DAL) strategy.
Its idea is to solve a binary classification task to choose samples for
labeling such that the labeled set and the unlabeled pool are
indistinguishable.
Parameters
----------
greedy_selection : bool, optional (default=False)
This parameter is only relevant for `batch_size>1`. If
`greedy_selection=False` the classifying discriminator is refitted
after each sample selection within a batch. Otherwise, the
discriminator is kept fixed.
missing_label : scalar or string or np.nan or None, optional
(default=np.nan)
Value to represent a missing label.
random_state : None or int or np.random.RandomState, optional
(default=None)
The random state to use.
References
----------
[1] Gissin D, Shalev-Shwartz S. "Discriminative active learning."
arXiv:1907.06347. 2019.
"""
def __init__(
self,
greedy_selection=False,
missing_label=MISSING_LABEL,
random_state=None,
):
super().__init__(
missing_label=missing_label, random_state=random_state
)
self.greedy_selection = greedy_selection
[docs] def query(
self,
X,
y,
discriminator,
candidates=None,
batch_size=1,
return_utilities=False,
):
"""Determines for which candidate samples labels are to be queried.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data set, usually complete, i.e., including the labeled
and unlabeled samples.
y : array-like of shape (n_samples)
Labels of the training data set (possibly including unlabeled ones
indicated by `self.missing_label`).
discriminator : skactiveml.base.SkactivemlClassifier
Model implementing the methods `fit` and `predict_proba`.
The parameters `classes` and `missing_label` will be internally
redefined.
candidates : None or array-like of shape (n_candidates), dtype=int or
array-like of shape (n_candidates, n_features), optional (default=None)
If `candidates` is `None`, the unlabeled samples from `(X, y)` are
considered as candidates.
If `candidates` is of shape `(n_candidates,)` and of type int,
`candidates` is considered as the indices of the samples in
`(X, y)`.
If `candidates` is of shape `(n_candidates, n_features)`, the
candidates are directly given in candidates (not necessarily
contained in `X`).
batch_size : int, optional (default=1)
The number of samples to be selected in one AL cycle.
return_utilities : bool, optional (default=False)
If true, also return the utilities based on the query strategy.
Returns
-------
query_indices : numpy.ndarray of shape (batch_size,)
The `query_indices` indicate for which candidate sample a label is
to be queried, e.g., `query_indices[0]` indicates the index of
the first selected sample.
If `candidates` is `None` or of shape `(n_candidates,)`, the
indexing refers to samples in `X`.
If `candidates` is of shape (n_candidates, n_features), the
indexing refers to samples in `candidates`.
utilities : numpy.ndarray of shape (batch_size, n_samples) or
numpy.ndarray of shape (batch_size, n_candidates)
The utilities of samples after each selected sample of the batch,
e.g., `utilities[0]` indicates the utilities used for selecting
the first sample (with index `query_indices[0]`) of the batch.
Utilities for labeled samples will be set to np.nan.
If `candidates` is `None` or of shape `(n_candidates,)`, the
indexing refers to samples in `X`.
If `candidates` is of shape `(n_candidates, n_features)`, the
indexing refers to samples in `candidates`.
"""
# Validate parameters.
X, y, candidates, batch_size, return_utilities = self._validate_data(
X, y, candidates, batch_size, return_utilities, reset=True
)
check_type(discriminator, "discriminator", SkactivemlClassifier)
check_type(self.greedy_selection, "greedy_selection", bool)
# Retransform candidates and create a potential mapping to the samples
# in `X`.
X_cand, mapping = self._transform_candidates(
candidates, X, y, enforce_mapping=True
)
# Re-define discriminator to fit the setting of classifying
# labeled (y=0) and unlabeled samples (y=1).
discriminator = clone(discriminator)
discriminator.classes = [0, 1]
discriminator.missing_label = -1
if self.greedy_selection:
# Return the top samples with the highest probabilities of
# being unlabeled, which correspond to their utilities.
y_discriminator = is_unlabeled(y, missing_label=self.missing_label)
y_discriminator = y_discriminator.astype(int)
discriminator.fit(X, y_discriminator)
utilities_cand = discriminator.predict_proba(X_cand)[:, 1]
# Remapping of `utilities` and `query_indices` if required.
utilities = np.full(len(X), np.nan)
utilities[mapping] = utilities_cand
# Return `query_indices` and potential `utilities`.
return simple_batch(
utilities,
self.random_state_,
batch_size=batch_size,
return_utilities=return_utilities,
)
else:
# Refit the binary classifier, i.e., the discriminator, after each
# selected sample in a batch.
X_discriminator = X
query_indices_cand = []
utilities_cand = np.empty((batch_size, len(X_cand)), dtype=float)
for i in range(batch_size):
# Determine unlabeled vs. labeled samples.
y_discriminator = is_unlabeled(
y, missing_label=self.missing_label
)
y_discriminator = y_discriminator.astype(int)
# Mark already selected samples as labeled.
y_discriminator[mapping[query_indices_cand]] = 0
# Fit discriminator to classify unlabeled vs. labeled samples.
discriminator.fit(X_discriminator, y_discriminator)
# Compute utilities as probabilities of being unlabeled.
utilities_cand[i] = discriminator.predict_proba(X_cand)[:, 1]
utilities_cand[i, query_indices_cand] = np.nan
query_indices_cand.append(
rand_argmax(utilities_cand[i], self.random_state_)[0]
)
# Remapping of `utilities` and `query_indices`
utilities = np.full((batch_size, len(X)), np.nan)
utilities[:, mapping] = utilities_cand
query_indices = mapping[query_indices_cand]
# Check whether `utilities` are to be returned.
if return_utilities:
return query_indices, utilities
else:
return query_indices