Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/terrafloww/rasteret/llms.txt

Use this file to discover all available pages before exploring further.

Rasteret integrates with TorchGeo to provide a seamless path from STAC search to PyTorch DataLoader. This guide shows you how to build Collections, assign train/val/test splits, and create TorchGeo datasets.

Quick Start

import rasteret
from torch.utils.data import DataLoader
from torchgeo.datasets.utils import stack_samples
from torchgeo.samplers import RandomGeoSampler

# 1. Build a Collection
collection = rasteret.build(
    "earthsearch/sentinel-2-l2a",
    name="training-data",
    bbox=(11.3, 48.1, 11.5, 48.3),
    date_range=("2024-01-01", "2024-06-30"),
)

# 2. Create a TorchGeo dataset
dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02", "B08"],  # RGB + NIR
    chip_size=256,
    is_image=True,
)

# 3. Create a sampler and DataLoader
sampler = RandomGeoSampler(dataset, size=256, length=1000)
loader = DataLoader(
    dataset,
    sampler=sampler,
    batch_size=8,
    num_workers=4,
    collate_fn=stack_samples,
)

# 4. Train
for batch in loader:
    images = batch["image"]  # [batch, channels, height, width]
    # ... model(images) ...

Assigning Train/Val/Test Splits

For reproducible ML workflows, assign splits to your Collection using PyArrow:
import numpy as np
import pyarrow as pa
import pyarrow.dataset as ds
from pathlib import Path

def assign_splits(
    collection,
    output_path: Path,
    train_ratio: float = 0.7,
    val_ratio: float = 0.15,
    seed: int = 42,
):
    """Add a 'split' column and save the Collection."""
    table = collection.dataset.to_table()
    n = len(table)

    # Deterministic random assignment
    rng = np.random.default_rng(seed)
    assignments = rng.random(n)
    splits = np.where(
        assignments < train_ratio,
        "train",
        np.where(assignments < train_ratio + val_ratio, "val", "test"),
    )

    # Add split column
    table = table.append_column("split", pa.array(splits))

    # Save as partitioned Parquet
    output_path.mkdir(parents=True, exist_ok=True)
    ds.write_dataset(
        table,
        output_path,
        format="parquet",
        partitioning=["year", "month"],
        existing_data_behavior="overwrite_or_ignore",
    )

    return rasteret.load(output_path, name=collection.name)

# Use it
collection = assign_splits(
    collection,
    output_path=Path("./collections/training_with_splits"),
    train_ratio=0.7,
    val_ratio=0.15,
)
See /home/daytona/workspace/source/examples/ml_training_with_splits.py:1 for a complete example.

Creating TorchGeo Datasets

Basic Image Dataset

# RGB imagery for classification
dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],  # Red, Green, Blue
    chip_size=224,
    split="train",  # Use training split only
)

print(f"Dataset length: {len(dataset)}")
print(f"CRS: {dataset.crs}")
print(f"Bounds: {dataset.bounds}")

Multi-Spectral Dataset

# All Sentinel-2 bands at 10m/20m
dataset = collection.to_torchgeo_dataset(
    bands=["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"],
    chip_size=256,
    split="train",
    allow_resample=True,  # Resample 20m bands to 10m grid
)
Important: When bands have different native resolutions (e.g. 10m vs 20m), set allow_resample=True. Otherwise, Rasteret will raise an error if you request mixed-resolution bands.

Mask Dataset (Single-Band)

# Land cover classification mask
mask_dataset = collection.to_torchgeo_dataset(
    bands=["SCL"],  # Scene Classification Layer
    chip_size=256,
    is_image=False,  # Return as 'mask' (channel dimension squeezed)
    split="train",
)

# Sample returns {'mask': tensor, 'crs': ..., 'bbox': ...}

TorchGeo Samplers

Random Spatial Sampling

from torchgeo.samplers import RandomGeoSampler

sampler = RandomGeoSampler(
    dataset,
    size=256,      # Chip size in pixels (matches dataset chip_size)
    length=5000,   # Number of samples per epoch
)

Grid Sampling

from torchgeo.samplers import GridGeoSampler

# Exhaustive grid over the dataset extent
sampler = GridGeoSampler(
    dataset,
    size=256,
    stride=256,  # Non-overlapping tiles
)

Pre-Defined Geometries

For labeled datasets, you can filter the Collection to specific geometries:
import geopandas as gpd

# Load labeled polygons
labels = gpd.read_file("labels.geojson")

# Filter Collection to labeled areas
labeled_collection = collection.subset(geometries=labels.geometry)

# Create dataset
dataset = labeled_collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
    split="train",
)

Training Loop Example

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchgeo.datasets.utils import stack_samples
from torchgeo.samplers import RandomGeoSampler

# Setup
train_ds = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02", "B08"],
    chip_size=256,
    split="train",
)

val_ds = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02", "B08"],
    chip_size=256,
    split="val",
)

train_sampler = RandomGeoSampler(train_ds, size=256, length=1000)
train_loader = DataLoader(
    train_ds,
    sampler=train_sampler,
    batch_size=16,
    num_workers=4,
    collate_fn=stack_samples,
)

