Open In Colab

Hybrid Classification & Embedding

In this example we’ll set up a hybrid paraDime routine that can both classify high-dimensional data and embed it in a two dimensional space. Our routine will learn both tasks simultaneously with a shared latent space of the model.

We start by importing some packages and the relevant paraDime submodules. We also call paraDime’s seed utility for reproducibility reasons.

[1]:
import copy
import torch
from matplotlib import pyplot as plt

import paradime.dr
import paradime.loss
import paradime.models
import paradime.routines
import paradime.utils

paradime.utils.seed.seed_all(42);

We test our hybrid model on the MNIST handwritten image dataset available via torchvision.

[2]:
import torchvision

mnist = torchvision.datasets.MNIST(
    '../data',
    train=True,
    download=True,
)
mnist_data = mnist.data.reshape(-1, 28*28) / 255.
num_items = 5000

mnist_subset = mnist_data[:num_items]
target_subset = mnist.targets[:num_items]

Defining Our Custom Model

Let’s now define our hybrid model. In our model’s __init__ we simply create a list of fully connected layers depending on our input and hidden layer dimensions. From the final hidden layer, the model branches out into an embedding part with default dimensionality 2 and another part for classification, which depends on the number of target classes.

[3]:
class HybridEmbeddingModel(paradime.models.Model):
    """A fully connected network for hybrid embedding and classification.

    Args:
        in_dim: Input dimension.
        hidden_dims: List of hidden layer dimensions.
        num_classes: Number of target classes.
        emb_dim: Embedding dimension (2 by default).
    """

    def __init__(self,
        in_dim: int,
        hidden_dims: list[int],
        num_classes: int,
        emb_dim: int = 2,
    ):
        super().__init__()

        self.layers = torch.nn.ModuleList()

        cur_dim = in_dim
        for hdim in hidden_dims:
            self.layers.append(torch.nn.Linear(cur_dim, hdim))
            cur_dim = hdim

        self.emb_layer = torch.nn.Linear(cur_dim, emb_dim)
        self.class_layer = torch.nn.Linear(cur_dim, num_classes)

        self.alpha = torch.nn.Parameter(torch.tensor(1.0))

    def common_forward(self, x):
        for layer in self.layers:
            # x = torch.sigmoid(layer(x))
            x = layer(x)
            x = torch.nn.functional.softplus(x)
        return x

    def embed(self, x):
        x = self.common_forward(x)
        x = self.emb_layer(x)
        return x

    def classify(self, x):
        x = self.common_forward(x)
        x = self.class_layer(x)
        return x

The model has a common_forward method, which propagates the input until the final hidden layer. It also has an embed method that takes the common forward output (i.e., latent representation) and embeds it, and a classify method that takes the same representation and turns it into prediction scores.

Defining the paraDime Routine

Now we can define our paraDime routine. We’ll borrow most settings from paraDime’s built-in ParametricTSNE class.

To do this, we first define a dummy parametric TSNE routine. The built-in TSNE class defers adding the training phases until the training actually begins. Only then does it call the _prepare_training method, which adds the training phases. Since we want to copy the training phases, we have to call this method ourselves:

[4]:
tsne = paradime.routines.ParametricTSNE(dataset=mnist_subset)
tsne._prepare_training()

Now we can acces our dummy method’s relations, training phases, the main embedding loss, and the dataset (which now already has PCA added to it, in case we want to use an initialization phase later):

[5]:
global_rel = tsne.global_relations
batch_rel = tsne.batch_relations
init_phase, main_phase = tsne.training_phases
rel_loss = main_phase.loss
dataset = tsne.dataset

With these building blocks, we can go on to define our routine. In the following cell, quite a lot of things happen. Let’s break it down one by one:

  • First we define a list of weights, which we will later use in a compound loss to run a bunch of experiments at once.

  • Now we create our actual hybrid_loss. It is a compound loss consisting of (1) the orignal loss that we copied from the built-in TSNE implementation, and (2) a simple classification loss. Because we implemented our custom model with a classify method, we don’t have to specify anything here. We weigh our losses differently in each iteration of our experiment loop. (Note: We create a deepcopy of the original loss, so that a new instance is created in each iteration. This makes sure that the loss’s history is not messed up, in case we want to inspect it later.)

  • We’re ready to define our hybrid_tsne routine now. As a model we use our HybridEmbeddingModel. The rest of the settings are the ones we copied from the built-in class above.

  • Because the classification component of our compound loss require labels, we now add the labels to the dataset using the routine’s add_to_dataset method.

  • Now would be the time to add an optional initialization phase, reusing the one from the built-in method. This is commented out, since in this example we don’t want to affect the classification with an initialization of the embedding.

  • Next, we add the main trianing phase with our hybrid_loss.

  • And finally, we can call the routine’s train method.

