from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Optional, Literal, Sequence, Union
import os, time, copy, json, yaml, re

import jax, jax.numpy as jnp, matplotlib.pyplot as plt, numpy as np
from jax.tree_util import tree_map


from phijax.utils import(
    Logger, save_checkpoint,
    restore_checkpoint, count_params, Collection
)
from phijax.equations import get_pde


Mode = Literal["train", "eval"] 



def _latest_run_root(base_seed_dir: str) -> str:
    parent = os.path.dirname(base_seed_dir)
    basename = os.path.basename(base_seed_dir)
    candidates = []
    for d in os.listdir(parent):
        if d == basename:
            candidates.append((0, os.path.join(parent, d)))
        else:
            m = re.fullmatch(re.escape(basename) + r"-(\d+)", d)
            if m:
                candidates.append((int(m.group(1)), os.path.join(parent, d)))
    if not candidates:
        raise FileNotFoundError(base_seed_dir)
    return max(candidates, key=lambda x: x[0])[1]


def resolve_run_root(
    *,
    exp_path: str,
    sweep_id: str,
    seed: int,
    select_index: Optional[int] = None,
) -> str:
    base = os.path.join(exp_path or "./runs", sweep_id, f"seed_{int(seed)}")
    if select_index is None:
        return _latest_run_root(base)
    if select_index < 0:
        raise ValueError
    return base if select_index == 0 else f"{base}-{select_index}"


def checkpoints_dir(run_root: str) -> str:
    return os.path.join(run_root, "checkpoints")


@dataclass(frozen=True)
class LoadedRun:
    run_root: str
    ckpt_dir: str
    cfg: Dict[str, Any]
    model: Any
    metrics: Optional[Dict[str, float]]


def load_run(
    *,
    trainer_cls,
    run_root: Optional[str] = None,
    exp_path: Optional[str] = None,
    sweep_id: Optional[str] = None,
    seed: Optional[int] = None,
    select_index: Optional[int] = None,
    device: Optional[str] = None,
    load_metrics: bool = True,
) -> LoadedRun:
    if run_root is None:
        if exp_path is None or sweep_id is None or seed is None:
            raise ValueError
        run_root = resolve_run_root(
            exp_path=exp_path,
            sweep_id=sweep_id,
            seed=seed,
            select_index=select_index,
        )

    cfg_path = os.path.join(run_root, "config.yml")
    t = trainer_cls.from_run_dir(run_root, device=device)
    t.run_dir = run_root
    t.build()
    t.restore(checkpoints_dir(run_root))

    metrics = None
    if load_metrics:
        mpath = os.path.join(run_root, "metrics.json")
        if os.path.isfile(mpath):
            with open(mpath, "r") as f:
                metrics = json.load(f)

    return LoadedRun(
        run_root=run_root,
        ckpt_dir=checkpoints_dir(run_root),
        cfg=t.raw_cfg,
        model=t.model,
        metrics=metrics,
    )


@dataclass(frozen=True)
class RunID:
    exp_name: str
    sweep_id: str
    seed: int
    exp_path: str


def _slugify(s: str) -> str:
    s = s.strip().lower()
    s = re.sub(r"[^a-z0-9._-]+", "-", s)
    s = re.sub(r"-{2,}", "-", s).strip("-")
    return s or "run"


def create_experiment_dir(
    exp_path: str,
    exp_name: str,
    seed: int,
    sweep_id: Optional[str] = None,
    create_if_missing: bool = True,
    select_index: Optional[int] = None,
) -> str:
    """
    Creates: <exp_path>/<exp_name>/seed_<seed>/<sweep_id or timestamp>/
    Returns the created directory.
    """
    exp_path = exp_path or "./runs"
    exp_name = _slugify(exp_name)
    model_name, eq_name, *_ = exp_name.split("-")
    parts = [exp_path, eq_name, model_name]
    if sweep_id:
        parts = [exp_path, sweep_id ]
    parts.append(f"seed_{int(seed)}")

    base = os.path.join(*parts)

    parent = os.path.dirname(base)
    basename = os.path.basename(base)

    def existing_suffixes():
        if not os.path.isdir(parent):
            return set(), False
        dirs = [d for d in os.listdir(parent) if d == basename or d.startswith(basename + "-")]
        suffixes = set()
        has_base = False
        for d in dirs:
            if d == basename:
                has_base = True
            else:
                tail = d[len(basename) + 1:] 
                if tail.isdigit():
                    suffixes.add(int(tail))
        return suffixes, has_base
    suffixes, has_base = existing_suffixes()
    if create_if_missing:
        if not has_base and not suffixes:
            path = base
        else:
            next_idx = (max(suffixes) if suffixes else 0) + 1
            path = f"{base}-{next_idx}"
        path = os.path.join(path)
        os.makedirs(path, exist_ok=True)
        print(path)
        return path

    if select_index is not None:
        if select_index < 0:
            raise ValueError("select_index must be >= 0")
        path_no_ckpt = base if select_index == 0 else f"{base}-{select_index}"
        path = os.path.join(path_no_ckpt, "checkpoints")
        if not os.path.isdir(path):
            raise FileNotFoundError(f"Requested run does not exist: {path}")
        return path

    if not has_base and not suffixes:
        raise FileNotFoundError(f"No runs exist under {parent} for {basename}")
    if suffixes:
        latest_idx = max(suffixes)
        path_no_ckpt = f"{base}-{latest_idx}"
    else:
        path_no_ckpt = base

    path = os.path.join(path_no_ckpt, "checkpoints")
    if not os.path.isdir(path):
        raise FileNotFoundError(f"'checkpoints' not found for latest run: {path}")
    return path


class Trainer:
    def __init__(self, run_cfg: Dict[str, Any], *, device: Optional[str] = None):
        self.raw_cfg = copy.deepcopy(run_cfg)
        self.device = device
        self.run_id = self._make_run_id(self.raw_cfg)
        self.run_dir: Optional[str] = None
        self.model = None
        self.logger = Logger()

    @staticmethod
    def from_yaml(path: str, *, device: Optional[str] = None) -> "Trainer":
        cfg = yaml.safe_load(open(path, "r"))
        return Trainer(cfg, device=device)
    
    @staticmethod
    def _read_yaml(path: str) -> Dict[str, Any]:
        with open(path, "r") as f:
            return yaml.safe_load(f)
        
    @classmethod
    def from_run_dir(cls, run_dir: str, *, device: Optional[str] = None) -> "Trainer":
        cfg_path = os.path.join(run_dir, "config.yml")
        run_cfg = cls._read_yaml(cfg_path)
        t = cls(run_cfg, device=device)
        t.run_dir = run_dir
        return t

    @classmethod
    def from_run_config(cls, run_cfg: Dict[str, Any], *, device: Optional[str] = None) -> "Trainer":
        return cls(run_cfg, device=device)

    @staticmethod
    def _make_run_id(run_cfg: Dict[str, Any]) -> RunID:
        names = run_cfg.get("_names", {})
        pde_name = names.get("pde", run_cfg.get("pde", "pde"))
        model_name = names.get("model", run_cfg.get("exp_name", "model").split("-")[0])
        opt_name = names.get("optim", run_cfg.get("optim", {}).get("optimizer", "opt"))
        activation = run_cfg.get("activation")
        seed = int(run_cfg.get("seed", 0))
        exp_path = run_cfg.get("exp_path", "./runs")
        exp_name = f"{model_name}-{pde_name}-{opt_name}-{activation}"
        sweep_id = f"{pde_name}_{model_name}_{opt_name}_{activation}"
        return RunID(exp_name, sweep_id, seed, exp_path)

    def resolve_run_dir(
        self,
        *,
        mode: Mode,
        select_index: Optional[int] = None,
        create_if_missing: Optional[bool] = None,
    ) -> str:
        if create_if_missing is None:
            create_if_missing = (mode == "train")
        print("Sweep ID:", self.run_id.sweep_id, select_index)
        self.run_dir = create_experiment_dir(
            exp_path=self.run_id.exp_path,
            exp_name=self.run_id.exp_name,
            seed=self.run_id.seed,
            sweep_id=self.run_id.sweep_id,
            create_if_missing=create_if_missing,
            select_index=select_index,
        )
        return self.run_dir

    def to_collection_cfg(self):
        return Collection.from_dict(copy.deepcopy(self.raw_cfg))

    def build(self) -> None:
        run_cfg = self.to_collection_cfg()
        self.model = get_pde(run_cfg)

    def save_config(self) -> None:
        cfg = copy.deepcopy(self.raw_cfg)
        cfg["exp_path"] = self.run_dir
        with open(os.path.join(self.run_dir, "config.yml"), "w") as f:
            yaml.dump(cfg, f, default_flow_style=False)

    def restore(self, checkpoints_dir: Optional[str] = None, step=None) -> None:
        if checkpoints_dir is None:
            checkpoints_dir = self.run_dir

        if not checkpoints_dir.endswith("checkpoints"):
            checkpoints_dir = os.path.join(checkpoints_dir, "checkpoints")
        self.model.state = restore_checkpoint(self.model.state, checkpoints_dir, step=step)

    def train(self) -> Dict[str, float]:
        self.resolve_run_dir(mode="train")
        self.save_config()

        wandb_cfg = self.raw_cfg.get("wandb", {})
        if wandb_cfg.get("use", False):
            import wandb
            wandb.init(
                project=wandb_cfg.get("project"),
                name=self.raw_cfg.get("exp_name", self.run_id.exp_name),
                config=self.raw_cfg,
                mode=wandb_cfg.get("mode", "online"),
            )

        self.build()
        run_cfg = self.to_collection_cfg()

        num_epochs = int(run_cfg.training.num_epochs)
        log_every = int(run_cfg.logging.log_every)
        save_every = run_cfg.logging.save_every
        num_keep = run_cfg.logging.num_keep_ckpts

        sampler = iter(self.model.sampler)
        u_ref = self.model.u_ref
        param_count = count_params(self.model.state.params)

        if jax.process_index() == 0:
            print(f"Number of parameters: {param_count}")

        print(jax.tree_util.tree_map(lambda x: x.shape, self.model.state.params))

        t0 = time.time()

        ## weighting
        weighting_cfg = run_cfg.weighting
        scheme = weighting_cfg.scheme
        weight_update_freq = int(weighting_cfg.get("update_freq", 1000))

        for step in range(num_epochs):
            step_t0 = time.time()
            batch = next(sampler)
            self.model.state = self.model.step(self.model.state, batch)
            #n = jax.local_device_count()
            #step_keys = jax.random.split(jax.random.PRNGKey(step), n)
            #self.model.state = self.model.step_pcgrad(self.model.state, batch, step_keys)

            
            #if jax.process_index() == 0 and scheme in ["ntk", "align", "groupdro"] and (step + 1 ) % weight_update_freq == 0:
                #self.model.state = self.model.update_weights(self.model.state, batch)
            #    print(self.model.state.weights)
            #    pass

            if step % log_every == 0 and jax.process_index() == 0:
                state = jax.device_get(tree_map(lambda x: x[0], self.model.state))
                batch0 = jax.device_get(tree_map(lambda x: x[0], batch))
                logs = self.model.log(state, batch0, u_ref)
                #print(logs['log_sigma'])
                self.logger.log_iter(step, step_t0, time.time(), logs)
                if wandb_cfg.get("use", False):
                    import wandb
                    wandb.log(logs, step=step)

            if save_every is not None and jax.process_index() == 0:
                if (step ) % int(save_every) == 0 or (step) == num_epochs:
                    save_checkpoint(
                        self.model.state,
                        os.path.join(self.run_dir, "checkpoints"),
                        keep=num_keep,
                    )

        metrics = self.evaluate()

        if jax.process_index() == 0:
            t1 = time.time()
            if wandb_cfg.get("use", False):
                import wandb
                wandb.log(
                    {**metrics, "total_time": t1 - t0, "param_count": param_count}
                )
                wandb.finish()
            print(f"Training completed in {t1 - t0:.2f} seconds.")

        return metrics

    def evaluate(
        self,
        *,
        select_index: Optional[int] = None,
        load_from_this_run: bool = True,
        plot: bool = True
    ) -> Dict[str, float]:
        if not load_from_this_run:
            self.resolve_run_dir(mode="eval", select_index=select_index, create_if_missing=False)

        self.build()

        ckpt_dir = (
            os.path.join(self.run_dir, "checkpoints")
            if load_from_this_run
            else self.run_dir
        )
        self.restore(ckpt_dir)

        state = self.model.state
        params = self.model.state.params
        u_ref = self.model.u_ref

        rmse = float(self.model.compute_l2_error(state))
        rmae = float(self.model.compute_rmae(state))

        metrics = {"rmse": rmse, "rmae": rmae}

        if jax.process_index() == 0:
            out_dir = (
                os.path.dirname(self.run_dir)
                if self.run_dir.endswith("checkpoints")
                else self.run_dir
            )
            with open(os.path.join(out_dir, "metrics.json"), "w") as f:
                json.dump(metrics, f, indent=4)
            if plot:
                t_star = self.model.t_star
                x_star = self.model.x_star
                u_pred = self.model.u_pred_fn(state, t_star, x_star)
                self.plot2d(t_star, x_star, u_ref, u_pred, out_dir)

        return metrics


    def plot2d(self, 
        t_star, 
        x_star,
        u_ref,
        u_pred, 
        out_dir

    ):
        TT, XX = jnp.meshgrid(t_star, x_star, indexing="ij")

        # plot
        fig = plt.figure(figsize=(18, 5))
        plt.subplot(1, 3, 1)
        plt.pcolor(TT, XX, u_ref, cmap="jet")
        plt.colorbar()
        plt.xlabel("t")
        plt.ylabel("x")
        plt.title("Exact")
        plt.tight_layout()

        plt.subplot(1, 3, 2)
        plt.pcolor(TT, XX, u_pred, cmap="jet")
        plt.colorbar()
        plt.xlabel("t")
        plt.ylabel("x")
        plt.title("Predicted")
        plt.tight_layout()

        plt.subplot(1, 3, 3)
        plt.pcolor(TT, XX, jnp.abs(u_ref - u_pred), cmap="jet")
        plt.colorbar()
        plt.xlabel("t")
        plt.ylabel("x")
        plt.title("Absolute error")
        plt.tight_layout()
        fig_path = os.path.join(out_dir, "figure.pdf")
        fig.savefig(fig_path, bbox_inches="tight", dpi=300)
    

    def load_model(
        self,
        *,
        select_index: Optional[int] = None,
        run_config: Optional[Dict[str, Any]] = None,
        run_config_path: Optional[str] = None,
        run_dir: Optional[str] = None,
    ):
        if sum(x is not None for x in [run_config, run_config_path, run_dir]) > 1:
            raise ValueError("Pass only one of: run_config, run_config_path, run_dir.")

        if run_config_path is not None:
            run_config = self._read_yaml(run_config_path)

        if run_dir is not None:
            cfg_path = os.path.join(run_dir, "config.yml")
            run_config = self._read_yaml(cfg_path)
            self.run_dir = run_dir

        if run_config is not None:
            self.raw_cfg = copy.deepcopy(run_config)
            self.run_id = self._make_run_id(self.raw_cfg)

        if self.run_dir is None:
            self.resolve_run_dir(mode="eval", select_index=select_index, create_if_missing=False)
        self.build()
        self.restore()  
        return self.model


from phijax.data import UniformSampler


class TimeMarchingTrainer(Trainer):
    def _window_ckpt_dir(self, idx: int) -> str:
        if self.run_dir is None:
            raise RuntimeError("run_dir is not set. Call resolve_run_dir() first.")
        return os.path.join(self.run_dir, "checkpoints", f"window_{idx:03d}")
    def train(self) -> Dict[str, float]:
        self.resolve_run_dir(mode="train")
        self.save_config()

        wandb_cfg = self.raw_cfg.get("wandb", {})
        use_wandb = bool(wandb_cfg.get("use", False))
        if use_wandb:
            import wandb
            wandb.init(
                project=wandb_cfg.get("project"),
                name=self.raw_cfg.get("exp_name", self.run_id.exp_name),
                config=self.raw_cfg,
                mode=wandb_cfg.get("mode", "online"),
            )

        self.build()
        run_cfg = self.to_collection_cfg()

        num_epochs = int(run_cfg.training.num_epochs)
        log_every = int(run_cfg.logging.log_every)
        save_every = run_cfg.logging.save_every
        num_keep = run_cfg.logging.num_keep_ckpts

        weighting_cfg = run_cfg.weighting
        scheme = weighting_cfg.scheme
        weight_update_freq = int(weighting_cfg.get("update_freq", 1000))

        num_time_windows = int(run_cfg.training.num_time_windows)

        u_ref = self.model.u_ref
        v_ref = self.model.v_ref
        t_star = self.model.t_star
        x_star = self.model.x_star
        y_star = self.model.y_star

        u0 = u_ref[0, ...]
        v0 = v_ref[0, ...]

        num_time_steps = len(t_star) // num_time_windows
        t = t_star[:num_time_steps]

        dt = t[1] - t[0]
        t0 = t[0]
        t1 = t[-1] + 0.1 * dt

        x0 = x_star[0]
        x1 = x_star[-1]
        y0 = y_star[0]
        y1 = y_star[-1]

        print(t.shape, x_star.shape, y_star.shape)

        dom = jnp.array([[t0, t1], [x0, x1], [y0, y1]])
        sampler = iter(UniformSampler(dom, int(run_cfg.training.batch_size)))

        if jax.process_index() == 0:
            param_count = count_params(self.model.state.params)
            print(f"Number of parameters: {param_count}")
            print(tree_map(lambda x: x.shape, self.model.state))
            print(tree_map(lambda x: x.shape, self.model.state.params))

        for idx in range(num_time_windows):
            step_offset = idx * num_epochs

            

            u_star = u_ref[num_time_steps * idx : num_time_steps * (idx + 1), ...]
            v_star = v_ref[num_time_steps * idx : num_time_steps * (idx + 1), ...]

            self.model.set_initial_condition(u0, v0, u_star, v_star, t)

            self.train_window(
                sampler=sampler,
                u_star=u_star,
                v_star=v_star,
                num_steps=num_epochs,
                step_offset=step_offset,
                log_every=log_every,
                save_every=save_every,
                num_keep=num_keep,
                scheme=scheme,
                weight_update_freq=weight_update_freq,
                use_wandb=use_wandb,
                idx = idx
            )

            if num_time_windows > 1:
                state = jax.device_get(tree_map(lambda x: x[0], self.model.state))

                u0 = self.model.u0_pred_fn(state, t_star[num_time_steps], x_star, y_star)
                v0 = self.model.v0_pred_fn(state, t_star[num_time_steps], x_star, y_star)

                del state
            print(f"Completed time window {idx + 1} / {num_time_windows}")

        return self.evaluate(plot=True)

    def train_window(
        self,
        *,
        sampler,
        u_star,
        v_star,
        num_steps: int,
        step_offset: int,
        log_every: int,
        save_every,
        num_keep: int,
        scheme: str,
        weight_update_freq: int,
        use_wandb: bool,
        idx: int
    ) -> None:
        for step in range(num_steps):
            step_t0 = time.time()
            batch = next(sampler)

            self.model.state = self.model.step(self.model.state, batch)

            if (
                jax.process_index() == 0
                and scheme in ["ntk", "align", "groupdro"]
                and (step + 1) % weight_update_freq == 0
            ):
                self.model.state = self.model.update_weights(self.model.state, batch)

            if step % log_every == 0 and jax.process_index() == 0:
                state = jax.device_get(tree_map(lambda x: x[0], self.model.state))
                batch0 = jax.device_get(tree_map(lambda x: x[0], batch))
                logs = self.model.log(state, batch0)
                self.logger.log_iter(step, step_t0, time.time(), logs)
                if use_wandb:
                    import wandb
                    wandb.log(logs, step=step + step_offset)

            ckpt_dir = self._window_ckpt_dir(idx)
            if save_every is not None and jax.process_index() == 0:
                if (step + 1) % int(save_every) == 0 or (step + 1) == num_steps:
                    save_checkpoint(
                        self.model.state,
                        ckpt_dir,
                        keep=num_keep,
                    )

    def restore_window(self, idx: int) -> None:
        if self.model is None:
            self.build()
        ckpt_dir = self._window_ckpt_dir(idx)
        self.model.state = restore_checkpoint(self.model.state, ckpt_dir)

        

    
        