val_sampler = RandomGeoSampler(val_ds, size=256, length=200)
val_loader = DataLoader(
    val_ds,
    sampler=val_sampler,
    batch_size=16,
    num_workers=4,
    collate_fn=stack_samples,
)

# Model (example)
model = nn.Sequential(
    nn.Conv2d(4, 64, 3, padding=1),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    nn.Linear(64, 10),
)
model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(10):
    model.train()
    for batch in train_loader:
        images = batch["image"].cuda()  # [B, C, H, W]
        # Mock labels for demo (replace with real labels)
        labels = torch.randint(0, 10, (images.size(0),)).cuda()

        optimizer.zero_grad()
        outputs = model(images.float())
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            images = batch["image"].cuda()
            labels = torch.randint(0, 10, (images.size(0),)).cuda()
            outputs = model(images.float())
            val_loss += criterion(outputs, labels).item()

    print(f"Epoch {epoch}: val_loss={val_loss/len(val_loader):.4f}")

Time-Series Models

For temporal models (e.g. LSTM, Transformer), enable time-series mode:
dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=128,
    split="train",
    time_series=True,  # Stack all timesteps
)

# Sample returns: {'image': tensor[T, C, H, W], ...}
# T = number of timesteps overlapping the sampled bbox

Advanced: Custom Labels

Add a label column to your Collection using PyArrow, then pass label_field to to_torchgeo_dataset():
import pyarrow as pa

# Add labels (example: random for demo)
table = collection.dataset.to_table()
labels = pa.array(np.random.randint(0, 10, len(table)))
table = table.append_column("label", labels)

# Wrap as a new Collection
labeled = rasteret.as_collection(
    table,
    name=collection.name,
    data_source=collection.data_source,
)

# Create dataset with labels
dataset = labeled.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
    label_field="label",  # Include in samples
)

# Sample returns: {'image': tensor, 'label': int, ...}

Data Augmentation

Use TorchGeo transforms or standard torchvision augmentations:
import torchvision.transforms as T

transforms = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(90),
])

dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
    split="train",
    transforms=transforms,
)

Handling Authentication

For datasets requiring credentials (Planetary Computer, NASA Earthdata), create a backend:
import rasteret
from obstore.auth.planetary_computer import PlanetaryComputerCredentialProvider

backend = rasteret.create_backend(
    credential_provider=PlanetaryComputerCredentialProvider(
        "https://planetarycomputer.microsoft.com/api/sas/v1/token"
    ),
)

# Pass backend to to_torchgeo_dataset()
dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
    backend=backend,
)
See the Cloud Authentication guide for details.

Best Practices

Dataset Size

  • Small AOIs (< 1000 km²): Build a single Collection
  • Large AOIs (> 10,000 km²): Consider spatial or temporal partitioning
  • Global training: Use pre-built GeoParquet indexes (e.g. Source Cooperative)

Caching

# Build once, use in all training runs
collection = rasteret.build(
    "earthsearch/sentinel-2-l2a",
    name="training-v1",
    bbox=BBOX,
    date_range=DATE_RANGE,
    force=False,  # Reuse cache
)

# Add splits and export
collection_with_splits = assign_splits(collection, output_path)
collection_with_splits.export("./artifacts/training_collection_v1")

# Share with teammates (they don't need STAC API access)

Multi-Resolution Bands

Sentinel-2 bands have different resolutions:
  • 10m: B02, B03, B04, B08
  • 20m: B05, B06, B07, B8A, B11, B12
  • 60m: B01, B09
When mixing resolutions, set allow_resample=True and choose a target_crs (e.g. UTM zone) for consistent grids:
dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02", "B8A"],  # 10m + 20m
    chip_size=256,
    allow_resample=True,
    target_crs=32630,  # UTM zone 30N
)

Dtype Handling

Rasteret returns native COG dtypes:
  • Sentinel-2: uint16
  • Landsat: uint16
  • NAIP: uint8
TorchGeo converts uint16int32 and uint32int64 for PyTorch compatibility. Normalize in your model:
images = batch["image"].float()  # Convert to float32
images = images / 10000.0  # Scale to reflectance (Sentinel-2)

Troubleshooting

Empty Samples

If you’re getting all-zero or NaN samples:
  1. Check that your AOI overlaps the Collection bounds: collection.bounds
  2. Verify scenes exist: len(collection.subset(split="train"))
  3. Inspect a sample manually:
    sample = dataset[dataset.bounds]
    print(sample["image"].shape, sample["image"].min(), sample["image"].max())
    

Mixed CRS Errors

If scenes have different CRS (rare), set target_crs to reproject at read time:
dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
    target_crs=32630,  # Force all scenes to UTM 30N
)

Slow Data Loading

If training is I/O bound:
  1. Increase max_concurrent (COG fetch concurrency):
    dataset = collection.to_torchgeo_dataset(
        bands=["B04", "B03", "B02"],
        chip_size=256,
        max_concurrent=200,  # Default is 50
    )
    
  2. Use num_workers in DataLoader (parallelism across batches)
  3. Consider prefetching scenes to local disk (outside Rasteret scope)

Next Steps