{"cells": [{"cell_type": "markdown", "id": "33f39229508719ea", "metadata": {"collapsed": false}, "source": ["# Deep Active Learning with Frozen Vision Transformer"]}, {"cell_type": "markdown", "id": "fd4eb6e7", "metadata": {}, "source": ["> **_Google Colab Note:_** If the notebook fails to run after installing the needed packages, try to restart the runtime (Ctrl + M) under Runtime -> Restart session.\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/scikit-activeml/scikit-activeml.github.io/blob/gh-pages/development/generated/tutorials_colab//05_pool_al_with_self_supervised_learning.ipynb)"]}, {"cell_type": "markdown", "id": "79402ebb", "metadata": {}, "source": ["**Notebook Dependencies**\n", "\n", "Uncomment the following cells to install all dependencies for this tutorial."]}, {"cell_type": "code", "execution_count": 1, "id": "92efee18e6c2b8ce", "metadata": {}, "outputs": [], "source": ["# !pip install scikit-activeml[opt] torch torchvision tqdm datasets transformers"]}, {"cell_type": "markdown", "id": "a773fbde", "metadata": {}, "source": ["
"]}, {"cell_type": "markdown", "id": "fa4583e83ede8da9", "metadata": {"collapsed": false}, "source": ["This tutorial demonstrates a practical comparison study using `scikit-activeml`. The workflow uses the self-supervised `DINOv2` model [1] to compute frozen image embeddings for the CIFAR-100 dataset [2]. Based on these embeddings, several active learning strategies are compared in a pool-based classification setting.\n", "\n", "**Key Steps:**\n", "1. **Self-Supervised Feature Extraction:** Use frozen DINOv2 to create embedding representations for the CIFAR-100 train and test images.\n", "\n", "2. **Active Learning Strategies:** Compare different active learning strategies provided by our library, including:\n", " - Random Sampling,\n", " - Uncertainty Sampling,\n", " - Discriminative Active Learning,\n", " - CoreSet,\n", " - TypiClust,\n", " - Badge,\n", " - ProbCover,\n", " - DropQuery,\n", " - Falcun,\n", " - UHerding.\n", "\n", "3. **Batch Sample Selection:** Use each active learning strategy, wrapped by a subsampling step for efficiency, to select batches of samples for labeling.\n", "\n", "4. **Plotting the Results:** Compare both the absolute test-accuracy curves and the differences relative to random sampling.\n", "\n", "**References:**\n", "\n", "[1] Oquab, M., Darcet, T., Moutakanni, T., Vo, H. V., Szafraniec, M.,\n", "Khalidov, V., ... & Bojanowski, P. DINOv2: Learning Robust Visual Features\n", "without Supervision. Transactions on Machine Learning Research.\n", "\n", "[2] Krizhevsky, A., & Hinton, G. (2009). Learning Multiple Layers of\n", "Features from Tiny Images."]}, {"cell_type": "code", "execution_count": null, "id": "3a4b6dbb9143a5eb", "metadata": {"ExecuteTime": {"end_time": "2025-11-08T18:55:09.099806Z", "start_time": "2025-11-08T18:55:04.875268Z"}, "collapsed": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["cuML: Accelerator installed.\n"]}], "source": ["# Comment in for speedup, if you have cuML installed.\n", "# %load_ext cuml.accel\n", "import numpy as np\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import warnings\n", "\n", "from datasets import load_dataset\n", "from sklearn.linear_model import LogisticRegression\n", "from skactiveml.classifier import SklearnClassifier\n", "from skactiveml.pool import (\n", " UncertaintySampling,\n", " RandomSampling,\n", " DiscriminativeAL,\n", " CoreSet,\n", " TypiClust,\n", " Badge,\n", " DropQuery,\n", " ProbCover,\n", " Falcun,\n", " UHerding,\n", " SubSamplingWrapper,\n", ")\n", "from skactiveml.utils import call_func\n", "from transformers import AutoImageProcessor, Dinov2Model\n", "from tqdm import tqdm\n", "\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "mpl.rcParams[\"figure.facecolor\"] = \"white\"\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""]}, {"cell_type": "markdown", "id": "3ac32ea901e08c01", "metadata": {"collapsed": false}, "source": ["## Embed CIFAR-100 Images with DINOv2"]}, {"cell_type": "markdown", "id": "7a4e3bf805669b1c", "metadata": {"collapsed": false}, "source": ["In this step, we prepare the train and test sets using the self-supervised DINOv2 model. DINOv2, short for \"self-distillation with no labels,\" is a vision foundation model that provides meaningful frozen representations for image data. Here, we use the CLS-token representation of `facebook/dinov2-small` as the embedding for each CIFAR-100 image."]}, {"cell_type": "code", "execution_count": 3, "id": "3b27a5c395f5897e", "metadata": {"ExecuteTime": {"end_time": "2025-11-08T18:55:25.655039Z", "start_time": "2025-11-08T18:55:09.144763Z"}, "collapsed": false}, "outputs": [], "source": ["# Download data.\n", "ds = load_dataset(\"uoft-cs/cifar100\")\n", "\n", "# Download DINOv2 ViT/S-14 as embedding model.\n", "processor = AutoImageProcessor.from_pretrained(\n", " \"facebook/dinov2-small\", use_fast=True\n", ")\n", "model = Dinov2Model.from_pretrained(\"facebook/dinov2-small\").to(device).eval()\n", "\n", "\n", "# Embed CIFAR-100 images.\n", "def embed(batch):\n", " inputs = processor(images=batch[\"img\"], return_tensors=\"pt\").to(device)\n", " with torch.no_grad():\n", " out = model(**inputs).last_hidden_state[:, 0]\n", " batch[\"emb\"] = out.cpu().numpy()\n", " return batch\n", "\n", "\n", "ds = ds.map(embed, batched=True, batch_size=32)\n", "X_pool = np.stack(ds[\"train\"][\"emb\"], dtype=np.float32)\n", "y_pool = np.array(ds[\"train\"][\"fine_label\"], dtype=np.int64)\n", "X_test = np.stack(ds[\"test\"][\"emb\"], dtype=np.float32)\n", "y_test = np.array(ds[\"test\"][\"fine_label\"], dtype=np.int64)"]}, {"cell_type": "markdown", "id": "cf4565232ae31432", "metadata": {"collapsed": false}, "source": ["## Random Seed Management"]}, {"cell_type": "markdown", "id": "b1ce38a2-0305-44eb-abf5-cb69c41c3f98", "metadata": {}, "source": ["To ensure experiment reproducibility, we define helper functions for generating seeds and `RandomState` objects. The classifier seed is derived from `master_random_state`, while the experiment loop also creates deterministic per-repetition random states for the query strategy and subsampling wrapper."]}, {"cell_type": "code", "execution_count": 4, "id": "11e6ce648a9ab110", "metadata": {"ExecuteTime": {"end_time": "2025-11-08T18:55:25.703150Z", "start_time": "2025-11-08T18:55:25.701134Z"}, "collapsed": false}, "outputs": [], "source": ["master_random_state = np.random.RandomState(0)\n", "\n", "\n", "def gen_seed(random_state: np.random.RandomState):\n", " \"\"\"\n", " Generate a seed for a random number generator.\n", "\n", " Parameters:\n", " - random_state (np.random.RandomState): Random state object.\n", "\n", " Returns:\n", " - int: Generated seed.\n", " \"\"\"\n", " return random_state.randint(0, 2**31)\n", "\n", "\n", "def gen_random_state(random_state: np.random.RandomState):\n", " \"\"\"\n", " Generate a new random state object based on a given random state.\n", "\n", " Parameters:\n", " - random_state (np.random.RandomState): Random state object.\n", "\n", " Returns:\n", " - np.random.RandomState: New random state object.\n", " \"\"\"\n", " return np.random.RandomState(gen_seed(random_state))"]}, {"cell_type": "markdown", "id": "24cc0c5019b852b5", "metadata": {"collapsed": false}, "source": ["## Classification Models and Query Strategies"]}, {"cell_type": "markdown", "id": "91a3743e-cb9e-49a9-a239-30c9f54839fa", "metadata": {}, "source": ["The embeddings computed above serve as input to a classification model. In this tutorial, we use `LogisticRegression` from `sklearn`, wrapped by `SklearnClassifier`, for both evaluation and strategy interfaces that require a classifier-like object. `UHerding` can use `decision_function` as a logits fallback for such linear classifiers, so it can be included here without switching to a neural network classifier. Query strategies are instantiated through factory functions to keep the repetition logic compact and reproducible."]}, {"cell_type": "code", "execution_count": null, "id": "d356c3752b0a58a1", "metadata": {"ExecuteTime": {"end_time": "2025-11-08T18:55:25.747222Z", "start_time": "2025-11-08T18:55:25.744228Z"}, "collapsed": false}, "outputs": [], "source": ["n_features, classes = X_pool.shape[1], np.unique(y_pool)\n", "missing_label = -1\n", "clf = SklearnClassifier(\n", " LogisticRegression(verbose=0, tol=1e-3, C=0.01, max_iter=10000),\n", " classes=classes,\n", " random_state=gen_seed(master_random_state),\n", " missing_label=-1,\n", ")\n", "\n", "\n", "def create_query_strategy(name, random_state):\n", " return query_strategy_factory_functions[name](random_state)\n", "\n", "\n", "query_strategy_factory_functions = {\n", " \"RandomSampling\": lambda random_state: RandomSampling(\n", " random_state=gen_seed(random_state), missing_label=missing_label\n", " ),\n", " \"UncertaintySampling\": lambda random_state: UncertaintySampling(\n", " random_state=gen_seed(random_state),\n", " missing_label=missing_label\n", " ),\n", " \"DiscriminativeAL\": lambda random_state: DiscriminativeAL(\n", " random_state=gen_seed(random_state), missing_label=missing_label\n", " ),\n", " \"CoreSet\": lambda random_state: CoreSet(\n", " random_state=gen_seed(random_state), missing_label=missing_label\n", " ),\n", " \"TypiClust\": lambda random_state: TypiClust(\n", " random_state=gen_seed(random_state), missing_label=missing_label\n", " ),\n", " \"Badge\": lambda random_state: Badge(\n", " random_state=gen_seed(random_state), missing_label=missing_label\n", " ),\n", " \"DropQuery\": lambda random_state: DropQuery(\n", " random_state=gen_seed(random_state), missing_label=missing_label,\n", " ),\n", " \"ProbCover\": lambda random_state: ProbCover(\n", " random_state=gen_seed(random_state), missing_label=missing_label\n", " ),\n", " \"Falcun\": lambda random_state: Falcun(\n", " random_state=gen_seed(random_state), missing_label=missing_label\n", " ),\n", " \"UHerding\": lambda random_state: UHerding(\n", " random_state=gen_seed(random_state),\n", " missing_label=missing_label,\n", " ),\n", "}"]}, {"cell_type": "markdown", "id": "9023c37048ec8f54", "metadata": {"collapsed": false}, "source": ["## Experiment Parameters\n", "\n", "For this experiment, we need to define how the strategies should be compared against one another. Here the number of repetitions (```n_reps```), the number of cycles (```n_cycles```), and the size of each query (```query_batch_size```) need to be defined. "]}, {"cell_type": "code", "execution_count": 6, "id": "e103670190e73a3f", "metadata": {"ExecuteTime": {"end_time": "2025-11-08T18:55:25.789678Z", "start_time": "2025-11-08T18:55:25.788190Z"}, "collapsed": false}, "outputs": [], "source": ["n_reps = 3\n", "n_cycles = 20\n", "query_batch_size = 100\n", "query_strategy_names = query_strategy_factory_functions.keys()"]}, {"cell_type": "markdown", "id": "66b8fff1b12211ed", "metadata": {"collapsed": false}, "source": ["The experiment loop iterates over all query strategies. For each repetition, the strategy is wrapped in `SubSamplingWrapper` to restrict the candidate pool for efficiency, the classifier is refit after every acquisition step, and the test accuracy is stored for each cycle in the `results` dictionary."]}, {"cell_type": "code", "execution_count": 7, "id": "718629bb-1c4f-4707-9eb1-9d1df901cd84", "metadata": {"ExecuteTime": {"end_time": "2025-11-08T19:13:43.735332Z", "start_time": "2025-11-08T18:55:25.831259Z"}}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Repeat 1 with RandomSampling: 0%| | 0/20 [00:00"]}, "metadata": {}, "output_type": "display_data"}], "source": ["x = np.arange(n_cycles + 1) * query_batch_size\n", "random_results = results[\"RandomSampling\"].reshape((-1, n_cycles + 1))\n", "colors = plt.rcParams[\"axes.prop_cycle\"].by_key()[\"color\"]\n", "strategy_colors = {\n", " qs_name: colors[i % len(colors)]\n", " for i, qs_name in enumerate(query_strategy_names)\n", "}\n", "\n", "fig, (ax_abs, ax_diff) = plt.subplots(1, 2, figsize=(22, 9), sharex=True)\n", "\n", "for qs_name in query_strategy_names:\n", " result = results[qs_name].reshape((-1, n_cycles + 1))\n", " mean = np.mean(result, axis=0)\n", " std = np.std(result, axis=0)\n", " color = strategy_colors[qs_name]\n", "\n", " ax_abs.errorbar(\n", " x,\n", " mean,\n", " std,\n", " label=f\"{qs_name}: ALCU={np.mean(mean):.3f}\",\n", " alpha=0.5,\n", " color=color,\n", " linewidth=3,\n", " elinewidth=3,\n", " )\n", "\n", " if qs_name == \"RandomSampling\":\n", " ax_diff.axhline(\n", " 0.0,\n", " color=color,\n", " linestyle=\"--\",\n", " linewidth=3,\n", " label=f\"{qs_name}: \u0394ALCU=0.000\",\n", " )\n", " continue\n", "\n", " delta_result = result - random_results\n", " delta_mean = np.mean(delta_result, axis=0)\n", " delta_std = np.std(delta_result, axis=0)\n", " ax_diff.errorbar(\n", " x,\n", " delta_mean,\n", " delta_std,\n", " label=f\"{qs_name}: \u0394ALCU={np.mean(delta_mean):.3f}\",\n", " alpha=0.5,\n", " color=color,\n", " linewidth=3,\n", " elinewidth=3,\n", " )\n", "\n", "ax_abs.set_title(\"Absolute accuracy\", fontsize=\"x-large\")\n", "ax_diff.set_title(\"Accuracy difference to random\", fontsize=\"x-large\")\n", "ax_abs.set_yticks(np.arange(0, 1.0, 0.1))\n", "ax_abs.grid()\n", "ax_diff.grid()\n", "ax_abs.legend(loc=\"lower right\", fontsize=\"large\")\n", "ax_diff.legend(loc=\"upper right\", fontsize=\"large\")\n", "ax_abs.set_xlabel(\"# labeled samples\", fontsize=\"x-large\")\n", "ax_diff.set_xlabel(\"# labeled samples\", fontsize=\"x-large\")\n", "ax_abs.set_ylabel(\"accuracy\", fontsize=\"x-large\")\n", "ax_diff.set_ylabel(\"accuracy - random accuracy\", fontsize=\"x-large\")\n", "fig.tight_layout()\n", "plt.show()"]}], "metadata": {"kernelspec": {"display_name": "scikit-activeml", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11"}, "nbsphinx": {"orphan": true}}, "nbformat": 4, "nbformat_minor": 5}