"""
Dataset classes for handling graph flow snapshots.

This module provides PyTorch Dataset classes for loading and managing
time-series data of discrete probability flows on graphs.
"""

import os
import json
import torch
from torch.utils.data import Dataset
from typing import Optional


def _load_array(path: str):
    """Load numpy or PyTorch array from file."""
    ext = os.path.splitext(path)[1].lower()
    if ext in (".pt", ".pth"):
        return torch.load(path, map_location="cpu")
    if ext == ".npy":
        import numpy as np
        return torch.from_numpy(np.load(path, allow_pickle=False))
    raise ValueError(f"Unsupported extension: {ext}")


class GraphFlowSamplesDataset(Dataset):
    """
    Dataset for T snapshots of discrete flow dynamics on a fixed n-node graph.
    
    Each snapshot contains:
        - samples_tm[t]: (M,) sampled node indices at time t
        - v_mat_seq[t]: (n,n) antisymmetric ground-truth field at time t
        - rho_gt_seq[t]: (n,) optional ground-truth density
        
    Args:
        samples_tm: Tensor of shape (T, M) containing node samples over time
        v_mat_seq: Tensor of shape (T, n, n) containing vector fields
        n: Number of nodes in the graph
        rho_gt_seq: Optional tensor of shape (T, n) with ground-truth densities
        smoothing: Smoothing constant for density estimation
        dtype: Data type for tensors
        metadata: Dictionary of additional metadata
    """
    
    def __init__(
        self,
        samples_tm: torch.Tensor,
        v_mat_seq: torch.Tensor,
        n: int,
        rho_gt_seq: Optional[torch.Tensor] = None,
        smoothing: float = 1e-8,
        dtype=torch.float64,
        metadata=None
    ):
        assert samples_tm.ndim == 2, "samples_tm must be (T,M)"
        assert v_mat_seq.ndim == 3 and v_mat_seq.shape[1] == v_mat_seq.shape[2] == n
        T = samples_tm.shape[0]
        assert v_mat_seq.shape[0] == T, "T mismatch"
        if rho_gt_seq is not None:
            assert rho_gt_seq.shape == (T, n), "rho_gt_seq must be (T,n)"

        self.samples_tm = samples_tm.long()
        self.v_mat_seq = v_mat_seq.to(dtype)
        self.rho_gt_seq = None if rho_gt_seq is None else rho_gt_seq.to(dtype)
        self.T, self.M = self.samples_tm.shape
        self.n = int(n)
        self.smoothing = float(smoothing)
        self.dtype = dtype
        self.metadata = metadata or {}

    @classmethod
    def from_folder(cls, folder: str, smoothing: Optional[float] = None, dtype=torch.float64):
        """
        Load dataset from a folder containing samples_tm, v_mat_seq, and optional rho_gt_seq files.
        
        Args:
            folder: Path to data folder
            smoothing: Smoothing constant (overrides metadata if provided)
            dtype: Data type for tensors
            
        Returns:
            GraphFlowSamplesDataset instance
        """
        def find(stem):
            for ext in (".pt", ".pth", ".npy"):
                p = os.path.join(folder, stem + ext)
                if os.path.isfile(p):
                    return p
            return None

        p_samples = find("samples_tm")
        p_vmat = find("v_mat_seq")
        p_rho = find("rho_gt_seq")
        
        if not p_samples or not p_vmat:
            raise FileNotFoundError("Need samples_tm.* and v_mat_seq.* in the folder.")

        samples_tm = _load_array(p_samples)
        v_mat_seq = _load_array(p_vmat)
        rho_gt_seq = _load_array(p_rho) if p_rho else None

        n = int(v_mat_seq.shape[1])
        meta = {}
        mpath = os.path.join(folder, "metadata.json")
        if os.path.isfile(mpath):
            with open(mpath) as f:
                meta = json.load(f)
            if "n" in meta and meta["n"] != n:
                raise ValueError(f"metadata n={meta['n']} != data n={n}")

        smooth = smoothing if smoothing is not None else meta.get("smoothing", 1e-8)
        return cls(
            samples_tm, v_mat_seq, n,
            rho_gt_seq=rho_gt_seq,
            smoothing=smooth,
            dtype=dtype,
            metadata=meta
        )

    def __len__(self):
        return self.T

    def _rho_hat(self, t: int):
        """Estimate density from samples with smoothing."""
        idxs = self.samples_tm[t]
        counts = torch.bincount(idxs, minlength=self.n).to(self.dtype)
        rho = counts + self.smoothing
        rho = rho / rho.sum()
        return rho, torch.log(torch.clamp(rho, min=1e-16))

    def __getitem__(self, t: int):
        """
        Get snapshot data at time index t.
        
        Returns:
            Dictionary containing samples, v_mat, rho_hat, logrho_hat, all_nodes,
            and optionally rho_gt and logrho_gt if ground truth is available.
        """
        samples_t = self.samples_tm[t]
        v_mat_t = self.v_mat_seq[t]
        rho_hat_t, logrho_hat_t = self._rho_hat(t)
        
        item = {
            "samples": samples_t,
            "v_mat": v_mat_t,
            "rho_hat": rho_hat_t,
            "logrho_hat": logrho_hat_t,
            "all_nodes": torch.arange(self.n, dtype=torch.long),
        }
        
        if self.rho_gt_seq is not None:
            rho_gt_t = self.rho_gt_seq[t]
            item["rho_gt"] = rho_gt_t
            item["logrho_gt"] = torch.log(torch.clamp(rho_gt_t, min=1e-16))
            
        return item
