Train a machine learning model on a collection¶
Here, we iterate over the artifacts within a collection to train a machine learning model at scale.
import lamindb as ln
ln.track("Qr1kIHvK506r0002")
Query our collection:
collection = ln.Collection.get(key="scrna/collection1", version="2")
collection.describe()
Create a map-style dataset¶
Let us create a map-style dataset using using mapped()
: a MappedCollection
.
Under-the-hood, it performs a virtual join of the features of the underlying AnnData
objects without loading the datasets into memory. You can either perform an inner join:
with collection.mapped(obs_keys=["cell_type"], join="inner") as dataset:
print("#observations", dataset.shape[0])
print("#variables:", len(dataset.var_joint))
Or an outer join:
dataset = collection.mapped(obs_keys=["cell_type"], join="outer")
print("#variables:", len(dataset.var_joint))
This is compatible with a PyTorch DataLoader
because it implements __getitem__
over a list of backed AnnData
objects.
For instance, the 5th observation in the collection can be accessed via:
dataset[5]
The labels
are encoded into integers:
dataset.encoders
It is also possible to create a dataset by selecting only observations with certain values of an .obs
column. Setting obs_filter
in the below example makes the dataset iterate only over observations having CD16-positive, CD56-dim natural killer cell, human
or macrophage
in .obs
column cell_type
across all AnnData
objects.
select_by_cell_type = (
"CD16-positive, CD56-dim natural killer cell, human",
"macrophage",
)
with collection.mapped(obs_filter=("cell_type", select_by_cell_type)) as dataset_filter:
print(dataset_filter.shape)
Create a pytorch DataLoader¶
Let us use a weighted sampler:
from torch.utils.data import DataLoader, WeightedRandomSampler
# label_key for weight doesn't have to be in labels on init
sampler = WeightedRandomSampler(
weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)
We can now iterate through the data loader:
for batch in dataloader:
pass
Close the connections in MappedCollection
:
dataset.close()
In practice, use a context manager
with collection.mapped(obs_keys=["cell_type"]) as dataset:
sampler = WeightedRandomSampler(
weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)
for batch in dataloader:
pass