from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Optional, Literal

import copy
import json
import os
import re
import time
import yaml

import numpy as np
import matplotlib.pyplot as plt
import torch

from phijax.torch.utils import Logger, save_checkpoint, restore_checkpoint, Collection
from phijax.torch.models import create_optimizer_and_scheduler
from phijax.torch.equations.registry import get_pde


Mode = Literal["train", "eval"]


@dataclass(frozen=True)
class RunID:
    exp_name: str
    sweep_id: str
    seed: int
    exp_path: str


def _slugify(s: str) -> str:
    s = s.strip().lower()
    s = re.sub(r"[^a-z0-9._-]+", "-", s)
    s = re.sub(r"-{2,}", "-", s).strip("-")
    return s or "run"


def create_experiment_dir(
    exp_path: str,
    sweep_id: str,
    seed: int,
    *,
    create_if_missing: bool = True,
    select_index: Optional[int] = None,
) -> str:
    exp_path = exp_path or "./runs"
    base = os.path.join(exp_path, sweep_id, f"seed_{int(seed)}")
    parent = os.path.dirname(base)
    basename = os.path.basename(base)

    if create_if_missing:
        os.makedirs(parent, exist_ok=True)
        candidates = [d for d in os.listdir(parent) if d == basename or d.startswith(basename + "-")]
        suffixes = []
        has_base = False
        for d in candidates:
            if d == basename:
                has_base = True
            else:
                tail = d[len(basename) + 1 :]
                if tail.isdigit():
                    suffixes.append(int(tail))
        if not has_base and not suffixes:
            path = base
        else:
            nxt = (max(suffixes) if suffixes else 0) + 1
            path = f"{base}-{nxt}"
        os.makedirs(path, exist_ok=True)
        return path

    if select_index is None:
        if not os.path.isdir(parent):
            raise FileNotFoundError(parent)
        candidates = []
        for d in os.listdir(parent):
            if d == basename:
                candidates.append((0, os.path.join(parent, d)))
            else:
                m = re.fullmatch(re.escape(basename) + r"-(\d+)", d)
                if m:
                    candidates.append((int(m.group(1)), os.path.join(parent, d)))
        if not candidates:
            raise FileNotFoundError(base)
        return max(candidates, key=lambda x: x[0])[1]

    if select_index < 0:
        raise ValueError("select_index must be >= 0")
    return base if select_index == 0 else f"{base}-{select_index}"


class Trainer:
    def __init__(self, run_cfg: Dict[str, Any], *, device: Optional[str] = None):
        self.raw_cfg = copy.deepcopy(run_cfg)
        self.device_str = device
        self.device = self._resolve_device(device)
        self.run_id = self._make_run_id(self.raw_cfg)
        self.run_dir: Optional[str] = None
        self.model = None
        self.logger = Logger()
        self._optimizer = None
        self._scheduler = None
        self._opt_meta = None

    @staticmethod
    def from_yaml(path: str, *, device: Optional[str] = None) -> "Trainer":
        with open(path, "r") as f:
            cfg = yaml.safe_load(f)
        return Trainer(cfg, device=device)

    @staticmethod
    def _read_yaml(path: str) -> Dict[str, Any]:
        with open(path, "r") as f:
            return yaml.safe_load(f)

    @classmethod
    def from_run_dir(cls, run_dir: str, *, device: Optional[str] = None) -> "Trainer":
        cfg_path = os.path.join(run_dir, "config.yml")
        run_cfg = cls._read_yaml(cfg_path)
        t = cls(run_cfg, device=device)
        t.run_dir = run_dir
        return t

    @staticmethod
    def _resolve_device(device: Optional[str]) -> torch.device:
        if device is None:
            return torch.device("cuda" if torch.cuda.is_available() else "cpu")
        return torch.device(device)

    @staticmethod
    def _make_run_id(run_cfg: Dict[str, Any]) -> RunID:
        names = run_cfg.get("_names", {})
        pde_name = names.get("pde", run_cfg.get("pde", "pde"))
        model_name = names.get("model", run_cfg.get("exp_name", "model").split("-")[0])
        opt_name = names.get("optim", run_cfg.get("optim", {}).get("optimizer", "opt"))
        activation = run_cfg.get("activation", "act")
        seed = int(run_cfg.get("seed", 0))
        exp_path = run_cfg.get("exp_path", "./runs")
        exp_name = _slugify(f"{model_name}-{pde_name}-{opt_name}-{activation}")
        sweep_id = _slugify(f"{pde_name}_{model_name}_{opt_name}_{activation}")
        return RunID(exp_name=exp_name, sweep_id=sweep_id, seed=seed, exp_path=exp_path)

    def resolve_run_dir(
        self,
        *,
        mode: Mode,
        select_index: Optional[int] = None,
        create_if_missing: Optional[bool] = None,
    ) -> str:
        if create_if_missing is None:
            create_if_missing = (mode == "train")
        self.run_dir = create_experiment_dir(
            exp_path=self.run_id.exp_path,
            sweep_id=self.run_id.sweep_id,
            seed=self.run_id.seed,
            create_if_missing=create_if_missing,
            select_index=select_index,
        )
        return self.run_dir

    def to_collection_cfg(self):
        return Collection.from_dict(copy.deepcopy(self.raw_cfg))

    def build(self) -> None:
        run_cfg = self.to_collection_cfg()
        self.model = get_pde(run_cfg, device=self.device)
        self.model.to(self.device)

        opt = create_optimizer_and_scheduler(run_cfg, self.model.parameters())
        self._optimizer = opt.optimizer
        self._scheduler = opt.scheduler
        self._opt_meta = opt

    def save_config(self) -> None:
        cfg = copy.deepcopy(self.raw_cfg)
        cfg["exp_path"] = self.run_dir
        with open(os.path.join(self.run_dir, "config.yml"), "w") as f:
            yaml.dump(cfg, f, default_flow_style=False)

    def restore(self, checkpoints_root: Optional[str] = None, step: int | str | None = "latest") -> None:
        if checkpoints_root is None:
            checkpoints_root = self.run_dir
        payload = restore_checkpoint(
            workdir=checkpoints_root,
            model=self.model,
            optimizer=self._optimizer,
            step=("latest" if step is None else step),
            map_location=self.device,
        )
        if isinstance(payload, dict) and "state" in payload and payload["state"] is not None:
            st = payload["state"]
            if isinstance(st, dict) and "weights" in st and hasattr(self.model, "weights"):
                for k, v in st["weights"].items():
                    self.model.weights[k] = torch.as_tensor(v, device=self.device)

    def _train_step(self, batch: torch.Tensor) -> float:
        self.model.train()

        opt = self._optimizer
        meta = self._opt_meta
        sched = self._scheduler

        accum = max(int(meta.grad_accum_steps), 1)
        clip = float(meta.clip_norm)

        opt.zero_grad(set_to_none=True)

        loss_val = 0.0
        for micro in range(accum):
            L = self.model.loss(batch)
            (L / accum).backward()
            loss_val = float(L.detach().item())

        if clip > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip)

        opt.step()
        if sched is not None:
            sched.step()

        return loss_val

    def train(self) -> Dict[str, float]:
        self.resolve_run_dir(mode="train")
        self.build()
        self.save_config()

        run_cfg = self.to_collection_cfg()
        num_steps = int(run_cfg.training.num_epochs)
        log_every = int(run_cfg.logging.log_every)
        save_every = run_cfg.logging.save_every
        num_keep = run_cfg.logging.num_keep_ckpts

        wandb_cfg = self.raw_cfg.get("wandb", {})
        use_wandb = bool(wandb_cfg.get("use", False))
        if use_wandb:
            import wandb
            wandb.init(
                project=wandb_cfg.get("project"),
                name=self.raw_cfg.get("exp_name", self.run_id.exp_name),
                config=self.raw_cfg,
                mode=wandb_cfg.get("mode", "online"),
            )

        sampler = iter(self.model.sampler)
        t0 = time.time()

        for step in range(num_steps):
            step_t0 = time.time()
            batch = next(sampler)
            if not torch.is_tensor(batch):
                batch = torch.as_tensor(batch)
            batch = batch.to(self.device)

            _ = self._train_step(batch)

            if step % log_every == 0:
                self.model.eval()
                #with torch.no_grad():
                logs = self.model.log(batch)
                self.logger.log_iter(step, step_t0, time.time(), logs)
                if use_wandb:
                    import wandb
                    wandb.log({k: (float(v) if torch.is_tensor(v) else v) for k, v in logs.items()}, step=step)

            if save_every is not None and (step % int(save_every) == 0 or step == num_steps - 1):
                extra_state = {"weights": {k: float(v.detach().cpu()) for k, v in getattr(self.model, "weights", {}).items()}}
                save_checkpoint(
                    workdir=self.run_dir,
                    step=step,
                    model=self.model,
                    optimizer=self._optimizer,
                    state=extra_state,
                    keep=num_keep,
                )

        metrics = self.evaluate(load_from_this_run=True, plot=True)

        if use_wandb:
            import wandb
            wandb.log({**metrics, "total_time": time.time() - t0})
            wandb.finish()
        print(f"Training completed in {time.time() - t0:.2f} seconds.")

        return metrics

    def evaluate(self, *, select_index: Optional[int] = None, load_from_this_run: bool = True, plot: bool = True) -> Dict[str, float]:
        if not load_from_this_run:
            self.resolve_run_dir(mode="eval", select_index=select_index, create_if_missing=False)
        if self.model is None:
            self.build()

        self.model.eval()
        self.restore(self.run_dir, step="latest")

        with torch.no_grad():
            rmse = float(self.model.compute_l2_error().detach().cpu().item())
            rmae = float(self.model.compute_rmae().detach().cpu().item())

        metrics = {"rmse": rmse, "rmae": rmae}

        out_dir = self.run_dir
        with open(os.path.join(out_dir, "metrics.json"), "w") as f:
            json.dump(metrics, f, indent=4)

        if plot and getattr(self.model, "t_star", None) is not None:
            t_star = self.model.t_star.detach().cpu().numpy()
            x_star = self.model.x_star.detach().cpu().numpy()
            u_ref = self.model.u_ref.detach().cpu().numpy()
            u_pred = self.model.u_pred_grid().detach().cpu().numpy()
            self.plot2d(t_star, x_star, u_ref, u_pred, out_dir)

        return metrics

    def plot2d(self, t_star, x_star, u_ref, u_pred, out_dir: str):
        TT, XX = np.meshgrid(t_star, x_star, indexing="ij")

        fig = plt.figure(figsize=(18, 5))
        plt.subplot(1, 3, 1)
        plt.pcolor(TT, XX, u_ref, cmap="jet")
        plt.colorbar()
        plt.xlabel("t")
        plt.ylabel("x")
        plt.title("Exact")
        plt.tight_layout()

        plt.subplot(1, 3, 2)
        plt.pcolor(TT, XX, u_pred, cmap="jet")
        plt.colorbar()
        plt.xlabel("t")
        plt.ylabel("x")
        plt.title("Predicted")
        plt.tight_layout()

        plt.subplot(1, 3, 3)
        plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet")
        plt.colorbar()
        plt.xlabel("t")
        plt.ylabel("x")
        plt.title("Absolute error")
        plt.tight_layout()

        fig_path = os.path.join(out_dir, "figure.pdf")
        fig.savefig(fig_path, bbox_inches="tight", dpi=300)
        plt.close(fig)

    def load_model(
        self,
        *,
        select_index: Optional[int] = None,
        run_config: Optional[Dict[str, Any]] = None,
        run_config_path: Optional[str] = None,
        run_dir: Optional[str] = None,
    ):
        if sum(x is not None for x in [run_config, run_config_path, run_dir]) > 1:
            raise ValueError("Pass only one of: run_config, run_config_path, run_dir.")

        if run_config_path is not None:
            run_config = self._read_yaml(run_config_path)

        if run_dir is not None:
            cfg_path = os.path.join(run_dir, "config.yml")
            run_config = self._read_yaml(cfg_path)
            self.run_dir = run_dir

        if run_config is not None:
            self.raw_cfg = copy.deepcopy(run_config)
            self.run_id = self._make_run_id(self.raw_cfg)

        if self.run_dir is None:
            self.resolve_run_dir(mode="eval", select_index=select_index, create_if_missing=False)

        self.build()
        self.restore(self.run_dir, step="latest")
        return self.model
