from __future__ import annotations
import jax
import numpy as np
import jax.numpy as jnp
from pathlib import Path
from typing import Any, Callable
from pipeline.bookkeeping.logging import eval_log
from pipeline.dataloader.transforms import compose, crop, normalize
from pipeline.dataloader.dataset import DynamicalSystemDataset
from pipeline.eval.metrics import compute_summary_errors, compute_trajectory_errors, get_histogram
from pipeline.training.runtime_resolve import Experiment, Runtime, resolve_runtime
from pipeline.training.state_init import TrainState
from pipeline.training.train_configs import TrainConfig
from pipeline.training.train_helper import EpochFlags, make_val_step
from pipeline.eval.eval_utils import EvaluateConfig, _load_checkpoint, _load_json, _make_figures, _maybe_init_wandb, _pick_checkpoint_dir, _render_video_if_requested, _write_results_txt, _make_transforms, init_state

def _compute_epoch_flags(epoch: int, train_config: TrainConfig) -> EpochFlags:
    use_ot = epoch >= train_config.ot_warm_up
    flags = {
        "use_ot": use_ot,
        "lambda_ot": train_config.lambda_ot if use_ot else 0.0,
    }
    return EpochFlags(**flags)


def _eval_dataset_batched(
    *,
    state: TrainState,
    dataset: DynamicalSystemDataset,
    rng: jax.Array,
    data_seed: int,
    train_config: TrainConfig,
    eval_cfg: EvaluateConfig,
    runtime: Runtime,
    eval_step: Callable,
    epoch_for_flags: int,
) -> tuple[dict[str, Any], np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Returns:
      metrics dict, plus (u_true, u_hat, u_hat_full, s_true, s_hat) as NumPy arrays on host.
    """
    epoch_flags = _compute_epoch_flags(epoch_for_flags, train_config)

    u_true_all: list[jax.Array] = []
    u_hat_all: list[jax.Array] = []
    u_hat_full_all: list[jax.Array] = []
    s_true_all: list[jax.Array] = []
    s_hat_all: list[jax.Array] = []

    num_samples = len(dataset)

    if num_samples < train_config.batch_size:
        batch_indices = np.arange(num_samples).reshape(1, -1)
    else:
        num_batches = num_samples // train_config.batch_size
        indices = np.arange(num_samples)
        batch_indices = indices[:num_batches * train_config.batch_size].reshape((num_batches, train_config.batch_size))

    transforms = compose([
        # add_noise(noise_level=train_config.noise_level), # Comparing with clean data for now
        # crop(window_size=train_config.crop_window_size),
        normalize(),
    ])

    for batch_idx in batch_indices:
        rng, step_rng = jax.random.split(rng)
        traj_batch = dataset.get_batch(batch_idx)

        transformed_batch = []
        for i in range(len(batch_idx)):
            transformed = transforms(traj_batch[i], step_rng)
            transformed_batch.append(transformed)

        traj_batch = np.stack(transformed_batch, axis=0)
        traj_batch = jax.device_put(traj_batch, runtime.device)

        outputs = eval_step(
            state=state,
            batch=traj_batch,
            rng=step_rng,
            epoch_flags=epoch_flags,
        )

        u_true_all.append(outputs["u_true"])
        u_hat_all.append(outputs["u_hat"])
        u_hat_full_all.append(outputs.get("u_hat_full", outputs["u_hat"]))
        s_true_all.append(outputs["s_true"])
        s_hat_all.append(outputs["s_hat"])

    u_true = np.asarray(jnp.concatenate(u_true_all, axis=0))
    u_hat = np.asarray(jnp.concatenate(u_hat_all, axis=0))
    u_hat_full = np.asarray(jnp.concatenate(u_hat_full_all, axis=0))
    s_true = np.asarray(jnp.concatenate(s_true_all, axis=0))
    s_hat = np.asarray(jnp.concatenate(s_hat_all, axis=0))

    traj_metrics = compute_trajectory_errors(u_true, u_hat)
    summary_metrics = compute_summary_errors(s_true, s_hat)

    if 'mse_per_time' in traj_metrics:
        traj_metrics.pop('mse_per_time')
    if 'mse_per_time' in summary_metrics:
        summary_metrics.pop('mse_per_time')

    metrics = {
        "trajectory": traj_metrics,
        "summary": summary_metrics,
    }
    return metrics, u_true, u_hat, u_hat_full, s_true, s_hat


def eval_exp(eval_cfg: EvaluateConfig) -> Path:


    exp_dir = Path(eval_cfg.experiment_dir)
    if not exp_dir.exists():
        raise FileNotFoundError(f"Experiment dir not found: {exp_dir}")

    # Load configs saved by training script
    train_cfg_dict = _load_json(exp_dir / "train_config.json")
    exp_cfg_dict = _load_json(exp_dir / "exp_config.json")

    # Ignore any legacy keys that are no longer in TrainConfig
    allowed = {f.name for f in TrainConfig.__dataclass_fields__.values()}
    train_cfg = TrainConfig(**{k: v for k, v in train_cfg_dict.items() if k in allowed})
    experiment = Experiment(**exp_cfg_dict)

    # Override runtime device/dtype from eval_cfg (without mutating TrainConfig)
    runtime = resolve_runtime(train_cfg)
    ckpt_root = exp_dir / "checkpoints"
    ckpt_dir = _pick_checkpoint_dir(ckpt_root, eval_cfg.checkpoint)
    ckpt = _load_checkpoint(ckpt_dir)

    state_keys = ['epoch', 'step', 'emulator_params', 'emulator_opt_state', 'summary_params', 'summary_opt_state']

    for key in state_keys:
        if key not in ckpt:
            raise KeyError(
                f"Checkpoint dict must contain key '{key}'. "
                f"Update save_checkpoint() to store '{{'{key}': value, ...}}'."
            )
        
    state: TrainState = init_state(ckpt)
    epoch = int(ckpt.get("epoch", 0))

    # Dataset
    test_dataset = DynamicalSystemDataset(data_path=eval_cfg.test_data_path)

    # Eval step
    eval_step = make_val_step(
        rollout_mse=runtime.rollout_mse,
        rollout_ot=runtime.rollout_ot,
        summary_apply=runtime.summary_apply,
    )

    eval_step = jax.jit(eval_step, static_argnames=("epoch_flags",))

    rng = jax.random.PRNGKey(int(experiment.seed))
    data_seed = int(experiment.seed)

    metrics, u_true, u_hat, u_hat_full, s_true, s_hat = _eval_dataset_batched(
        state=state,
        dataset=test_dataset,
        rng=rng,
        data_seed=data_seed,
        train_config=train_cfg,
        eval_cfg=eval_cfg,
        runtime=runtime,
        eval_step=eval_step,
        epoch_for_flags=epoch,
    )

    # Write text summary
    summary_type = str(getattr(train_cfg, "summary_config", {}).get("type", "identity"))
    _write_results_txt(
        exp_dir,
        summary_type=summary_type,
        train_config=train_cfg,
        ckpt_dir=ckpt_dir,
        metrics=metrics,
    )

    # Figures
    image_paths: list[Path] = []
    if eval_cfg.make_figures:
        image_paths = _make_figures(
            exp_dir,
            u_true=u_true,
            u_hat=u_hat,
            u_hat_full=u_hat_full,
            s_true=s_true,
            s_hat=s_hat,
            train_config=train_cfg,
            critic_params=state.critic.params,
        )

    # Video
    video_path = _render_video_if_requested(
        out_dir=exp_dir,
        eval_cfg=eval_cfg,
        u_true=u_true,
        u_hat=u_hat,
        s_true=s_true,
    )

    # W&B
    if eval_cfg.use_wandb:
        wandb_run = _maybe_init_wandb(eval_cfg, train_cfg=train_cfg, ckpt_dir=ckpt_dir, )
    else: 
        wandb_run = None
    eval_log(
        eval_cfg,
        wandb_run=wandb_run,
        image_paths=image_paths,
        video_path=video_path,
        out_dir=exp_dir,
        ckpt_dir=ckpt_dir,
    )

    print(f"[OK] Evaluation complete. Outputs at: {exp_dir}")


if __name__ == "__main__":
    exp_dir = "outputs/lorenz63_no_ot_noisy/18c2773d"
    test_data_path = "data/lorenz63/9bda2e47/test_data.npz"
    cfg = EvaluateConfig(
        experiment_dir=Path(exp_dir),
        test_data_path=Path(test_data_path),
        checkpoint=None,  # or "final" or "epoch_0009"
        # device="cpu",
        # dtype="float32",
        make_figures=True,
        # make_video=False,
        wandb_project=None,
    )
    eval_exp(cfg)
