Paper Annotation Tool using Text Embeddings#

Google Colab Note: If the notebook fails to run after installing the needed packages, try to restart the runtime (Ctrl + M) under Runtime -> Restart session.

Open in Colab

Notebook Dependencies

Uncomment the following cell to install all dependencies for this tutorial.

[1]:
# !pip install scikit-activeml datasets jupyter_ui_poll

This notebook showcases an example annotation tool to classify arXiv papers. The overall task is to classify arXiv paper into machine learning papers (cs.AI, cs.LG, cs.CV) and non-machine learning papers (cs.LO, cs.NI, cs.CR, cs.PL) which are class 1 and 0, respectively. We use ipywidgets to create the interface, such that users can annotate samples directly within the notebook. The dataset is a subset of the subset_arxiv_papers_with_embeddings dataset on huggingface from MongoDB which already provides precomputed embeddings.

NOTE: For testing execute this notebook on your local machine or click the Open in Google Colab button above.

[2]:
import datasets
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
import time
import textwrap
from sklearn.model_selection import train_test_split
from IPython.display import display
from jupyter_ui_poll import ui_events

Dataset#

First, we have to create our dataset, which we want to annotate. The papers are filtered according to their assigned category. To simplify the task, we exclude all papers that would belong to both classes. Based on the filtered papers, we subsample a balanced dataset.

[3]:
random_state = np.random.RandomState(0)


# functions to check if the provided category is indicative of class 0 or 1
def is_c0_category(category):
    return category in ["cs.LO", "cs.NI", "cs.CR", "cs.PL"]


def is_c1_category(category):
    return category in ["cs.AI", "cs.LG", "cs.CV"]


# check if one of the assigned category of a paper is assigned to a class
def is_valid_paper(category_field, category_fn):
    categories = category_field.split()
    valid_category = [category_fn(c) for c in categories]
    return any(valid_category)


# download dataset from huggingface
cache_dir = "huggingface_cache/datasets/"
arxiv = datasets.load_dataset(
    "MongoDB/subset_arxiv_papers_with_embeddings", cache_dir=cache_dir
)

# filter papers according to their categories
is_c1 = np.array(
    [is_valid_paper(c, is_c1_category) for c in arxiv["train"]["categories"]]
)
is_c0 = np.array(
    [is_valid_paper(c, is_c0_category) for c in arxiv["train"]["categories"]]
)

# exclude paper that belong to both classes
is_exclusive_c1 = np.logical_and(is_c1, ~is_c0)
is_exclusive_c0 = np.logical_and(is_c0, ~is_c1)

# calculate maximum subset size per class
subset_size = min(np.sum(is_exclusive_c0), np.sum(is_exclusive_c1))

# draw subset per class
is_exclusive_c1_subset = random_state.choice(
    np.where(is_exclusive_c1)[0], subset_size, replace=False
)
is_exclusive_c0_subset = random_state.choice(
    np.where(is_exclusive_c0)[0], subset_size, replace=False
)

# get embeddings for subsets
c1_embeddings = np.array(arxiv["train"]["embedding"])[is_exclusive_c1_subset]
c0_embeddings = np.array(arxiv["train"]["embedding"])[is_exclusive_c0_subset]
print(f"c0: {len(c0_embeddings)}")
print(f"c1: {len(c1_embeddings)}")

# create dataset as needed for scikit-activeml and the annotation tool
X_indices = np.concatenate(
    [is_exclusive_c0_subset, is_exclusive_c1_subset], dtype=int
)
X_embeddings = np.concatenate([c0_embeddings, c1_embeddings])
y = np.array([0] * len(c0_embeddings) + [1] * len(c1_embeddings), dtype=float)
Repo card metadata block was not found. Setting CardData to empty.
c0: 267
c1: 267

User Interface#

We need to define the user interface, such that the notebook can interact with the user. We use ipywidgets to create the shown widgets and jupyter_ui_poll to handle the polling for the events. The user interface iteratively presents a single sample for which the user can press a button to indicate to which class this sample belongs to. We also offer an additional button that always classifies the shown sample correctly. An output widget is updated after every annotation and shows the current performance of the trained classifier (with user annotations and ideal annotations). We also show the learning curve which shows the performance over time.

