Query-by-Committee (QBC) with Empirical Variance#

Idea: QBC maintains a committee of models and selects unlabeled samples where the committee most disagrees, targeting epistemic uncertainty. In batch mode, it ranks points by a disagreement score and takes the top batch_size samples. Empirical variance (regression) means that the variance of the committee’s real-valued predictions quantifies disagreement, and higher variance is preferred.

Google Colab Note: If the notebook fails to run after installing the needed packages, try to restart the runtime (Ctrl + M) under Runtime -> Restart session.
https://colab.research.google.com/assets/colab-badge.svg
Notebook Dependencies
Uncomment the following cell to install all dependencies for this tutorial.
# !pip install scikit-activeml
import numpy as np
from matplotlib import pyplot as plt, animation
from scipy.stats import uniform

from skactiveml.utils import MISSING_LABEL, is_labeled

from sklearn.gaussian_process import GaussianProcessRegressor
from skactiveml.pool import QueryByCommittee
from skactiveml.regressor import SklearnRegressor

# Set a fixed random state for reproducibility.
random_state = np.random.RandomState(0)


def true_function(X_):
    """Compute the true underlying function."""
    return (X_**3 + 2 * X_**2 + X_ - 1).flatten()


# Generate samples.
n_samples = 100
X = np.concatenate(
    [
        uniform.rvs(0, 1.5, 9 * n_samples // 10, random_state=random_state),
        uniform.rvs(1.5, 0.5, n_samples // 10, random_state=random_state),
    ]
).reshape(-1, 1)

# Define noise: higher noise for X < 1 and lower otherwise.
noise = np.vectorize(
    lambda x: random_state.rand() * 1.5 if x < 1 else random_state.rand() * 0.5
)

# Build the dataset.
y_true = true_function(X) + noise(X).flatten()
y = np.full(shape=y_true.shape, fill_value=MISSING_LABEL)
X_test = np.linspace(0, 2, num=100).reshape(-1, 1)
y_test = true_function(X_test)

# Initialise the regressor.
reg = SklearnRegressor(GaussianProcessRegressor())

# Initialise the query strategy.
qs = QueryByCommittee(sample_predictions_method_name='sample_y', sample_predictions_dict={'n_samples': 100})

# Prepare the plotting area.
fig, (ax_1, ax_2) = plt.subplots(2, 1, sharex=True)
artists = []

# Active learning cycle.
n_cycles = 20
for c in range(n_cycles):
    # Fit the regressor using the current labels.
    reg.fit(X, y)

    # Query the next sample(s).
    query_idx = qs.query(X=X, y=y, ensemble=reg)

    # Record current plot elements.
    coll_old = list(ax_1.collections) + list(ax_2.collections)
    title = ax_1.text(
        0.5, 1.05,
        f"Prediction after acquiring {c} labels\n"
        f"Test R-squared score: {reg.score(X_test, y_test):.4f}",
        size=plt.rcParams["axes.titlesize"],
        ha="center",
        transform=ax_1.transAxes,
    )
    ax_1.set_xlabel('Sample')
    ax_1.set_ylabel('Target Value')
    ax_2.set_xlabel('Sample')
    ax_2.set_ylabel('Utility')

    # Compute utility values for the test candidates.
    _, utilities_test = qs.query(X=X, y=y, ensemble=reg, candidates=X_test, return_utilities=True)
    utilities_test = (utilities_test - utilities_test.min()).flatten()
    if np.any(utilities_test != utilities_test[0]):
        utilities_test /= utilities_test.max()

    # Plot utility information on the second axis.
    (utility_line,) = ax_2.plot(X_test, utilities_test, c="green")
    utility_fill = plt.fill_between(X_test.flatten(), utilities_test, color="green", alpha=0.3)

    # Plot the samples and their labels.
    is_lbld = is_labeled(y)
    ax_1.scatter(X[~is_lbld], y_true[~is_lbld], c="lightblue")
    ax_1.scatter(X[is_lbld], y[is_lbld], c="orange")

    # Predict and plot the regressor's output.
    y_pred = reg.predict(X_test)
    (prediction_line,) = ax_1.plot(X_test, y_pred, c="black")

    # Capture new plot elements.
    coll_new = list(ax_1.collections) + list(ax_2.collections)
    coll_new.append(title)
    artists.append(
        [x for x in coll_new if (x not in coll_old)]
        + [utility_line, utility_fill, prediction_line]
    )

    # Update labels for the queried sample.
    y[query_idx] = y_true[query_idx]

# Create an animation from the collected frames.
ani = animation.ArtistAnimation(fig, artists, interval=1000, blit=True)
/home/runner/work/scikit-activeml.github.io/scikit-activeml.github.io/scikit-activeml/skactiveml/regressor/_wrapper.py:184: UserWarning: The 'estimator' could not be fitted because of 'Found array with 0 sample(s) (shape=(0, 1)) while a minimum of 1 is required by GaussianProcessRegressor.'. Therefore, the empirical label mean `_label_mean=0` and the empirical label standard deviation `_label_std=1` will be used to make predictions.
  warnings.warn(
/home/runner/work/scikit-activeml.github.io/scikit-activeml.github.io/scikit-activeml/skactiveml/regressor/_wrapper.py:184: UserWarning: The 'estimator' could not be fitted because of 'Found array with 0 sample(s) (shape=(0, 1)) while a minimum of 1 is required by GaussianProcessRegressor.'. Therefore, the empirical label mean `_label_mean=0` and the empirical label standard deviation `_label_std=1` will be used to make predictions.
  warnings.warn(
/home/runner/work/scikit-activeml.github.io/scikit-activeml.github.io/scikit-activeml/skactiveml/regressor/_wrapper.py:221: UserWarning: Since the 'estimator' could not be fitted when calling the `fit` method, the label mean `_label_mean=0` and optionally the label standard deviation `_label_std=1` is used to make the predictions.
  warnings.warn(
/home/runner/work/scikit-activeml.github.io/scikit-activeml.github.io/scikit-activeml/skactiveml/regressor/_wrapper.py:184: UserWarning: The 'estimator' could not be fitted because of 'Found array with 0 sample(s) (shape=(0, 1)) while a minimum of 1 is required by GaussianProcessRegressor.'. Therefore, the empirical label mean `_label_mean=0` and the empirical label standard deviation `_label_std=1` will be used to make predictions.
  warnings.warn(
/home/runner/work/scikit-activeml.github.io/scikit-activeml.github.io/scikit-activeml/skactiveml/regressor/_wrapper.py:221: UserWarning: Since the 'estimator' could not be fitted when calling the `fit` method, the label mean `_label_mean=0` and optionally the label standard deviation `_label_std=1` is used to make the predictions.
  warnings.warn(
../../../_images/pool_regression_legend1.png

References:

The implementation of this strategy is based on Seung et al.[1] and Burbidge et al.[2].

Total running time of the script: (0 minutes 6.012 seconds)

Gallery generated by Sphinx-Gallery