#!pip install torchgeo

import sys
from pathlib import Path
project_root = str(Path(__file__).resolve().parent.parent)
if project_root not in sys.path:
    sys.path.append(project_root)
import project_config

# download the data manually
#rsync -avz --progress rsync://m1483140@dataserv.ub.tum.de/m1483140/ /ewsc/ewsc/so2sat/
# password: m1483140

# references:
#https://torchgeo.readthedocs.io/en/stable/api/datasets.html#so2sat
#https://ieeexplore.ieee.org/document/9014553

from torchgeo.datasets import So2Sat
from torch.utils.data import DataLoader
import torch

# 1. Initialize dataset (Download set to True if you don't have it)
# 'bands' can be 'rgb' (just optical RGB) or 's2' (all 10 optical bands)
train_dataset = So2Sat(
    root=project_config.SO2SAT_DATA_ROOT, 
    version="2",
    split="train", 
    transforms=None,
    checksum=False
)
# default bands: ('S1_B1', 'S1_B2', 'S1_B3', 'S1_B4', 'S1_B5', 'S1_B6', 'S1_B7', 'S1_B8', 'S2_B02', 'S2_B03', 'S2_B04', 'S2_B05', 'S2_B06', 'S2_B07', 'S2_B08', 'S2_B8A', 'S2_B11', 'S2_B12')

# print all attributes of the dataset
print(dir(train_dataset))
print(train_dataset.all_band_names)
print(train_dataset.classes)

# 2. Create Loader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 3. Inspect a batch
for batch in train_loader:
    # So2Sat returns a dictionary
    # print all keys in the batch
    print(batch.keys())
    image = batch["image"]  # Shape: (B, 18, 32, 32) first 8 bands are S1 (radar), last 10 bands are S2 (optical)
    label = batch["label"]  # Shape: (B,) <- Class Index (0-16)
    
    print(f"Image shape: {image.shape}")
    print(f"Label: {label}")
    break

# bands B02, B03, B04 correspond to RGB
rgb_image = image[:, 8:11, :, :]  # Extract RGB bands
print(f"RGB Image shape: {rgb_image.shape}")

# for multimodal learning, we want to use both radar and optical data
# but from the radar we care most about intensities?

# we need to split the data into a list of two tensors: radar and optical
radar_image = image[:, 0:8, :, :]  # Extract S1 bands
optical_image = image[:, 8:11, :, :]  # Extract S2 bands