"""Cross-section dataset for Stage 1 tokenizer."""

import json
from pathlib import Path

import numpy as np
import torch
import zarr
from torch.utils.data import Dataset


class CrossSectionDataset(Dataset):
    """Daily cross-sections from Zarr store."""

    def __init__(
        self,
        zarr_path: str,
        split: str = "train",
        max_assets: int = 512,
        min_valid_features: int = 1,  # Asset valid if >= N features are non-NaN
        cache_in_memory: bool = True,  # Cache for speed (recommended)
    ):
        zarr_path = Path(zarr_path)
        self.root = zarr.open_group(str(zarr_path), mode="r")
        self.X_zarr = self.root["X"]
        self.max_assets = max_assets
        self.min_valid_features = min_valid_features
        self.cache_in_memory = cache_in_memory

        # Load manifest for split boundaries
        manifest_path = zarr_path.parent / "manifest.json"
        with open(manifest_path) as f:
            self.manifest = json.load(f)

        splits = self.manifest["splits"]
        if split == "train":
            start, end = 0, splits["train"]["end_idx"]
        elif split == "val":
            start, end = splits["train"]["end_idx"], splits["val"]["end_idx"]
        elif split == "test":
            start, end = splits["val"]["end_idx"], splits["test"]["end_idx"]
        else:
            raise ValueError(f"Unknown split: {split}")

        self.indices = list(range(start, end))
        self.split = split
        
        # Cache entire split in memory for fast access
        if cache_in_memory:
            print(f"Caching {split} split in memory ({len(self.indices)} days)...")
            self._cache_X = []
            self._cache_mask = []
            for t in self.indices:
                X_t = np.array(self.X_zarr[t, :max_assets, :])
                valid_t = (~np.isnan(X_t)).sum(axis=1) >= min_valid_features
                X_t = np.nan_to_num(X_t, nan=0.0, posinf=0.0, neginf=0.0)
                X_t = np.where(valid_t[:, None], X_t, 0.0)
                self._cache_X.append(torch.from_numpy(X_t).float())
                self._cache_mask.append(torch.from_numpy(valid_t).bool())
            print(f"Cached {len(self._cache_X)} days in memory")

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        t = self.indices[idx]
        
        if self.cache_in_memory:
            return {
                "X": self._cache_X[idx],
                "mask": self._cache_mask[idx],
                "day_idx": t,
            }

        # Fallback: read from Zarr (slow on NFS)
        X_t = np.array(self.X_zarr[t, : self.max_assets, :])
        valid_t = (~np.isnan(X_t)).sum(axis=1) >= self.min_valid_features
        X_t = np.nan_to_num(X_t, nan=0.0, posinf=0.0, neginf=0.0)
        X_t = np.where(valid_t[:, None], X_t, 0.0)

        return {
            "X": torch.from_numpy(X_t).float(),
            "mask": torch.from_numpy(valid_t).bool(),
            "day_idx": t,
        }

    @property
    def num_features(self):
        return self.X_zarr.shape[-1]

    @property
    def num_assets(self):
        return min(self.X_zarr.shape[1], self.max_assets)
