from __future__ import annotations

from dataclasses import dataclass, field
from typing import Literal

import jax
import jax.numpy as jnp
import numpy as np

from cyreal.datasets.dataset_protocol import DatasetProtocol
from cyreal.datasets.time_utils import prepare_seq_to_seq_windows
from taming_the_ito_lyon.config import Config
from taming_the_ito_lyon.data.integrity_checks import (
    ensure_b_w_l_c,
    validate_window_alignment,
)
from cyreal.sources import DiskSource


@dataclass
class SO3DynamicsSim(DatasetProtocol):
    config: Config
    split: Literal["train", "val", "test"]
    ordering: Literal["sequential", "shuffle"] = field(init=False)
    _driver_np: np.ndarray = field(init=False, repr=False)
    _solution_np: np.ndarray = field(init=False, repr=False)
    _num_windows: int = field(init=False, repr=False)
    _num_examples: int = field(init=False, repr=False)
    _dataset_len: int = field(init=False, repr=False)

    def __post_init__(self) -> None:
        self.ordering = "shuffle" if self.split == "train" else "sequential"

        data = np.load(self.config.experiment_config.dataset_name.value)
        dt_sim = float(
            np.asarray(data["dt"])
        )  # This is a scalar dt for the whole simulation
        skip = int(0.1 / dt_sim)  # keep every 10th frame
        if skip < 1:
            skip = 1
        # print(f"Downsample rate: {skip} (effective dt = {dt_sim * skip})")

        train_fraction = float(self.config.experiment_config.train_fraction)
        val_fraction = float(self.config.experiment_config.val_fraction)

        damping0_rotmat = data["R_sim_damped0"]
        damping1_rotmat = data["R_sim_damped1"]
        damping2_rotmat = data["R_sim_damped2"]
        damping3_rotmat = data["R_sim_damped3"]
        # Flatten 3x3 -> 9 so the model sees shape (B, T, 9)
        rotmats = np.stack(
            [damping0_rotmat, damping1_rotmat, damping2_rotmat, damping3_rotmat],
            axis=2,
        )
        batch_size, timesteps, boxes, _, _ = rotmats.shape
        # (Batch * 4, Timesteps, 9)
        rotmats_flat = rotmats.transpose(0, 2, 1, 3, 4).reshape(
            batch_size * boxes, timesteps, 9
        )
        # Downsample first, then build 20-step windows on the downsampled sequence.
        # This makes each window correspond to indices:
        #   start : start + skip*20 : skip
        # on the original simulation grid, while the sliding-window stride is 1 on
        # the downsampled grid.
        rotmats_flat = rotmats_flat[:, ::skip, :]
        # NOTE: `prepare_seq_to_seq_windows` returns NumPy arrays. With the updated
        # `cyreal.datasets.time_utils.sliding_window_many`, these are typically
        # *views* (stride-trick windows), not fully materialized copies.
        driver_np, solution_np = prepare_seq_to_seq_windows(
            input_sequence=rotmats_flat,
            target_sequence=rotmats_flat,
            split=self.split,
            input_window_len=21,  # we will fit the polynomaial to 12 (n_recon) points only
            target_window_len=21,
            train_fraction=train_fraction,
            val_fraction=val_fraction,
            sliding_window_stride=1,
        )

        # `np.lib.stride_tricks.sliding_window_view` appends the window axis at the end.
        # For sequences shaped (B, T, C) this yields windows shaped (B, W, C, L).
        # We want (B, W, L, C) for downstream code and batching.
        #
        # Historically this was hard-coded for L=20; instead, normalize by ensuring the
        # channel axis (C=9) is last.
        driver_np = ensure_b_w_l_c("driver", driver_np)
        solution_np = ensure_b_w_l_c("solution", solution_np)
        validate_window_alignment(driver_np, solution_np)

        self._driver_np = driver_np
        self._solution_np = solution_np
        self._num_examples = int(driver_np.shape[0])  # B*4
        self._num_windows = int(driver_np.shape[1])  # windows per example
        self._dataset_len = int(self._num_examples * self._num_windows)

    def __len__(self) -> int:
        # Flattened sample count: one sample per (example, window_start)
        return int(self._dataset_len)

    def __getitem__(self, index: int) -> dict[str, jax.Array]:
        # Keep this cheap: only materialize a single window pair.
        if index < 0 or index >= self._dataset_len:
            raise IndexError(f"Index out of range: {index} (len={self._dataset_len})")
        wi = index % self._num_windows
        bi = index // self._num_windows
        driver = jnp.asarray(self._driver_np[bi, wi], dtype=jnp.float32)  # (T, 9)
        solution_flat = jnp.asarray(
            self._solution_np[bi, wi], dtype=jnp.float32
        )  # (T, 9)
        solution = solution_flat.reshape(solution_flat.shape[0], 3, 3)  # (T, 3, 3)
        return {
            "driver": driver,
            "solution": solution,
        }

    def make_disk_source(
        self,
    ) -> DiskSource:
        driver_np = self._driver_np
        solution_np = self._solution_np

        _, _, ctx_len, channels = driver_np.shape
        _, _, tgt_len, channels2 = solution_np.shape
        if channels2 != channels:
            raise ValueError(
                f"Driver/solution channel mismatch: driver={channels}, solution={channels2}"
            )
        if int(channels) != 9:
            raise ValueError(
                f"Expected SO(3) rotations flattened as 9 channels, got channels={channels}. Total shape: {driver_np.shape}"
            )

        num_windows = int(self._num_windows)

        def _read_sample(index: int | np.ndarray) -> dict[str, np.ndarray]:
            idx = int(np.asarray(index))
            wi = idx % num_windows
            bi = idx // num_windows
            driver = np.asarray(driver_np[bi, wi], dtype=np.float32)  # (T, 9)
            solution_flat = np.asarray(solution_np[bi, wi], dtype=np.float32)  # (T, 9)
            solution = solution_flat.reshape(int(tgt_len), 3, 3)  # (T, 3, 3)
            return {
                "driver": driver,
                "solution": solution,
            }

        sample_spec = {
            "driver": jax.ShapeDtypeStruct(
                shape=(ctx_len, channels), dtype=jnp.float32
            ),
            "solution": jax.ShapeDtypeStruct(shape=(tgt_len, 3, 3), dtype=jnp.float32),
        }

        return DiskSource(
            length=int(self._dataset_len),
            sample_fn=_read_sample,
            sample_spec=sample_spec,
            ordering=self.ordering,
            prefetch_size=self.config.experiment_config.batch_size,
        )


if __name__ == "__main__":
    from taming_the_ito_lyon.config.config import load_toml_config

    config = load_toml_config("configs/sg_so3_sim/m_nrde_mlp.toml")
    dataset = SO3DynamicsSim(config, "train")
    print(dataset._driver_np.shape)
    print(dataset._solution_np.shape)