The process_batch_annotation_UI function offers two boolean flags to customize the user interface. If show_correct_class is set to True, the user interface also shows what the real class of the sample is. If skip_interaction is set to True, samples are annotated with the real class automatically.

[4]:
# this function removes whitespace and wraps the text again so it fits into the annotation tool
def normalize_text(text):
    text_lines = text.replace("\t", " ").replace("\n", " ").strip().split()
    text_single_line = " ".join(text_lines)
    wrapped_text = textwrap.fill(text_single_line, width=70)
    return wrapped_text

# this function blocks the execution to wait for button presses
def wait_for_interaction():
    global button_pressed
    with ui_events() as poll:
        poll(1000)
        while not button_pressed:
            poll(10)
            time.sleep(0.01)

# a button callback used to save the annotation
def button_clicked(i, arg):
    global button_pressed
    global results
    if not button_pressed:
        results.append(i)
        button_pressed = True

# create buttons and connect the callbacks
button_c0 = widgets.Button(
    description=f"Other",
    tooltip="Paper belongs to cs.LO, cs.NI, cs.CR, cs.PL.",
)
button_c0.on_click(lambda arg: button_clicked(0, arg))
button_c1 = widgets.Button(
    description=f"Machine Learning",
    tooltip="Paper belongs to cs.AI, cs.LG, or cs.CV.",
)
button_c1.on_click(lambda arg: button_clicked(1, arg))
button_creal = widgets.Button(
    description=f"Correct Class",
    tooltip="Use this button to always annotate with the real class.",
)
button_creal.on_click(lambda arg: button_clicked(-1, arg))

# output to display the sample that should be annotated
output_sample = widgets.Output()
# output to show a summary of the performance and annotation
output_stats = widgets.Output()
# output to show learning curves
output_plot = widgets.Output()

# create the layout for the annotation tool
buttons = widgets.VBox([button_c1, button_c0, button_creal])
buttons.layout.width = "200px"

output_sample.layout.width = "590px"

output_stats.layout.height = "450px"
output_stats.layout.width = "250px"

output_plot.layout.height = "450px"
output_plot.layout.width = "540px"

stats_box = widgets.HBox([output_stats, output_plot])
sample_box = widgets.HBox([buttons, output_sample])
output_box = widgets.VBox([stats_box, sample_box])


# this function processes handles the updates for the annotation tool
def process_batch_annotation_UI(
    cycle,
    learning_curve,
    learning_curve_ideal,
    query_indices,
    X_indices_train,
    y_train_true,
    num_annot_errors,
    num_total_annot,
    batch_size,
    # set this flag to True to additionally show the real class
    show_correct_class=False,
    # set this flag if the presented papers should be annotated automatically
    skip_interaction=False,
):
    # initialize global variables that are used by the callbacks
    global button_pressed
    global results
    button_pressed = False
    results = []
    # draw plot
    with output_plot:
        fig, ax = plt.subplots(figsize=(6, 4))
        ax.set_xlabel("Cycle")
        ax.set_ylabel("Accuracy")
        ax.plot(learning_curve, label="with annotated labels")
        ax.plot(learning_curve_ideal, label="with ideally annotated labels")
        ax.set_xlim(0, len(X_indices_train) // batch_size)
        ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.175), ncol=2)
        output_plot.clear_output(wait=True)
        display(fig)
        plt.close(fig)

    # update the stats text
        with output_stats:
            output_stats.clear_output(wait=True)
            print(f"Cycle: {cycle}")
            print(f"Accuracy:          {learning_curve[-1]:.3f}")
            print(f"Accuracy (ideal):  {learning_curve_ideal[-1]:.3f}")
            print(f"AULC:              {np.mean(learning_curve):.3f}")
            print(f"AULC (ideal):      {np.mean(learning_curve_ideal):.3f}")
            print(f"Annotation errors: {num_annot_errors}/{num_total_annot}")

    # iterate over all queried instances and process buttons and update outputs
    # after each annotations
    for batch_index, query_index in enumerate(query_indices):
        # reset button_pressed variable, such that waiting for button presses
        # works
        button_pressed = False
        # only do loop once if skip_interaction is True
        if skip_interaction and batch_index > 0:
            break
        # update the sample text which should be annotated
        with output_sample:
            output_sample.clear_output(wait=True)
            title = normalize_text(
                arxiv["train"][int(X_indices_train[int(query_index)])]["title"]
            )
            abstract = normalize_text(
                arxiv["train"][int(X_indices_train[int(query_index)])][
                    "abstract"
                ]
            )
            print(f"Batch index: {batch_index}/{batch_size}")
            if show_correct_class:
                correct_class_str = (
                    "Machine Learning"
                    if y_train_true[query_index] > 0
                    else "Other"
                )
                print(f"Correct class: {correct_class_str}")
            print("Title:")
            print(title)
            print()
            print("Abstract:")
            print(abstract)
        # replace all labels with the real labels if skip_interaction=True,
        # otherwise wait for the respective button input
        if skip_interaction:
            results.extend(y_train_true[query_indices])
        else:
            wait_for_interaction()