The three blocks of code that follow the training are just there for saving the embedded MNIST subsets and our classifier’s accuracies for each weight.

[6]:
embeddings = []
train_accuracies = []
test_accuracies = []

weights = [0.0, 5.0, 20.0, 50.0, 100.0, 200.0, 300.0, 500.0]

for w in weights:
    hybrid_loss = paradime.loss.CompoundLoss(
        [copy.deepcopy(rel_loss), paradime.loss.ClassificationLoss()],
        weights=[w, 1.0],
    )

    paradime.utils.logging.log(f"Weight: {w}")

    hybrid_tsne = paradime.dr.ParametricDR(
        model=HybridEmbeddingModel(
            in_dim=28 * 28, hidden_dims=[100, 50], num_classes=10, emb_dim=2,
        ),
        dataset=dataset,
        global_relations=global_rel,
        batch_relations=batch_rel,
        use_cuda=True,
        verbose=True,
    )
    hybrid_tsne.add_to_dataset({"labels": target_subset})
    # hybrid_tsne.add_training_phase(copy.deepcopy(init_phase))
    hybrid_tsne.add_training_phase(
        epochs=50,
        batch_size=500,
        learning_rate=0.001,
        loss=hybrid_loss,
    )
    hybrid_tsne.train()

    embeddings.append(hybrid_tsne.apply(mnist_subset, "embed"))

    train_logits = hybrid_tsne.apply(mnist_subset, "classify")
    train_prediction = torch.argmax(train_logits, dim=1)
    train_accuracies.append(
        torch.sum(train_prediction == target_subset) / num_items
    )

    test_logits = hybrid_tsne.apply(
        mnist_data[num_items : 2 * num_items], "classify"
    )
    test_prediction = torch.argmax(test_logits, dim=1)
    test_accuracies.append(
        torch.sum(test_prediction == mnist.targets[num_items : 2 * num_items])
        / num_items
    )

2022-08-30 20:37:26,942: Weight: 0.0
2022-08-30 20:37:26,944: Registering dataset.
2022-08-30 20:37:27,042: Adding entry 'labels' to dataset.
2022-08-30 20:37:27,043: Computing global relations 'rel'.
2022-08-30 20:37:27,044: Indexing nearest neighbors.
2022-08-30 20:37:42,841: Calculating probabilities.
2022-08-30 20:37:43,297: Beginning training phase 'None'.
2022-08-30 20:37:45,580: Loss after epoch 0: 22.895014762878418
2022-08-30 20:37:47,444: Loss after epoch 5: 5.923945248126984
2022-08-30 20:37:49,409: Loss after epoch 10: 3.1659086644649506
2022-08-30 20:37:51,232: Loss after epoch 15: 2.5046416670084
2022-08-30 20:37:52,935: Loss after epoch 20: 2.0659042298793793
2022-08-30 20:37:54,722: Loss after epoch 25: 1.763053223490715
2022-08-30 20:37:56,292: Loss after epoch 30: 1.4860624969005585
2022-08-30 20:37:57,876: Loss after epoch 35: 1.2686931937932968
2022-08-30 20:37:59,648: Loss after epoch 40: 1.1323856264352798
2022-08-30 20:38:01,708: Loss after epoch 45: 0.9357316344976425
2022-08-30 20:38:02,920: Weight: 5.0
2022-08-30 20:38:02,922: Registering dataset.
2022-08-30 20:38:02,923: Adding entry 'labels' to dataset.
2022-08-30 20:38:02,923: Computing global relations 'rel'.
2022-08-30 20:38:02,924: Indexing nearest neighbors.
2022-08-30 20:38:04,291: Calculating probabilities.
2022-08-30 20:38:04,518: Beginning training phase 'None'.
2022-08-30 20:38:04,931: Loss after epoch 0: 23.25223970413208
2022-08-30 20:38:06,781: Loss after epoch 5: 6.683535039424896
2022-08-30 20:38:08,322: Loss after epoch 10: 3.6716531813144684
2022-08-30 20:38:10,026: Loss after epoch 15: 2.8826711028814316
2022-08-30 20:38:11,845: Loss after epoch 20: 2.4705112278461456
2022-08-30 20:38:13,606: Loss after epoch 25: 2.174752339720726
2022-08-30 20:38:15,387: Loss after epoch 30: 1.9238889664411545
2022-08-30 20:38:17,268: Loss after epoch 35: 1.751110389828682
2022-08-30 20:38:19,449: Loss after epoch 40: 1.628341406583786
2022-08-30 20:38:21,750: Loss after epoch 45: 1.4107389375567436
2022-08-30 20:38:23,335: Weight: 20.0
2022-08-30 20:38:23,336: Registering dataset.
2022-08-30 20:38:23,338: Adding entry 'labels' to dataset.
2022-08-30 20:38:23,339: Computing global relations 'rel'.
2022-08-30 20:38:23,339: Indexing nearest neighbors.
2022-08-30 20:38:24,549: Calculating probabilities.
2022-08-30 20:38:24,778: Beginning training phase 'None'.
2022-08-30 20:38:25,149: Loss after epoch 0: 24.82417392730713
2022-08-30 20:38:26,936: Loss after epoch 5: 7.285155415534973
2022-08-30 20:38:28,715: Loss after epoch 10: 4.459277033805847
2022-08-30 20:38:30,499: Loss after epoch 15: 3.819707542657852
2022-08-30 20:38:32,411: Loss after epoch 20: 3.5017621219158173
2022-08-30 20:38:34,179: Loss after epoch 25: 3.1889511048793793
2022-08-30 20:38:35,959: Loss after epoch 30: 2.9895763993263245
2022-08-30 20:38:37,707: Loss after epoch 35: 2.743565410375595
2022-08-30 20:38:39,494: Loss after epoch 40: 2.5353642404079437
2022-08-30 20:38:41,155: Loss after epoch 45: 2.3798455893993378
2022-08-30 20:38:42,466: Weight: 50.0
2022-08-30 20:38:42,467: Registering dataset.
2022-08-30 20:38:42,470: Adding entry 'labels' to dataset.
2022-08-30 20:38:42,471: Computing global relations 'rel'.
2022-08-30 20:38:42,471: Indexing nearest neighbors.
2022-08-30 20:38:43,652: Calculating probabilities.
2022-08-30 20:38:43,895: Beginning training phase 'None'.
2022-08-30 20:38:44,285: Loss after epoch 0: 28.03474760055542
2022-08-30 20:38:46,084: Loss after epoch 5: 10.625086545944214
2022-08-30 20:38:47,883: Loss after epoch 10: 6.836755931377411
2022-08-30 20:38:49,889: Loss after epoch 15: 6.091784834861755
2022-08-30 20:38:51,669: Loss after epoch 20: 5.645148038864136
2022-08-30 20:38:53,400: Loss after epoch 25: 5.316219866275787
2022-08-30 20:38:55,000: Loss after epoch 30: 5.004549890756607
2022-08-30 20:38:56,788: Loss after epoch 35: 4.790313869714737
2022-08-30 20:38:58,793: Loss after epoch 40: 4.5410477221012115
2022-08-30 20:39:00,546: Loss after epoch 45: 4.355794429779053
2022-08-30 20:39:01,935: Weight: 100.0
2022-08-30 20:39:01,936: Registering dataset.
2022-08-30 20:39:01,938: Adding entry 'labels' to dataset.
2022-08-30 20:39:01,939: Computing global relations 'rel'.
2022-08-30 20:39:01,939: Indexing nearest neighbors.
2022-08-30 20:39:03,346: Calculating probabilities.
2022-08-30 20:39:03,580: Beginning training phase 'None'.
2022-08-30 20:39:04,132: Loss after epoch 0: 32.7530312538147
2022-08-30 20:39:05,762: Loss after epoch 5: 13.464583992958069
2022-08-30 20:39:07,510: Loss after epoch 10: 10.146591305732727
2022-08-30 20:39:09,238: Loss after epoch 15: 9.307429790496826
2022-08-30 20:39:10,904: Loss after epoch 20: 8.781771123409271
2022-08-30 20:39:12,594: Loss after epoch 25: 8.418864369392395
2022-08-30 20:39:14,475: Loss after epoch 30: 8.03650563955307
2022-08-30 20:39:16,118: Loss after epoch 35: 7.703064024448395
2022-08-30 20:39:17,741: Loss after epoch 40: 7.550269901752472
2022-08-30 20:39:19,472: Loss after epoch 45: 7.269642651081085
2022-08-30 20:39:21,026: Weight: 200.0
2022-08-30 20:39:21,027: Registering dataset.
2022-08-30 20:39:21,028: Adding entry 'labels' to dataset.
2022-08-30 20:39:21,029: Computing global relations 'rel'.
2022-08-30 20:39:21,030: Indexing nearest neighbors.
2022-08-30 20:39:22,213: Calculating probabilities.
2022-08-30 20:39:22,441: Beginning training phase 'None'.
2022-08-30 20:39:22,879: Loss after epoch 0: 42.25814485549927
2022-08-30 20:39:24,658: Loss after epoch 5: 20.35624349117279
2022-08-30 20:39:26,469: Loss after epoch 10: 17.04476749897003
2022-08-30 20:39:28,241: Loss after epoch 15: 16.222278952598572
2022-08-30 20:39:30,051: Loss after epoch 20: 15.673559188842773
2022-08-30 20:39:31,575: Loss after epoch 25: 15.100891947746277
2022-08-30 20:39:33,314: Loss after epoch 30: 14.537290573120117
2022-08-30 20:39:35,251: Loss after epoch 35: 14.1690274477005
2022-08-30 20:39:36,708: Loss after epoch 40: 13.666248321533203
2022-08-30 20:39:38,413: Loss after epoch 45: 13.398423433303833
2022-08-30 20:39:39,822: Weight: 300.0
2022-08-30 20:39:39,824: Registering dataset.
2022-08-30 20:39:39,826: Adding entry 'labels' to dataset.
2022-08-30 20:39:39,827: Computing global relations 'rel'.
2022-08-30 20:39:39,828: Indexing nearest neighbors.
2022-08-30 20:39:41,022: Calculating probabilities.
2022-08-30 20:39:41,249: Beginning training phase 'None'.
2022-08-30 20:39:41,670: Loss after epoch 0: 51.858256340026855
2022-08-30 20:39:43,494: Loss after epoch 5: 27.09006428718567
2022-08-30 20:39:45,283: Loss after epoch 10: 23.82643985748291
2022-08-30 20:39:47,192: Loss after epoch 15: 22.54674530029297
2022-08-30 20:39:48,970: Loss after epoch 20: 21.88172221183777
2022-08-30 20:39:50,720: Loss after epoch 25: 21.147352695465088
2022-08-30 20:39:52,489: Loss after epoch 30: 20.724207282066345
2022-08-30 20:39:54,252: Loss after epoch 35: 20.065298557281494
2022-08-30 20:39:56,027: Loss after epoch 40: 19.532976269721985
2022-08-30 20:39:57,778: Loss after epoch 45: 19.268828511238098
2022-08-30 20:39:59,342: Weight: 500.0
2022-08-30 20:39:59,343: Registering dataset.
2022-08-30 20:39:59,344: Adding entry 'labels' to dataset.
2022-08-30 20:39:59,345: Computing global relations 'rel'.
2022-08-30 20:39:59,345: Indexing nearest neighbors.
2022-08-30 20:40:00,514: Calculating probabilities.
2022-08-30 20:40:00,741: Beginning training phase 'None'.
2022-08-30 20:40:01,113: Loss after epoch 0: 72.07359838485718
2022-08-30 20:40:02,873: Loss after epoch 5: 43.112019538879395
2022-08-30 20:40:04,707: Loss after epoch 10: 37.126792430877686
2022-08-30 20:40:06,214: Loss after epoch 15: 35.245335817337036
2022-08-30 20:40:07,941: Loss after epoch 20: 34.216312885284424
2022-08-30 20:40:09,595: Loss after epoch 25: 32.9052791595459
2022-08-30 20:40:11,364: Loss after epoch 30: 32.117218255996704
2022-08-30 20:40:13,099: Loss after epoch 35: 31.11882710456848
2022-08-30 20:40:14,774: Loss after epoch 40: 30.738272190093994
2022-08-30 20:40:16,531: Loss after epoch 45: 30.457290410995483

Plotting the Results

Once the training has completed (which might take while, since we train 8 different routines from scratch), we can take a look at the results.

paraDime’s scatterplot utility function accepts an ax keyword parameter. We can create a grid of Matplotlib axes and, in a loop, pass each axis as ax to create the scatterplot inside the grid. To the final grid cell we add a plot of the accuracies as a function of the losses’ weights:

[7]:
fig = plt.figure(figsize=(15, 15))

for i, (emb, w) in enumerate(zip(embeddings, weights)):
    ax = fig.add_subplot(3, 3, i + 1)

    paradime.utils.plotting.scatterplot(
        emb,
        labels=target_subset,
        ax=ax,
        legend=(i == 0),
        legend_options={"loc": 3},
    )
    ax.set_title(f"w_emb / w_class = {w}")

palette = paradime.utils.plotting.get_color_palette()
ax = fig.add_subplot(3, 3, 9)
ax.plot(weights, train_accuracies, c=palette["petrol"])
ax.plot(weights, test_accuracies, c=palette["aqua"])
ax.set_xscale("log")
ax.legend(["train", "test"])
ax.set_title("Classification accuracy");
../_images/examples_hybrid_14_0.png

The plot in the top left corner shows the embedding of the latent space for pure classification. While the embedding function is random here (because the embedding branch of our model was not trained at all), we can still see some clusters because our model has learned to group classes together in the latent space. With increasing weight on the embedding loss, the plot starts to look more like we would expect a t-SNE of MNIST to look like. Despite this, the classification accuracy remains high, as cen be seen in the accuracy plot on the bottom right. Only for very high weights of the embedding loss, the train accuracy drops. Interetingly though, the test accuracy ramains stable (maybe even increasing ever so slightly). This means that the strong focus on embedding does not hurt the classifier to generalize at all.

In summary, we have succesfully trained a model that can perform both tasks, classification and t-SNE-like embedding, pretty well.