Train a machine learning model on a collection#
Here, we iterate over the artifacts within a collection to train a machine learning model at scale.
import lamindb as ln
import anndata as ad
import numpy as np
๐ก lamindb instance: testuser1/test-scrna
ln.track()
๐ก notebook imports: anndata==0.9.2 lamindb==0.67.2 numpy==1.26.3 torch==2.1.2
๐ก saved: Transform(uid='Qr1kIHvK506r5zKv', name='Train a machine learning model on a collection', short_name='scrna5', version='1', type=notebook, updated_at=2024-01-24 13:38:20 UTC, created_by_id=1)
๐ก saved: Run(uid='RLUdmi0MJv6ooCFMCsS6', run_at=2024-01-24 13:38:20 UTC, transform_id=5, created_by_id=1)
Query our collection:
collection = ln.Collection.filter(
name="My versioned scRNA-seq collection", version="2"
).one()
collection.describe()
Show code cell output
Collection(uid='5rQPk6jQmbjiJEGvAHaw', name='My versioned scRNA-seq collection', version='2', hash='BOAf0T5UbN_iOe3fQDyq', visibility=1, updated_at=2024-01-24 13:37:57 UTC)
Provenance:
๐ transform: Transform(uid='ManDYgmftZ8C5zKv', name='Standardize and append a batch of data', short_name='scrna2', version='1', type='notebook', updated_at=2024-01-24 13:37:44 UTC, created_by_id=1)
๐ฃ run: Run(uid='NI5PiOJiQX8AWaRqFHuD', run_at=2024-01-24 13:37:44 UTC, transform_id=2, created_by_id=1)
๐ค created_by: User(uid='DzTjkKse', handle='testuser1', name='Test User1', updated_at=2024-01-24 13:37:11 UTC)
โฌ๏ธ input_of (core.Run): ['2024-01-24 13:38:09 UTC']
Features:
var: FeatureSet(uid='GgyyMaxOALjCcPCIm3Yq', n=36390, type='number', registry='bionty.Gene', hash='gRQGj3QB8ZsIfXA1BjiL', updated_at=2024-01-24 13:37:35 UTC, created_by_id=1)
'MIR1302-2HG', 'FAM138A', 'OR4F5', 'None', 'None', 'None', 'None', 'None', 'None', 'None', 'OR4F29', 'None', 'OR4F16', 'None', 'LINC01409', 'FAM87B', 'LINC01128', 'LINC00115', 'FAM41C', 'None', ...
obs: FeatureSet(uid='4TEcPnmG7T3dRdQymmyY', n=4, registry='core.Feature', hash='jKvG9U7UNNKJMc6t0H99', updated_at=2024-01-24 13:37:36 UTC, created_by_id=1)
๐ cell_type (40, bionty.CellType): 'dendritic cell', 'B cell, CD19-positive', 'effector memory CD4-positive, alpha-beta T cell, terminally differentiated', 'cytotoxic T cell', 'CD8-positive, CD25-positive, alpha-beta regulatory T cell', 'CD14-positive, CD16-negative classical monocyte', 'CD38-positive naive B cell', 'CD4-positive, alpha-beta T cell', 'classical monocyte', 'T follicular helper cell', ...
๐ assay (4, bionty.ExperimentalFactor): '10x 3' v3', '10x 5' v2', '10x 5' v1', 'single-cell RNA sequencing'
๐ tissue (17, bionty.Tissue): 'blood', 'thoracic lymph node', 'spleen', 'lung', 'mesenteric lymph node', 'lamina propria', 'liver', 'jejunal epithelium', 'omentum', 'bone marrow', ...
๐ donor (12, core.ULabel): 'D496', '621B', 'A29', 'A36', 'A35', '637C', 'A52', 'A37', 'D503', '640C', ...
external: FeatureSet(uid='kK8VDsHHaGBf4MOeBl3M', n=2, registry='core.Feature', hash='Cd6sfM0NoF0o0l1mYdrj', updated_at=2024-01-24 13:37:56 UTC, created_by_id=1)
๐ assay (4, bionty.ExperimentalFactor): '10x 3' v3', '10x 5' v2', '10x 5' v1', 'single-cell RNA sequencing'
๐ organism (1, bionty.Organism): 'human'
Labels:
๐ท๏ธ organism (1, bionty.Organism): 'human'
๐ท๏ธ tissues (17, bionty.Tissue): 'blood', 'thoracic lymph node', 'spleen', 'lung', 'mesenteric lymph node', 'lamina propria', 'liver', 'jejunal epithelium', 'omentum', 'bone marrow', ...
๐ท๏ธ cell_types (40, bionty.CellType): 'dendritic cell', 'B cell, CD19-positive', 'effector memory CD4-positive, alpha-beta T cell, terminally differentiated', 'cytotoxic T cell', 'CD8-positive, CD25-positive, alpha-beta regulatory T cell', 'CD14-positive, CD16-negative classical monocyte', 'CD38-positive naive B cell', 'CD4-positive, alpha-beta T cell', 'classical monocyte', 'T follicular helper cell', ...
๐ท๏ธ experimental_factors (4, bionty.ExperimentalFactor): '10x 3' v3', '10x 5' v2', '10x 5' v1', 'single-cell RNA sequencing'
๐ท๏ธ ulabels (12, core.ULabel): 'D496', '621B', 'A29', 'A36', 'A35', '637C', 'A52', 'A37', 'D503', '640C', ...
Create a map-style dataset#
Let us create a map-style dataset using using mapped()
: a MappedCollection
. This is what, for example, the PyTorch DataLoader
expects as an input.
Under-the-hood, it performs a virtual inner join of the features of the underlying AnnData
objects and thus allows to work with very large collections.
You can either perform a virtual inner join:
with collection.mapped(label_keys=["cell_type"], join="inner") as dataset:
print(len(dataset.var_joint))
749
Or a virtual outer join:
dataset = collection.mapped(label_keys=["cell_type"], join="outer")
len(dataset.var_joint)
36503
This is compatible with a PyTorch DataLoader
because it implements __getitem__
over a list of backed AnnData
objects.
The 5th cell in the collection can be accessed like:
dataset[5]
Show code cell output
[array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), 19]
The labels
are encoded into integers:
dataset.encoders
Show code cell output
[{'plasmablast': 0,
'megakaryocyte': 1,
'lymphocyte': 2,
'naive thymus-derived CD8-positive, alpha-beta T cell': 3,
'animal cell': 4,
'CD8-positive, alpha-beta memory T cell, CD45RO-positive': 5,
'CD4-positive helper T cell': 6,
'CD8-positive, alpha-beta memory T cell': 7,
'dendritic cell': 8,
'regulatory T cell': 9,
'effector memory CD8-positive, alpha-beta T cell, terminally differentiated': 10,
'naive B cell': 11,
'effector memory CD4-positive, alpha-beta T cell, terminally differentiated': 12,
'gamma-delta T cell': 13,
'classical monocyte': 14,
'CD16-positive, CD56-dim natural killer cell, human': 15,
'effector memory CD4-positive, alpha-beta T cell': 16,
'group 3 innate lymphoid cell': 17,
'non-classical monocyte': 18,
'cytotoxic T cell': 19,
'alveolar macrophage': 20,
'B cell, CD19-positive': 21,
'progenitor cell': 22,
'mast cell': 23,
'plasma cell': 24,
'alpha-beta T cell': 25,
'naive thymus-derived CD4-positive, alpha-beta T cell': 26,
'macrophage': 27,
'CD8-positive, CD25-positive, alpha-beta regulatory T cell': 28,
'T follicular helper cell': 29,
'mucosal invariant T cell': 30,
'dendritic cell, human': 31,
'CD4-positive, alpha-beta T cell': 32,
'CD16-negative, CD56-bright natural killer cell, human': 33,
'CD14-positive, CD16-negative classical monocyte': 34,
'CD38-positive naive B cell': 35,
'plasmacytoid dendritic cell': 36,
'memory B cell': 37,
'conventional dendritic cell': 38,
'germinal center B cell': 39}]
Create a pytorch DataLoader#
Let us use a weighted sampler:
from torch.utils.data import DataLoader, WeightedRandomSampler
# label_key for weight doesn't have to be in labels on init
sampler = WeightedRandomSampler(
weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)
We can now iterate through the data loader:
for batch in dataloader:
pass
Close the connections in MappedCollection
:
dataset.close()
In practice, use a context manager
with collection.mapped(label_keys=["cell_type"]) as dataset:
sampler = WeightedRandomSampler(
weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)
for batch in dataloader:
pass