Annotation & Active Learning Cycle#

The following cell shows the active learning cycle. The classifier (clf), the query strategy (qs) and the batch size (batch_size) are marked with a todo comment and can be changed freely. Note, that you might have to adjust the call_func call to pass additional parameters to the query strategy’s query function.

The UI updates itself after pressing one of the buttons used for annotation to show the next sample to annotate. The learning curves are only updated once per batch.

[5]:
from sklearn.linear_model import LogisticRegression

from skactiveml.classifier import SklearnClassifier
from skactiveml.pool import RandomSampling
from skactiveml.utils import call_func

import warnings
warnings.filterwarnings("ignore")

# split data into a train and test set
X_indices_train, X_indices_test, X_train, X_test, y_train_true, y_test = (
    train_test_split(X_indices, X_embeddings, y, shuffle=True, random_state=0)
)
# initialize the labels for the training set
# y_train will consist of the annotations the user does
y_train = np.full_like(y_train_true, np.nan)
# y_train_ideal is labeled under a perfect annotator
y_train_ideal = np.full_like(y_train_true, np.nan)

# number of queries per batch
# TODO: change the batch size to see how it affects various query strategies
batch_size = 5

# initialize the classifier
# TODO: replace with a classifier of choice
clf = SklearnClassifier(
    LogisticRegression(random_state=0), classes=[0, 1], random_state=0
)
# initialize the query strategy
# TODO: replace with a query strategy of choice
qs = RandomSampling(random_state=0)

# the accuracies per cycle
learning_curve = []
learning_curve_ideal = []
# number of annotation errors
num_annotation_errors = 0
# number of total annotations
num_total_annotations = 0

# clear all outputs
output_sample.clear_output()
output_stats.clear_output()
output_plot.clear_output()

# display the UI
display(output_box)

# iterate over all cycles
for cycle in range(0, len(X_train) // batch_size):
    # calculate accuracy for this cycle with ideal annotations
    clf.fit(X_train, y_train_ideal)
    score_ideal = clf.score(X_test, y_test)
    learning_curve_ideal.append(score_ideal)
    # calculate accuracy for this cycle with user annotations
    clf.fit(X_train, y_train)
    score = clf.score(X_test, y_test)
    learning_curve.append(score)

    # get indices to query labels for from the query strategy
    query_indices = call_func(
        qs.query,
        X=X_train,
        y=y_train,
        batch_size=batch_size,
        clf=clf,
        fit_clf=False,
    )

    # update UI
    process_batch_annotation_UI(
        cycle,
        learning_curve,
        learning_curve_ideal,
        query_indices,
        X_indices_train,
        y_train_true,
        num_annotation_errors,
        num_total_annotations,
        batch_size,
        show_correct_class=False,
        skip_interaction=False,
    )

    # Update labels with new annotations
    # replace all -1 labels (Correct Class button) with the correct label
    results = [
        y_train_true[i] if r == -1 else r
        for i, r in zip(query_indices, results)
    ]
    # set label in y_train and y_train_ideal
    y_train[query_indices] = results
    y_train_ideal[query_indices] = y_train_true[query_indices]
    # increase number of errors and total annotations
    num_annotation_errors += np.sum(y_train_true[query_indices] != results)
    num_total_annotations += len(results)

The cell above produces an output which looks like the following image:

image.png