{ "cells": [ { "cell_type": "markdown", "id": "33f39229508719ea", "metadata": { "collapsed": false }, "source": [ "# Active Image Classification via Self-supervised Learning" ] }, { "cell_type": "markdown", "id": "fd4eb6e7", "metadata": {}, "source": [ "> **_Google Colab Note:_** If the notebook fails to run after installing the needed packages, try to restart the runtime (Ctrl + M) under Runtime -> Restart session.\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/scikit-activeml/scikit-activeml-docs/blob/gh-pages/latest/generated/tutorials_colab//05_pool_al_with_self_supervised_learning.ipynb)" ] }, { "cell_type": "markdown", "id": "79402ebb", "metadata": {}, "source": [ "**Notebook Dependencies**\n", "\n", "Uncomment the following cells to install all dependencies for this tutorial." ] }, { "cell_type": "code", "execution_count": 1, "id": "0c5bb17e", "metadata": {}, "outputs": [], "source": [ "# !pip install scikit-activeml torch torchvision torchaudio tqdm" ] }, { "cell_type": "markdown", "id": "a773fbde", "metadata": {}, "source": [ "
" ] }, { "cell_type": "markdown", "id": "fa4583e83ede8da9", "metadata": { "collapsed": false }, "source": [ "This tutorial aims to demonstrate a practical comparison study using our ```scikit-activeml``` library. The workflow involves utilizing a self-supervised learning model, specifically ```DINOv2``` from [1], to generate embeddings for the Flowers-102 dataset [2]. Subsequently, various active learning strategies will be employed to intelligently select samples for labeling.\n", "\n", "**Key Steps:**\n", "1. **Self-Supervised Learning Model:** Utilize the DINOv2 model to create embedding dataset for Flowers-102 dataset.\n", "\n", "2. **Active Learning Strategies:** Employ different active learning strategies provided by our library, including:\n", " - Random Sampling\n", " - Uncertainty Sampling\n", " - Discriminative Active Learning (DiscriminativeAL)\n", " - CoreSet\n", " - TypiClust\n", " - Badge\n", "\n", "3. **Batch Sample Selection:** Use each active learning strategy to select a batch of samples for labeling.\n", "\n", "4. **Plotting the results:** By the end of this notebook, we'll compare the accuracy of the aforementioned active learning strategies.\n", "\n", "**References:**\n", "\n", "[1] M. Oquab et al., ‘DINOv2: Learning Robust Visual Features without Supervision’. Transactions on Machine Learning Research (TMLR)\n", "\n", "[2] M. E. Nilsback and A. Zisserman, 'Automated Flower Classification over a Large Number of Classes'. Indian Conference on Computer Vision, Graphics and Image Processing (ICVGIP)" ] }, { "cell_type": "code", "execution_count": 2, "id": "3a4b6dbb9143a5eb", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T17:58:49.361696Z", "start_time": "2024-02-16T17:58:47.694499Z" }, "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib as mlp\n", "import matplotlib.pyplot as plt\n", "\n", "from sklearn.linear_model import LogisticRegression\n", "\n", "from skactiveml.classifier import SklearnClassifier\n", "from skactiveml.pool import UncertaintySampling, RandomSampling, DiscriminativeAL, CoreSet, TypiClust, Badge\n", "from skactiveml.utils import call_func, MISSING_LABEL\n", "\n", "import torch\n", "import torchvision.datasets as datasets\n", "import torchvision.transforms as transforms\n", "from tqdm import tqdm\n", "\n", "import warnings\n", "mlp.rcParams[\"figure.facecolor\"] = \"white\"\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "id": "3ac32ea901e08c01", "metadata": { "collapsed": false }, "source": [ "## Prepare Data with DINOv2" ] }, { "cell_type": "markdown", "id": "7a4e3bf805669b1c", "metadata": { "collapsed": false }, "source": [ "In this step, we focus on preparing the datasets using the self-supervised learning model DINOv2. DINOv2, short for \"self-distillation with no labels\", is a state-of-the-art model that excels at learning meaningful representations from unlabeled data.\n", "\n", "If you've already completed these steps, you can skip ahead to loading your data. " ] }, { "cell_type": "markdown", "id": "9c0819f4-b45f-4285-aa1d-84469764b6a5", "metadata": {}, "source": [ "**Step 1: Transformation**\n", "\n", "Apply necessary transformations to the datasets, including resizing images to a standardized format. This ensures consistency of input dimensions with the DINOv2 model." ] }, { "cell_type": "code", "execution_count": 3, "id": "3b27a5c395f5897e", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T17:58:55.815281Z", "start_time": "2024-02-16T17:58:55.810506Z" }, "collapsed": false }, "outputs": [], "source": [ "transforms = transforms.Compose(\n", " [transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]\n", " )" ] }, { "cell_type": "markdown", "id": "0493d78a-c3d5-42b3-b5b6-b83191d4bbe1", "metadata": {}, "source": [ "**Step 2: Load pretrained Model**\n", "To calculate embeddings, we'll use DINOv2. Below we load the second smallest DINOv2 model to generate embedding datasets for the Flowers-102 datasets. " ] }, { "cell_type": "code", "execution_count": 4, "id": "387db871-7e26-46d0-9722-45c71ed40014", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T18:04:27.290216Z", "start_time": "2024-02-16T18:03:24.949961Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading: \"https://github.com/facebookresearch/dinov2/zipball/main\" to cache/main.zip\n", "Downloading: \"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth\" to cache/checkpoints\\dinov2_vitb14_pretrain.pth\n", "100%|██████████| 330M/330M [00:03<00:00, 113MB/s] \n" ] }, { "data": { "text/plain": [ "DinoVisionTransformer(\n", " (patch_embed): PatchEmbed(\n", " (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))\n", " (norm): Identity()\n", " )\n", " (blocks): ModuleList(\n", " (0-11): 12 x NestedTensorBlock(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): MemEffAttention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): LayerScale()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): LayerScale()\n", " (drop_path2): Identity()\n", " )\n", " )\n", " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (head): Identity()\n", ")" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.hub.set_dir('cache/')\n", "\n", "dinov2_vitb14 = torch.hub.load(\"facebookresearch/dinov2\", \"dinov2_vitb14\")\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "dinov2_vitb14.to(device)" ] }, { "cell_type": "markdown", "id": "7668a073b61085ea", "metadata": { "collapsed": false }, "source": [ "**Step 3: Load Datasets and Generate Embeddings**\n", "\n", "Firstly, we begin by loading the Flowers-102 dataset.\n", "\n", "After that, we employ the pre-trained DINOv2 model to generate embeddings for each image in the dataset and save them in a npy file." ] }, { "cell_type": "code", "execution_count": 5, "id": "8c0165be-d3ce-4eb2-9309-45ba6c08c8d3", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T18:08:01.968032Z", "start_time": "2024-02-16T18:08:01.959921Z" } }, "outputs": [], "source": [ "def load_and_process_dataset(root_dir, is_train, batch_size=4):\n", " \"\"\"\n", " Load and process a given dataset for training or validation.\n", "\n", " Parameters:\n", " - root_dir (str) : Root directory where the dataset will be stored.\n", " - is_train (bool) : Boolean indicating whether the dataset is for training (True) or validation (False).\n", " - batch_size (int) : The batch_size used for the DataLoader.\n", "\n", " Returns:\n", " - X (numpy.ndarray): Concatenated embeddings of the dataset.\n", " - y_true (numpy.ndarray): Concatenated true labels of the dataset.\n", " \"\"\"\n", "\n", " # Load dataset\n", " if is_train:\n", " split = 'train'\n", " else:\n", " split = 'val'\n", "\n", " dataset = datasets.Flowers102(root=root_dir, split=split, download=True, transform=transforms)\n", "\n", "\n", " # Create DataLoader\n", " dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=is_train, num_workers=2)\n", "\n", " embedding_list = []\n", " label_list = []\n", "\n", " # Iterate through the DataLoader and extract embeddings\n", " with torch.no_grad():\n", " for i, data in tqdm(enumerate(dataloader), total=len(dataloader), desc=f\"Flowers102 {split}\"):\n", " image, label = data\n", " embeddings = dinov2_vitb14(image.to(device)).cpu()\n", " embedding_list.append(embeddings)\n", " label_list.append(label)\n", "\n", " # Concatenate embeddings and labels\n", " X = torch.cat(embedding_list, dim=0).numpy()\n", " y_true = torch.cat(label_list, dim=0).numpy()\n", "\n", " return X, y_true" ] }, { "cell_type": "markdown", "id": "b0ac7d46-ab15-4aaf-ad72-2f600c227cd9", "metadata": {}, "source": [ "Applying on Flowers102 Datasets" ] }, { "cell_type": "code", "execution_count": 6, "id": "b9c0fa197cbde9a", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T18:08:03.651976Z", "start_time": "2024-02-16T18:08:03.649154Z" }, "collapsed": false }, "outputs": [], "source": [ "data_dir = \"./data\"" ] }, { "cell_type": "code", "execution_count": 7, "id": "4abc295a-f41d-42aa-9b02-3a9fffc28b1d", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T18:15:55.521694Z", "start_time": "2024-02-16T18:12:23.843887Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading https://thor.robots.ox.ac.uk/flowers/102/102flowers.tgz to data\\flowers-102\\102flowers.tgz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 344862509/344862509 [00:05<00:00, 68322052.63it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data\\flowers-102\\102flowers.tgz to data\\flowers-102\n", "Downloading https://thor.robots.ox.ac.uk/flowers/102/imagelabels.mat to data\\flowers-102\\imagelabels.mat\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 502/502 [00:00<00:00, 334318.93it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading https://thor.robots.ox.ac.uk/flowers/102/setid.mat to data\\flowers-102\\setid.mat\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 14989/14989 [00:00<00:00, 7487901.70it/s]\n", "Flowers102 train: 100%|██████████| 255/255 [00:09<00:00, 25.93it/s]\n", "Flowers102 val: 100%|██████████| 255/255 [00:09<00:00, 27.97it/s]\n" ] } ], "source": [ "# Flowers-102\n", "flowers102_X_train, flowers102_y_train_true = load_and_process_dataset(data_dir, True)\n", "flowers102_X_test, flowers102_y_test_true = load_and_process_dataset(data_dir, False)\n", "\n", "np.save(f'{data_dir}/flowers102_dinov2B_X_train.npy', flowers102_X_train)\n", "np.save(f'{data_dir}/flowers102_dinov2B_y_train.npy', flowers102_y_train_true)\n", "np.save(f'{data_dir}/flowers102_dinov2B_X_test.npy', flowers102_X_test)\n", "np.save(f'{data_dir}/flowers102_dinov2B_y_test.npy', flowers102_y_test_true)" ] }, { "cell_type": "markdown", "id": "fe0757f8-9261-4510-90cd-8a2a8707add7", "metadata": {}, "source": [ "## Load your preprocessed Dataset\n", "\n", "If you have previously processed your data with DINOv2, please use the following code to load your data. And we also define the number of classes in the Flowers102 dataset." ] }, { "cell_type": "code", "execution_count": 8, "id": "8fd85d019561dd8d", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T18:16:00.131931Z", "start_time": "2024-02-16T18:16:00.127759Z" }, "collapsed": false }, "outputs": [], "source": [ "data_dir = \"./data\"" ] }, { "cell_type": "code", "execution_count": 9, "id": "e29fc457-2fb1-4824-8cd5-b637e44795a8", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T18:16:09.760397Z", "start_time": "2024-02-16T18:16:09.744836Z" } }, "outputs": [], "source": [ "# Flowers-102\n", "X_train = np.load(f'{data_dir}/flowers102_dinov2B_X_train.npy')\n", "y_train_true = np.load(f'{data_dir}/flowers102_dinov2B_y_train.npy')\n", "X_test = np.load(f'{data_dir}/flowers102_dinov2B_X_test.npy')\n", "y_test_true = np.load(f'{data_dir}/flowers102_dinov2B_y_test.npy')\n", "\n", "dataset_classes = 102" ] }, { "cell_type": "markdown", "id": "cf4565232ae31432", "metadata": { "collapsed": false }, "source": [ "## Random Seed Management" ] }, { "cell_type": "markdown", "id": "b1ce38a2-0305-44eb-abf5-cb69c41c3f98", "metadata": {}, "source": [ "To ensure experiment reproducibility, it's important to set random states for all components that might use them. For simplicity, we set a single fixed random state and use helper functions to generate new seeds and random states. It's important to note that the ```master_random_state``` should only be used to create new random states or random seeds." ] }, { "cell_type": "code", "execution_count": 10, "id": "11e6ce648a9ab110", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T18:16:16.503717Z", "start_time": "2024-02-16T18:16:16.497010Z" }, "collapsed": false }, "outputs": [], "source": [ "master_random_state = np.random.RandomState(0)\n", "\n", "def gen_seed(random_state:np.random.RandomState):\n", " \"\"\"\n", " Generate a seed for a random number generator.\n", "\n", " Parameters:\n", " - random_state (np.random.RandomState): Random state object.\n", "\n", " Returns:\n", " - int: Generated seed.\n", " \"\"\"\n", " return random_state.randint(0, 2**31)\n", "\n", "def gen_random_state(random_state:np.random.RandomState):\n", " \"\"\"\n", " Generate a new random state object based on a given random state.\n", "\n", " Parameters:\n", " - random_state (np.random.RandomState): Random state object.\n", "\n", " Returns:\n", " - np.random.RandomState: New random state object.\n", " \"\"\"\n", " return np.random.RandomState(gen_seed(random_state))" ] }, { "cell_type": "markdown", "id": "24cc0c5019b852b5", "metadata": { "collapsed": false }, "source": [ "## Classification Models and Query Strategies" ] }, { "cell_type": "markdown", "id": "91a3743e-cb9e-49a9-a239-30c9f54839fa", "metadata": {}, "source": [ "The embeddings we have computed can be used as an input to a classification model. For this guide, we use `LogisticRegression` from `sklearn`. Moreover, we handle the creation of query strategies using factory functions to simplify the separation of query strategies across repetitions." ] }, { "cell_type": "code", "execution_count": 11, "id": "d356c3752b0a58a1", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T18:16:19.280308Z", "start_time": "2024-02-16T18:16:19.273518Z" }, "collapsed": false }, "outputs": [], "source": [ "clf = SklearnClassifier(LogisticRegression(), classes=np.arange(dataset_classes), random_state=gen_seed(master_random_state))\n", "\n", "def create_query_strategy(name, random_state):\n", " return query_strategy_factory_functions[name](random_state)\n", "\n", "query_strategy_factory_functions = {\n", " 'RandomSampling': lambda random_state: RandomSampling(random_state=gen_seed(random_state)),\n", " 'UncertaintySampling': lambda random_state: UncertaintySampling(random_state=gen_seed(random_state)),\n", " 'DiscriminativeAL': lambda random_state: DiscriminativeAL(random_state=gen_seed(random_state)),\n", " 'CoreSet': lambda random_state: CoreSet(random_state=gen_seed(random_state)),\n", " 'TypiClust': lambda random_state: TypiClust(random_state=gen_seed(random_state)),\n", " 'Badge': lambda random_state: Badge(random_state=gen_seed(random_state))\n", "}" ] }, { "cell_type": "markdown", "id": "9023c37048ec8f54", "metadata": { "collapsed": false }, "source": [ "## Experiment Parameters\n", "\n", "For this experiment, we need to define how the strategies should be compared against one another. Here the number of repetitions (```n_reps```), the number of cycles (```n_cycles```), and the size of each query (```query_batch_size```) need to be defined. " ] }, { "cell_type": "code", "execution_count": 12, "id": "e103670190e73a3f", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T18:16:22.477680Z", "start_time": "2024-02-16T18:16:22.469634Z" }, "collapsed": false }, "outputs": [], "source": [ "n_reps = 3\n", "n_cycles = 30\n", "query_batch_size = 8\n", "query_strategy_names = query_strategy_factory_functions.keys()" ] }, { "cell_type": "markdown", "id": "66b8fff1b12211ed", "metadata": { "collapsed": false }, "source": [ "## Experiment Loop\n", "\n", "The actual experiment loops over all query strategies. The accuracy for the test set is stored for each cycle and repetition in the `results` dictionary." ] }, { "cell_type": "code", "execution_count": 13, "id": "718629bb-1c4f-4707-9eb1-9d1df901cd84", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T18:33:37.013276Z", "start_time": "2024-02-16T18:16:23.899324Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Repeat 1 with RandomSampling: 100%|██████████| 30/30 [00:01<00:00, 21.70it/s]\n", "Repeat 2 with RandomSampling: 100%|██████████| 30/30 [00:01<00:00, 23.39it/s]\n", "Repeat 3 with RandomSampling: 100%|██████████| 30/30 [00:01<00:00, 25.08it/s]\n", "Repeat 1 with UncertaintySampling: 100%|██████████| 30/30 [00:03<00:00, 8.37it/s]\n", "Repeat 2 with UncertaintySampling: 100%|██████████| 30/30 [00:03<00:00, 9.47it/s]\n", "Repeat 3 with UncertaintySampling: 100%|██████████| 30/30 [00:02<00:00, 10.31it/s]\n", "Repeat 1 with DiscriminativeAL: 100%|██████████| 30/30 [00:13<00:00, 2.17it/s]\n", "Repeat 2 with DiscriminativeAL: 100%|██████████| 30/30 [00:13<00:00, 2.21it/s]\n", "Repeat 3 with DiscriminativeAL: 100%|██████████| 30/30 [00:13<00:00, 2.18it/s]\n", "Repeat 1 with CoreSet: 100%|██████████| 30/30 [00:05<00:00, 5.06it/s]\n", "Repeat 2 with CoreSet: 100%|██████████| 30/30 [00:05<00:00, 5.06it/s]\n", "Repeat 3 with CoreSet: 100%|██████████| 30/30 [00:05<00:00, 5.22it/s]\n", "Repeat 1 with TypiClust: 100%|██████████| 30/30 [00:14<00:00, 2.06it/s]\n", "Repeat 2 with TypiClust: 100%|██████████| 30/30 [00:14<00:00, 2.05it/s]\n", "Repeat 3 with TypiClust: 100%|██████████| 30/30 [00:14<00:00, 2.14it/s]\n", "Repeat 1 with Badge: 100%|██████████| 30/30 [00:37<00:00, 1.24s/it]\n", "Repeat 2 with Badge: 100%|██████████| 30/30 [00:37<00:00, 1.26s/it]\n", "Repeat 3 with Badge: 100%|██████████| 30/30 [00:38<00:00, 1.29s/it]\n" ] } ], "source": [ "results = {}\n", "\n", "for qs_name in query_strategy_names:\n", " accuracies = np.full((n_reps, n_cycles), np.nan)\n", " for i_rep in range(n_reps):\n", " y_train = np.full(shape=y_train_true.shape, fill_value=MISSING_LABEL)\n", "\n", " qs = create_query_strategy(qs_name, random_state=gen_random_state(np.random.RandomState(i_rep)))\n", " clf.fit(X_train, y_train)\n", "\n", " for c in tqdm(range(n_cycles), desc=f'Repeat {i_rep + 1} with {qs_name}'):\n", " query_idx = call_func(qs.query, X=X_train, y=y_train, batch_size=query_batch_size, clf=clf, discriminator=clf)\n", " y_train[query_idx] = y_train_true[query_idx]\n", " clf.fit(X_train, y_train)\n", " score = clf.score(X_test, y_test_true)\n", " accuracies[i_rep, c] = score\n", "\n", " results[qs_name] = accuracies" ] }, { "cell_type": "markdown", "id": "998e52cb1a2107b", "metadata": { "collapsed": false }, "source": [ "## Resulting Plotting\n", "\n", "We use learning curves to compare strategies. We visualize the average accuracy over all repetitions. In addition, the legend provides insight into the area under the learning curve, which indicates the average accuracy over all cycles." ] }, { "cell_type": "code", "execution_count": 14, "id": "7ed6ebe52431f489", "metadata": { "ExecuteTime": { "end_time": "2024-02-16T18:33:44.893333Z", "start_time": "2024-02-16T18:33:44.750552Z" }, "collapsed": false }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for qs_name in query_strategy_names:\n", " key = qs_name\n", " result = results[key]\n", " reshaped_result = result.reshape((-1, n_cycles))\n", " errorbar_mean = np.mean(reshaped_result, axis=0)\n", " errorbar_std = np.std(reshaped_result, axis=0)\n", " plt.errorbar(np.arange(1, n_cycles+1), errorbar_mean, errorbar_std, label=f\"({np.mean(errorbar_mean):.4f}) {qs_name}\", alpha=0.5)\n", "plt.title(f\"LogisticRegression with query batch size {query_batch_size}\")\n", "plt.legend(loc='lower right')\n", "plt.xlabel('cycle')\n", "plt.ylabel('accuracy')\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }