import json
import pickle
import numpy as np
from pathlib import Path
from dataclasses import asdict
from typing import Any, Callable
from pipeline.eval.eval_configs import EvaluateConfig
from pipeline.training.train_configs import TrainConfig
from pipeline.dataloader.transforms import add_noise, crop, normalize
from pipeline.eval.metrics import get_histogram
from pipeline.visualization import plot_histogram_comparison
from pipeline.training.state_init import CriticState, EmulatorState, SummaryState, TrainState
from pipeline.visualization.lorentz63_plots import (
    plot_l63_attractor_colored_3d,
    plot_l63_discriminator_phase_space,
    plot_l63_full_double,
    plot_l63_full_single,
    plot_l63_projection,
    plot_l63_projection_psd,
    plot_l63_regime_identification,
    plot_l63_projections_colored_by_summary,
    plot_summary_vs_state,
)
from modules.distance.wgan_distance import make_critic

# Utilities
# ----------------------------
def _require_manim() -> Any:
    try:
        import manim  # noqa: F401
        from manim import config  # type: ignore
    except ImportError as e:
        raise ImportError(
            "Manim is required for video rendering. "
            "uv sync --extra manim"
        ) from e
    return config

def _load_json(path: Path) -> dict[str, Any]:
    with open(path, "r") as f:
        return json.load(f)

def _is_valid_ckpt_dir(d: Path) -> bool:
    return d.is_dir() and (d / "checkpoint.pkl").is_file()


def _pick_checkpoint_dir(ckpt_root: Path, requested: str | None) -> Path:
    """
    Returns a CHECKPOINT DIRECTORY (not a file).
    Priority:
      - requested (if provided and valid)
      - best/ (optional, if exists)
      - final/
      - highest epoch_XXXX/
      - most recently modified valid dir
    """
    if not ckpt_root.exists():
        raise FileNotFoundError(f"Checkpoint root not found: {ckpt_root}")

    if requested is not None:
        cand = ckpt_root / requested
        if _is_valid_ckpt_dir(cand):
            return cand
        raise FileNotFoundError(
            f"Requested checkpoint '{requested}' not found or missing checkpoint.pkl in: {cand}"
        )

    best = ckpt_root / "best"
    if _is_valid_ckpt_dir(best):
        return best

    final = ckpt_root / "final"
    if _is_valid_ckpt_dir(final):
        return final

    epoch_dirs = [p for p in ckpt_root.glob("epoch_*") if _is_valid_ckpt_dir(p)]
    if epoch_dirs:
        # Prefer highest epoch index if name matches epoch_XXXX
        def _epoch_num(p: Path) -> int:
            try:
                return int(p.name.split("_", 1)[1])
            except Exception:
                return -1

        epoch_dirs_sorted = sorted(epoch_dirs, key=_epoch_num, reverse=True)
        if _epoch_num(epoch_dirs_sorted[0]) >= 0:
            return epoch_dirs_sorted[0]

        # Fallback to newest mtime
        return max(epoch_dirs, key=lambda p: p.stat().st_mtime)

    raise FileNotFoundError(f"No valid checkpoint dirs found under: {ckpt_root}")

def init_state(ckpt: dict[str, Any]) -> TrainState:
    emulator_state = EmulatorState(
        params=ckpt["emulator_params"],
        opt_state=ckpt["emulator_opt_state"],
    )
    summary_state = SummaryState(
        params=ckpt["summary_params"],
        opt_state=ckpt["summary_opt_state"],
    )
    critic_state = CriticState(
        params=ckpt["critic_params"],
        opt_state=ckpt["critic_opt_state"],
    )

    return TrainState(
        emulator=emulator_state,
        summary=summary_state,
        critic=critic_state,
        step=ckpt["step"],
    )

def _load_checkpoint(ckpt_dir: Path) -> dict[str, Any]:
    ckpt_file = ckpt_dir / "checkpoint.pkl"
    with open(ckpt_file, "rb") as f:
        obj = pickle.load(f)
    if not isinstance(obj, dict):
        raise TypeError(f"Expected dict checkpoint, got: {type(obj)}")
    return obj


def _make_transforms(eval_cfg: EvaluateConfig, train_cfg: TrainConfig) -> list[Callable]:
    transforms: list[Callable] = []

    noise_level = eval_cfg.noise_level
    if noise_level is None:
        noise_level = float(getattr(train_cfg, "noise_level", 0.0))

    crop_window_size = eval_cfg.crop_window_size
    if crop_window_size is None:
        crop_window_size = int(getattr(train_cfg, "crop_window_size", 0) or 0)

    if noise_level and noise_level > 0.0:
        transforms.append(add_noise(noise_level=float(noise_level)))

    if crop_window_size and crop_window_size > 0:
        transforms.append(crop(window_size=int(crop_window_size)))

    if eval_cfg.normalize:
        transforms.append(normalize())

    return transforms

def _write_results_txt(
    out_dir: Path,
    *,
    summary_type: str,
    train_config: TrainConfig,
    ckpt_dir: Path,
    metrics: dict[str, Any],
) -> None:
    traj = metrics.get("trajectory", {})

    ot_or_noot = "ot" if train_config.distance_config["type"] != "no_ot" else "no_ot"
    noise_level = float(getattr(train_config, "noise_level", 0.0))

    lines: list[str] = []
    lines.append(f"Checkpoint dir: {ckpt_dir}\n")
    lines.append(f"Summary type: {summary_type}\n")
    lines.append(f"OT: {ot_or_noot}\n")
    lines.append(f"Noise level: {noise_level}\n")


    for key in ["mse", "histogram_error", "energy_spectrum", "ipm"]:
        if key in traj:
            lines.append(f"{key}: {traj[key]}\n")

    with open(out_dir / "results.txt", "w") as f:
        f.writelines(lines)

def _make_figures(
    out_dir: Path,
    *,
    u_true: np.ndarray,
    u_hat: np.ndarray,
    u_hat_full: np.ndarray,
    s_true: np.ndarray,
    s_hat: np.ndarray,
    train_config: TrainConfig,
    critic_params: Any | None = None,
) -> list[Path]:
    
    figures = []

    if s_true.ndim == 2:
        s_true = s_true[..., None]
    if s_hat.ndim == 2:
        s_hat = s_hat[..., None]

    # Histograms
    k = int(np.ceil(np.sqrt(max(u_true.shape[1], u_hat.shape[1]))))
    hist_s_true = get_histogram(s_true, num_bins=k, density=True)
    hist_s_hat = get_histogram(s_hat, num_bins=k, density=True)
    hist_u_true = get_histogram(u_true, num_bins=k, density=True)
    hist_u_hat = get_histogram(u_hat, num_bins=k, density=True)
    
    # Labels
    u_label = r"u(t)"
    u_comp_labels = [r"$x(t)$", r"$y(t)$", r"$z(t)$"]

    ot_enabled = train_config.distance_config["type"] != "no_ot"
    
    figures.append(
        plot_histogram_comparison(
            hist_u_true[0],
            hist_u_hat[0],
            comp_labels=u_comp_labels,
            title="Trajectory Histogram Comparison",
        )
    )
    if ot_enabled:
        figures.append(
            plot_histogram_comparison(
                hist_s_true[0],
                hist_s_hat[0],
                comp_labels=[f"$f_{i}(u(t))$" if s_true.shape[2] > 1 else f"$f(u(t))$" for i in range(1, s_true.shape[2] + 1)],
                title="Summary Histogram Comparison",
            )
        )

    # Trajectory curves
    figures.append(plot_l63_full_single(u_true[0], plotted_variable=u_label, u_comp_labels=u_comp_labels))
    figures.append(
        plot_l63_full_double(
            u_hat[0],
            u_true[0],
            plotted_variable_1=u_label + " Predicted",
            plotted_variable_2=u_label + " True",
            first=[r"$x(t)_{\mathrm{pred}}$", r"$y(t)_{\mathrm{pred}}$", r"$z(t)_{\mathrm{pred}}$"],
            second=[r"$x(t)_{\mathrm{true}}$", r"$y(t)_{\mathrm{true}}$", r"$z(t)_{\mathrm{true}}$"],
        )
    )

    if s_true.shape[2] == 1 and ot_enabled:
        figures.append(plot_l63_projection(u_true[0], s_true[0], s_hat[0]))
        figures.append(plot_l63_projection_psd(u_true[0], s_true[0], s_hat[0]))
        figures.append(plot_l63_attractor_colored_3d(u_true[0], s_true[0]))
    if s_true.shape[2] == 1:
        figures.append(plot_l63_regime_identification(u_true[0], s_true[0]))
        figures.append(plot_summary_vs_state(u_true[0], s_true[0]))
        figures.append(plot_l63_projections_colored_by_summary(u_true[0], s_true[0], cbar_label="s_true"))
        figures.append(
            plot_l63_projections_colored_by_summary(u_hat_full[0], s_hat[0], cbar_label="s_pred")
        )

    if train_config.distance_config["type"] == "wgan" and critic_params is not None:
        repo_root = Path(__file__).resolve().parents[3]
        dist_cfg_path = repo_root / "configs" / "distance" / "wgan_config.json"
        if dist_cfg_path.exists():
            dist_cfg = _load_json(dist_cfg_path)
        else:
            dist_cfg = {}
        input_dim = int(s_true.shape[2]) if s_true.ndim == 3 else 1
        dist_cfg = {**dist_cfg, "input_dim": input_dim}

        _, apply_critic = make_critic(dist_cfg)
        phi_true = np.asarray(apply_critic(critic_params, s_true))
        phi_fake = np.asarray(apply_critic(critic_params, s_hat))

        if phi_true.ndim == 3:
            phi_true = np.mean(phi_true, axis=-1)
        if phi_fake.ndim == 3:
            phi_fake = np.mean(phi_fake, axis=-1)

        u_fake = u_hat_full if u_hat_full is not None else u_hat
        figures.append(
            plot_l63_discriminator_phase_space(
                u_true[0],
                u_fake[0],
                phi_true[0],
                phi_fake[0],
            )
        )
    # If summary is 3D and OT enabled, reproduce extra plots
    if s_true.shape[2] == 3 and ot_enabled:
        s_label = r"f(u(t))"
        s_comp_labels = [f"$f_{i}(u(t))$" for i in range(1, s_true.shape[2] + 1)]
        figures.append(plot_l63_full_single(s_true[0], plotted_variable=s_label, s_comp_labels=s_comp_labels))
        figures.append(
            plot_l63_full_double(
                s_hat[0],
                s_true[0],
                plotted_variable_1=s_label + " Predicted",
                plotted_variable_2=s_label + " True",
                first=[r"$f_1(\cdot)_{\mathrm{pred}}$", r"$f_2(\cdot)_{\mathrm{pred}}$", r"$f_3(\cdot)_{\mathrm{pred}}$"],
                second=[r"$f_1(\cdot)_{\mathrm{true}}$", r"$f_2(\cdot)_{\mathrm{true}}$", r"$f_3(\cdot)_{\mathrm{true}}$"],
            )
        )
        figures.append(
            plot_l63_full_double(
                s_true[0],
                u_true[0],
                plotted_variable_1=u_label,
                plotted_variable_2=s_label,
                first=u_comp_labels,
                second=s_comp_labels,
            )
        )

    image_paths: list[Path] = []
    for i, fig in enumerate(figures, 1):
        p = out_dir / f"figure_{i:02d}.pdf"
        fig.savefig(p, dpi=300, bbox_inches="tight", )
        image_paths.append(p)
    return image_paths


def _render_video_if_requested(
    *,
    out_dir: Path,
    eval_cfg: EvaluateConfig,
    u_true: np.ndarray,
    u_hat: np.ndarray,
    s_true: np.ndarray,
) -> Path | None:
    if not eval_cfg.make_video:
        return None

    config = _require_manim()

    # Your existing manim Scene
    from pipeline.visualization.animations_ce import Lorenz63  # noqa: WPS433

    config["media_dir"] = str(out_dir)
    config["video_dir"] = str(out_dir)
    config["images_dir"] = str(out_dir)
    config["output_file"] = "lorenz63.mp4"
    config["pixel_height"] = 2160
    config["pixel_width"] = 3840
    config["frame_rate"] = 60

    scene = Lorenz63(
        u_true[0],
        u_hat[0],
        np.concatenate([np.zeros_like(s_true), np.zeros_like(s_true), s_true], axis=2)[0]
        if s_true.shape[2] == 1
        else s_true[0],
    )
    scene.render()

    video_path = out_dir / "lorenz63.mp4"
    return video_path if video_path.exists() else None


def _maybe_init_wandb(eval_cfg: EvaluateConfig, *, train_cfg: TrainConfig, ckpt_dir: Path) -> Any | None:
    use_wandb = eval_cfg.wandb_project is not None
    if not use_wandb:
        return None
    try:
        import wandb
    except ImportError as e:
        raise ImportError(
            "wandb is not installed. Install it with:\n"
            "uv sync --extra wandb"
        ) from e

    run = wandb.init(
        project=eval_cfg.wandb_project,
        job_type="eval",
        group=eval_cfg.wandb_group,
        name=eval_cfg.wandb_run_name,
        config={
            "eval": asdict(eval_cfg),
            "train_config": train_cfg.__dict__,
            "checkpoint_dir": str(ckpt_dir),
        },
    )
    return run

    
