from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Literal, Sequence, Union
import os, time, copy, json, yaml, re

import jax, jax.numpy as jnp, matplotlib.pyplot as plt, numpy as np
from jax.tree_util import tree_map
from functools import partial
from jax import random, pmap


from phijax.utils import(
    Logger, save_checkpoint,
    restore_checkpoint, count_params, Collection
)
from phijax.equations import get_pde


from phijax.data import BaseSampler

from .em import *

Mode = Literal["train", "eval"] 



def _latest_run_root(base_seed_dir: str) -> str:
    parent = os.path.dirname(base_seed_dir)
    basename = os.path.basename(base_seed_dir)
    candidates = []
    for d in os.listdir(parent):
        if d == basename:
            candidates.append((0, os.path.join(parent, d)))
        else:
            m = re.fullmatch(re.escape(basename) + r"-(\d+)", d)
            if m:
                candidates.append((int(m.group(1)), os.path.join(parent, d)))
    if not candidates:
        raise FileNotFoundError(base_seed_dir)
    return max(candidates, key=lambda x: x[0])[1]


def resolve_run_root(
    *,
    exp_path: str,
    sweep_id: str,
    seed: int,
    select_index: Optional[int] = None,
) -> str:
    base = os.path.join(exp_path or "./runs", sweep_id, f"seed_{int(seed)}")
    if select_index is None:
        return _latest_run_root(base)
    if select_index < 0:
        raise ValueError
    return base if select_index == 0 else f"{base}-{select_index}"


def checkpoints_dir(run_root: str) -> str:
    return os.path.join(run_root, "checkpoints")


@dataclass(frozen=True)
class LoadedRun:
    run_root: str
    ckpt_dir: str
    cfg: Dict[str, Any]
    model: Any
    metrics: Optional[Dict[str, float]]


def load_run(
    *,
    trainer_cls,
    run_root: Optional[str] = None,
    exp_path: Optional[str] = None,
    sweep_id: Optional[str] = None,
    seed: Optional[int] = None,
    select_index: Optional[int] = None,
    device: Optional[str] = None,
    load_metrics: bool = True,
) -> LoadedRun:
    if run_root is None:
        if exp_path is None or sweep_id is None or seed is None:
            raise ValueError
        run_root = resolve_run_root(
            exp_path=exp_path,
            sweep_id=sweep_id,
            seed=seed,
            select_index=select_index,
        )

    cfg_path = os.path.join(run_root, "config.yml")
    t = trainer_cls.from_run_dir(run_root, device=device)
    t.run_dir = run_root
    try:
        t.build()
        t.restore(checkpoints_dir(run_root))
    except:
        run_cfg = t.to_collection_cfg()
        run_cfg.flag = "state_fail"
        t.model = get_pde(run_cfg)
        t.restore(checkpoints_dir(run_root))

    metrics = None
    if load_metrics:
        mpath = os.path.join(run_root, "metrics.json")
        if os.path.isfile(mpath):
            with open(mpath, "r") as f:
                metrics = json.load(f)

    return LoadedRun(
        run_root=run_root,
        ckpt_dir=checkpoints_dir(run_root),
        cfg=t.raw_cfg,
        model=t.model,
        metrics=metrics,
    )


@dataclass(frozen=True)
class RunID:
    exp_name: str
    sweep_id: str
    seed: int
    exp_path: str


def _slugify(s: str) -> str:
    s = s.strip().lower()
    s = re.sub(r"[^a-z0-9._-]+", "-", s)
    s = re.sub(r"-{2,}", "-", s).strip("-")
    return s or "run"


def create_experiment_dir(
    exp_path: str,
    exp_name: str,
    seed: int,
    sweep_id: Optional[str] = None,
    create_if_missing: bool = True,
    select_index: Optional[int] = None,
) -> str:
    """
    Creates: <exp_path>/<exp_name>/seed_<seed>/<sweep_id or timestamp>/
    Returns the created directory.
    """
    exp_path = exp_path or "./runs"
    exp_name = _slugify(exp_name)
    model_name, eq_name, *_ = exp_name.split("-")
    parts = [exp_path, eq_name, model_name]
    if sweep_id:
        parts = [exp_path, sweep_id ]
    parts.append(f"seed_{int(seed)}")

    base = os.path.join(*parts)

    parent = os.path.dirname(base)
    basename = os.path.basename(base)

    def existing_suffixes():
        if not os.path.isdir(parent):
            return set(), False
        dirs = [d for d in os.listdir(parent) if d == basename or d.startswith(basename + "-")]
        suffixes = set()
        has_base = False
        for d in dirs:
            if d == basename:
                has_base = True
            else:
                tail = d[len(basename) + 1:] 
                if tail.isdigit():
                    suffixes.add(int(tail))
        return suffixes, has_base
    suffixes, has_base = existing_suffixes()
    if create_if_missing:
        if not has_base and not suffixes:
            path = base
        else:
            next_idx = (max(suffixes) if suffixes else 0) + 1
            path = f"{base}-{next_idx}"
        path = os.path.join(path)
        os.makedirs(path, exist_ok=True)
        print(path)
        return path

    if select_index is not None:
        if select_index < 0:
            raise ValueError("select_index must be >= 0")
        path_no_ckpt = base if select_index == 0 else f"{base}-{select_index}"
        path = os.path.join(path_no_ckpt, "checkpoints")
        if not os.path.isdir(path):
            raise FileNotFoundError(f"Requested run does not exist: {path}")
        return path

    if not has_base and not suffixes:
        raise FileNotFoundError(f"No runs exist under {parent} for {basename}")
    if suffixes:
        latest_idx = max(suffixes)
        path_no_ckpt = f"{base}-{latest_idx}"
    else:
        path_no_ckpt = base

    path = os.path.join(path_no_ckpt, "checkpoints")
    if not os.path.isdir(path):
        raise FileNotFoundError(f"'checkpoints' not found for latest run: {path}")
    return path


@dataclass
class Track:
    values: List[float]
    n: int 
    seglen: int
    idx: int
    next_switch: int


class Stepper_:
    
    def __init__(self, values, max_steps: int):
        import math
        self.values = [float(v) for v in values]
        self.max_steps = max_steps
        self.n = len(values)
        self.seglen = math.ceil(max_steps / self.n)
        self.idx = 0
        self.next_switch = self.seglen
    
    def maybe_step(self, state, step:int, set_st_params_fn) -> bool:
        if self.idx >= self.n - 1:
            return state, False

        if step < self.next_switch:
            return state, False
        
        self.idx = min(step // self.seglen, self.n - 1)
        self.next_switch = (self.idx + 1) * self.seglen
        nu_new = dict(state.st_params["nu"])
        nu_new["res"] = self.values[self.idx]

        state = set_st_params_fn(state, nu=nu_new)
        return state, True
    

from dataclasses import dataclass
import math

@dataclass
class _Track:
    values: list
    n: int
    seglen: int
    idx: int
    next_switch: int


class Stepper:
    def __init__(self, schedules, max_steps):
        self.max_steps = int(max_steps)
        self.tracks = {}

        for (param, term), vals in schedules.items():
            values = [float(v) for v in vals]
            n = len(values)
            seglen = math.ceil(self.max_steps / n)
            self.tracks[(param, term)] = Track(
                values=values,
                n=n,
                seglen=seglen,
                idx=0,
                next_switch=seglen,
            )

    def maybe_step(self, state, step, set_st_params_fn):
        nu_updates = {}
        lam_updates = {}

        for (param, term), tr in self.tracks.items():
            if tr.idx >= tr.n - 1:
                continue
            if step < tr.next_switch:
                continue

            tr.idx = min(step // tr.seglen, tr.n - 1)
            tr.next_switch = (tr.idx + 1) * tr.seglen
            v = tr.values[tr.idx]

            if param == "nu":
                nu_updates[term] = v
            else:
                lam_updates[term] = v

        if not nu_updates and not lam_updates:
            return state, False

        kwargs = {}
        if nu_updates:
            nu_new = dict(state.st_params["nu"])
            nu_new.update(nu_updates)
            kwargs["nu"] = nu_new

        if lam_updates:
            lam_new = dict(state.st_params["lam"])
            lam_new.update(lam_updates)
            kwargs["lam"] = lam_new

        state = set_st_params_fn(state, **kwargs)
        return state, True

def update_em(state, nu, lam, set_st_params_fn):
    kwargs = {}
    nu_new = dict(state.st_params["nu"])
    nu_new.update({'res': nu})
    kwargs['nu'] = nu_new
    lam_new = dict(state.st_params["lam"])
    lam_new.update({'res': lam})
    kwargs['lam'] = lam_new
    state = set_st_params_fn(state, **kwargs)

    return state





class Trainer:
    def __init__(self, run_cfg: Dict[str, Any], *, device: Optional[str] = None):
        self.raw_cfg = copy.deepcopy(run_cfg)
        self.device = device
        self.run_id = self._make_run_id(self.raw_cfg)
        self.run_dir: Optional[str] = None
        self.model = None
        self.logger = Logger()

    @staticmethod
    def from_yaml(path: str, *, device: Optional[str] = None) -> "Trainer":
        cfg = yaml.safe_load(open(path, "r"))
        return Trainer(cfg, device=device)
    
    @staticmethod
    def _read_yaml(path: str) -> Dict[str, Any]:
        with open(path, "r") as f:
            return yaml.safe_load(f)
        
    @classmethod
    def from_run_dir(cls, run_dir: str, *, device: Optional[str] = None) -> "Trainer":
        cfg_path = os.path.join(run_dir, "config.yml")
        run_cfg = cls._read_yaml(cfg_path)
        t = cls(run_cfg, device=device)
        t.run_dir = run_dir
        return t

    @classmethod
    def from_run_config(cls, run_cfg: Dict[str, Any], *, device: Optional[str] = None) -> "Trainer":
        return cls(run_cfg, device=device)

    @staticmethod
    def _make_run_id(run_cfg: Dict[str, Any]) -> RunID:
        names = run_cfg.get("_names", {})
        pde_name = names.get("pde", run_cfg.get("pde", "pde"))
        model_name = names.get("model", run_cfg.get("exp_name", "model").split("-")[0])
        opt_name = names.get("optim", run_cfg.get("optim", {}).get("optimizer", "opt"))
        activation = run_cfg.get("activation")
        seed = int(run_cfg.get("seed", 0))
        exp_path = run_cfg.get("exp_path", "./runs")
        #res objective 
        res_objective = run_cfg.get("objectives", {}).get("terms", {}).get("res", "mse")
        exp_name = f"{model_name}-{pde_name}-{opt_name}-{res_objective}"
        sweep_id = f"{pde_name}_{model_name}_{opt_name}_{res_objective}"
        return RunID(exp_name, sweep_id, seed, exp_path)

    def resolve_run_dir(
        self,
        *,
        mode: Mode,
        select_index: Optional[int] = None,
        create_if_missing: Optional[bool] = None,
    ) -> str:
        if create_if_missing is None:
            create_if_missing = (mode == "train")
        print("Sweep ID:", self.run_id.sweep_id, select_index)
        self.run_dir = create_experiment_dir(
            exp_path=self.run_id.exp_path,
            exp_name=self.run_id.exp_name,
            seed=self.run_id.seed,
            sweep_id=self.run_id.sweep_id,
            create_if_missing=create_if_missing,
            select_index=select_index,
        )
        return self.run_dir

    def to_collection_cfg(self):
        return Collection.from_dict(copy.deepcopy(self.raw_cfg))

    def build(self) -> None:
        run_cfg = self.to_collection_cfg()
        self.model = get_pde(run_cfg)

    def save_config(self) -> None:
        cfg = copy.deepcopy(self.raw_cfg)
        cfg["exp_path"] = self.run_dir
        with open(os.path.join(self.run_dir, "config.yml"), "w") as f:
            yaml.dump(cfg, f, default_flow_style=False)

    def restore(self, checkpoints_dir: Optional[str] = None, step=None) -> None:
        if checkpoints_dir is None:
            checkpoints_dir = self.run_dir

        if not checkpoints_dir.endswith("checkpoints"):
            checkpoints_dir = os.path.join(checkpoints_dir, "checkpoints")
        self.model.state = restore_checkpoint(self.model.state, checkpoints_dir, step=step)


    


    def train(self) -> Dict[str, float]:
        self.resolve_run_dir(mode="train")
        
        

        wandb_cfg = self.raw_cfg.get("wandb", {})
        if wandb_cfg.get("use", False):
            import wandb
            wandb.init(
                project=wandb_cfg.get("project"),
                name= self.run_id.exp_name ,#self.raw_cfg.get("exp_name", self.run_id.exp_name),
                config=self.raw_cfg,
                mode=wandb_cfg.get("mode", "online"),
            )

        self.build()
        run_cfg = self.to_collection_cfg()
        self.save_config()

        num_epochs = int(run_cfg.training.num_epochs)
        log_every = int(run_cfg.logging.log_every)
        save_every = run_cfg.logging.save_every
        num_keep = run_cfg.logging.num_keep_ckpts

        sampler = iter(self.model.sampler)
        u_ref = self.model.u_ref
        param_count = count_params(self.model.state.params)

        if jax.process_index() == 0:
            print(f"Number of parameters: {param_count}")

        print(jax.tree_util.tree_map(lambda x: x.shape, self.model.state.params))

        t0 = time.time()

        ## weighting
        weighting_cfg = run_cfg.weighting
        scheme = weighting_cfg.scheme
        weight_update_freq = int(weighting_cfg.get("update_freq", 1000))
        st_params = self.model.state.st_params
        #print("Initial st_params:", st_params)

        #print(run_cfg.get("objectives", {}).get("student_t", {}))
        print("Devices:", jax.local_devices())

        if run_cfg.get("objectives", {}).get("student_t", {}).get("scheduler", None) is not None:
            print("Using student-t scheduler")
            sched_cfg = run_cfg.objectives.student_t.scheduler
            stepper = Stepper(
                schedules = {
                    (par, term): vals for par, terms in sched_cfg.items() for term, vals in terms.items()
                }, 
                max_steps = num_epochs
            )
            
            #stepper = Stepper(
            #    schedules = {
            #        #("nu", "res"): [5.0, 4.5, 3.0],
            #        ("lam", "bcs"): [0.02, 0.02, 0.02, 0.1, 0.5],
            #    }, 
            #    max_steps = num_epochs
            #)
            
            print( {
                    (par, term): vals for par, terms in sched_cfg.items() for term, vals in terms.items()
                })

        em_config = EMConfig(
            newton_steps_nu=10, 
            nu_min=1,
            nu_max=1000,
            eps_denom=1e-8, 
            a_lam=1000,
            b_lam=1000,

        )
        self.nu_template  = self.model.state.st_params["nu"]["res"]
        self.lam_template = self.model.state.st_params["lam"]["res"]



        for step in range(num_epochs):
            step_t0 = time.time()
            batch = next(sampler)
            #self.model.state = self.model.set_st_params(self.model.state, **st_params)
            #self.model.state, stepped = stepper.maybe_step(
            #    self.model.state, step, self.model.set_st_params
            #)

            self.model.state = self.model.step(self.model.state, batch)

            #if step > 5000 and (step ) % 2000 == 0:
            #    state0 = jax.tree_util.tree_map(lambda x: x[0], self.model.state)
            #    batch0 = jax.tree_util.tree_map(lambda x: x[0], batch)
            #    new_nu, new_lam = self.model.run_em(state0, batch0, em_config)
                #print(new_nu)
                #print(new_lam)
                #self.model.state = self.model.set_st_params(
                #    self.model.state,
                #    nu=new_nu,
                #    lam=new_lam
                #)
            def bcast_to_like(x, like):
                # like is a replicated leaf, e.g. shape (n_devices,) or (n_devices,1)
                x = jnp.asarray(x, dtype=like.dtype)
                return jnp.broadcast_to(x, like.shape)

            if (step + 1) % 5000 == 0:
                # Host copy of replica 0
                state0 = jax.device_get(tree_map(lambda x: x[0], self.model.state))
                batch0 = jax.device_get(tree_map(lambda x: x[0], batch))

                residuals = self.model.residuals(state0, batch0)['res']
                lam, nu = em(residuals,
                            state0.st_params['lam']['res'],
                            state0.st_params['nu']['res'],
                            em_config.a_lam, em_config.b_lam, em_config)

                # IMPORTANT: build nu/lam dicts from state0 (unreplicated)
                nu_host = dict(state0.st_params["nu"]);  nu_host["res"]  = float(jnp.asarray(nu).reshape(()))
                lam_host = dict(state0.st_params["lam"]); lam_host["res"] = float(jnp.asarray(lam).reshape(()))

                # Now replicate ONCE inside set_st_params
                self.model.state = self.model.set_st_params(self.model.state, nu=nu_host, lam=lam_host)

                print(lam, nu)


                
                

            if step % log_every == 0 and jax.process_index() == 0:
                state = jax.device_get(tree_map(lambda x: x[0], self.model.state))
                batch0 = jax.device_get(tree_map(lambda x: x[0], batch))
                logs = self.model.log(state, batch0, u_ref)
                #print(logs['log_sigma'])
                self.logger.log_iter(step, step_t0, time.time(), logs)
                if wandb_cfg.get("use", False):
                    import wandb
                    wandb.log(logs, step=step)

            if save_every is not None and jax.process_index() == 0 and (step % save_every == 0 or (step + 1) == num_epochs):
                    
                    save_checkpoint(
                        self.model.state,
                        os.path.join(self.run_dir, "checkpoints"),
                        keep=num_keep,
                    )

        metrics = self.evaluate()

        if jax.process_index() == 0:
            t1 = time.time()
            if wandb_cfg.get("use", False):
                import wandb
                wandb.log(
                    {**metrics, "total_time": t1 - t0, "param_count": param_count}
                )
                wandb.finish()
            print(f"Training completed in {t1 - t0:.2f} seconds.")

        return metrics

    def evaluate(
        self,
        *,
        select_index: Optional[int] = None,
        load_from_this_run: bool = True,
        plot: bool = True
    ) -> Dict[str, float]:
        if not load_from_this_run:
            self.resolve_run_dir(mode="eval", select_index=select_index, create_if_missing=False)

        self.build()

        ckpt_dir = (
            os.path.join(self.run_dir, "checkpoints")
            if load_from_this_run
            else self.run_dir
        )
        self.restore(ckpt_dir)

        state = self.model.state
        params = self.model.state.params
        u_ref = self.model.u_ref

        rmse = float(self.model.compute_l2_error(state))
        rmae = float(self.model.compute_rmae(state))

        metrics = {"rmse": rmse, "rmae": rmae}

        if jax.process_index() == 0:
            out_dir = (
                os.path.dirname(self.run_dir)
                if self.run_dir.endswith("checkpoints")
                else self.run_dir
            )
            with open(os.path.join(out_dir, "metrics.json"), "w") as f:
                json.dump(metrics, f, indent=4)
            if plot:
                t_star = self.model.t_star
                x_star = self.model.x_star
                u_pred = self.model.u_pred_fn(state, t_star, x_star)
                self.plot2d(t_star, x_star, u_ref, u_pred, out_dir)

        return metrics


    def plot2d(self, 
        t_star, 
        x_star,
        u_ref,
        u_pred, 
        out_dir

    ):
        TT, XX = jnp.meshgrid(t_star, x_star, indexing="ij")

        # plot
        fig = plt.figure(figsize=(18, 5))
        plt.subplot(1, 3, 1)
        plt.pcolor(TT, XX, u_ref, cmap="jet")
        plt.colorbar()
        plt.xlabel("t")
        plt.ylabel("x")
        plt.title("Exact")
        plt.tight_layout()

        plt.subplot(1, 3, 2)
        plt.pcolor(TT, XX, u_pred, cmap="jet")
        plt.colorbar()
        plt.xlabel("t")
        plt.ylabel("x")
        plt.title("Predicted")
        plt.tight_layout()

        plt.subplot(1, 3, 3)
        plt.pcolor(TT, XX, jnp.abs(u_ref - u_pred), cmap="jet")
        plt.colorbar()
        plt.xlabel("t")
        plt.ylabel("x")
        plt.title("Absolute error")
        plt.tight_layout()
        fig_path = os.path.join(out_dir, "figure.pdf")
        fig.savefig(fig_path, bbox_inches="tight", dpi=300)
    

    def load_model(
        self,
        *,
        select_index: Optional[int] = None,
        run_config: Optional[Dict[str, Any]] = None,
        run_config_path: Optional[str] = None,
        run_dir: Optional[str] = None,
    ):
        if sum(x is not None for x in [run_config, run_config_path, run_dir]) > 1:
            raise ValueError("Pass only one of: run_config, run_config_path, run_dir.")

        if run_config_path is not None:
            run_config = self._read_yaml(run_config_path)

        if run_dir is not None:
            cfg_path = os.path.join(run_dir, "config.yml")
            run_config = self._read_yaml(cfg_path)
            self.run_dir = run_dir

        if run_config is not None:
            self.raw_cfg = copy.deepcopy(run_config)
            self.run_id = self._make_run_id(self.raw_cfg)

        if self.run_dir is None:
            self.resolve_run_dir(mode="eval", select_index=select_index, create_if_missing=False)
        self.build()
        self.restore()  
        return self.model


from phijax.data import UniformSampler

class CurriculumTrainer(Trainer):
    def _window_ckpt_dir(self, idx: int) -> str:
        if self.run_dir is None:
            raise RuntimeError("run_dir is not set. Call resolve_run_dir() first.")
        return os.path.join(self.run_dir, "checkpoints", f"window_{idx:03d}")

    def restore_window(self, idx: int) -> None:
        if self.model is None:
            self.build()
        ckpt_dir = self._window_ckpt_dir(idx)
        self.model.state = restore_checkpoint(self.model.state, ckpt_dir)
    



class TimeMarchingTrainer(CurriculumTrainer):
    def _window_ckpt_dir(self, idx: int) -> str:
        if self.run_dir is None:
            raise RuntimeError("run_dir is not set. Call resolve_run_dir() first.")
        return os.path.join(self.run_dir, "checkpoints", f"window_{idx:03d}")
    def train(self) -> Dict[str, float]:
        self.resolve_run_dir(mode="train")
        self.save_config()

        wandb_cfg = self.raw_cfg.get("wandb", {})
        use_wandb = bool(wandb_cfg.get("use", False))
        if use_wandb:
            import wandb
            wandb.init(
                project=wandb_cfg.get("project"),
                name=self.raw_cfg.get("exp_name", self.run_id.exp_name),
                config=self.raw_cfg,
                mode=wandb_cfg.get("mode", "online"),
            )

        self.build()
        run_cfg = self.to_collection_cfg()

        num_epochs = int(run_cfg.training.num_epochs)
        log_every = int(run_cfg.logging.log_every)
        save_every = run_cfg.logging.save_every
        num_keep = run_cfg.logging.num_keep_ckpts

        weighting_cfg = run_cfg.weighting
        scheme = weighting_cfg.scheme
        weight_update_freq = int(weighting_cfg.get("update_freq", 1000))

        num_time_windows = int(run_cfg.training.num_time_windows)

        u_ref = self.model.u_ref
        v_ref = self.model.v_ref
        t_star = self.model.t_star
        x_star = self.model.x_star
        y_star = self.model.y_star

        u0 = u_ref[0, ...]
        v0 = v_ref[0, ...]

        num_time_steps = len(t_star) // num_time_windows
        t = t_star[:num_time_steps]

        dt = t[1] - t[0]
        t0 = t[0]
        t1 = t[-1] + 0.1 * dt

        x0 = x_star[0]
        x1 = x_star[-1]
        y0 = y_star[0]
        y1 = y_star[-1]

        print(t.shape, x_star.shape, y_star.shape)

        dom = jnp.array([[t0, t1], [x0, x1], [y0, y1]])
        sampler = iter(UniformSampler(dom, int(run_cfg.training.batch_size)))

        if jax.process_index() == 0:
            param_count = count_params(self.model.state.params)
            print(f"Number of parameters: {param_count}")
            print(tree_map(lambda x: x.shape, self.model.state))
            print(tree_map(lambda x: x.shape, self.model.state.params))

        for idx in range(num_time_windows):
            step_offset = idx * num_epochs

            

            u_star = u_ref[num_time_steps * idx : num_time_steps * (idx + 1), ...]
            v_star = v_ref[num_time_steps * idx : num_time_steps * (idx + 1), ...]

            self.model.set_initial_condition(u0, v0, u_star, v_star, t)

            self.train_window(
                sampler=sampler,
                u_star=u_star,
                v_star=v_star,
                num_steps=num_epochs,
                step_offset=step_offset,
                log_every=log_every,
                save_every=save_every,
                num_keep=num_keep,
                scheme=scheme,
                weight_update_freq=weight_update_freq,
                use_wandb=use_wandb,
                idx = idx
            )

            if num_time_windows > 1:
                state = jax.device_get(tree_map(lambda x: x[0], self.model.state))

                u0 = self.model.u0_pred_fn(state, t_star[num_time_steps], x_star, y_star)
                v0 = self.model.v0_pred_fn(state, t_star[num_time_steps], x_star, y_star)

                del state
            print(f"Completed time window {idx + 1} / {num_time_windows}")

        return self.evaluate(plot=True)

    def train_window(
        self,
        *,
        sampler,
        u_star,
        v_star,
        num_steps: int,
        step_offset: int,
        log_every: int,
        save_every,
        num_keep: int,
        scheme: str,
        weight_update_freq: int,
        use_wandb: bool,
        idx: int
    ) -> None:
        for step in range(num_steps):
            step_t0 = time.time()
            batch = next(sampler)

            self.model.state = self.model.step(self.model.state, batch)

            if (
                jax.process_index() == 0
                and scheme in ["ntk", "align", "groupdro"]
                and (step + 1) % weight_update_freq == 0
            ):
                self.model.state = self.model.update_weights(self.model.state, batch)

            if step % log_every == 0 and jax.process_index() == 0:
                state = jax.device_get(tree_map(lambda x: x[0], self.model.state))
                batch0 = jax.device_get(tree_map(lambda x: x[0], batch))
                logs = self.model.log(state, batch0)
                self.logger.log_iter(step, step_t0, time.time(), logs)
                if use_wandb:
                    import wandb
                    wandb.log(logs, step=step + step_offset)

            ckpt_dir = self._window_ckpt_dir(idx)
            if save_every is not None and jax.process_index() == 0:
                if (step + 1) % int(save_every) == 0 or (step + 1) == num_steps:
                    save_checkpoint(
                        self.model.state,
                        ckpt_dir,
                        keep=num_keep,
                    )

    def restore_window(self, idx: int) -> None:
        if self.model is None:
            self.build()
        ckpt_dir = self._window_ckpt_dir(idx)
        self.model.state = restore_checkpoint(self.model.state, ckpt_dir)



class SpaceSampler(BaseSampler):
    def __init__(self, coords, batch_size, rng_key=random.PRNGKey(1234)):
        super().__init__(batch_size, rng_key)
        self.coords = coords

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
        "Generates data containing batch_size samples"
        idx = random.choice(key, self.coords.shape[0], shape=(self.batch_size,))
        batch = self.coords[idx, :]

        return batch


class ICSampler(SpaceSampler):
    def __init__(self, u, v, p, temp, coords, batch_size, rng_key=random.PRNGKey(1234)):
        super().__init__(coords, batch_size, rng_key)

        self.u = u
        self.v = v
        self.p = p
        self.temp = temp

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
        "Generates data containing batch_size samples"
        idx = random.choice(key, self.coords.shape[0], shape=(self.batch_size,))

        coords_batch = self.coords[idx, :]
        u_batch = self.u[idx]
        v_batch = self.v[idx]
        p_batch = self.p[idx]
        temp_batch = self.temp[idx]

        batch = (coords_batch, u_batch, v_batch, p_batch, temp_batch)

        return batch


class BCsSampler(BaseSampler):
    def __init__(self, dom,  batch_size, rng_key=random.PRNGKey(1234)):
        super().__init__(batch_size, rng_key)

        self.dom = dom

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
        subkeys = random.split(key, 3)

        t = random.uniform(subkeys[0], (self.batch_size // 2,), minval=self.dom[0][0], maxval=self.dom[0][1])
        x = random.uniform(subkeys[1], (self.batch_size // 2,), minval=self.dom[1][0], maxval=self.dom[1][1])

        bc1_batch = jnp.stack([t, x, jnp.zeros_like(x)]).T
        bc2_batch = jnp.stack([t, x, 2 * jnp.ones_like(x)]).T

        bc_batch = jnp.vstack([bc1_batch, bc2_batch])

        return bc_batch

class NSTrainer(CurriculumTrainer):
    def train(self) -> Dict[str, float]:
        self.resolve_run_dir(mode="train")
        self.save_config()

        wandb_cfg = self.raw_cfg.get("wandb", {})
        use_wandb = bool(wandb_cfg.get("use", False))
        if use_wandb:
            import wandb
            wandb.init(
                project=wandb_cfg.get("project"),
                name=self.raw_cfg.get("exp_name", self.run_id.exp_name),
                config=self.raw_cfg,
                mode=wandb_cfg.get("mode", "online"),
            )

        self.build()
        run_cfg = self.to_collection_cfg()

        num_epochs = int(run_cfg.training.num_epochs)
        log_every = int(run_cfg.logging.log_every)
        save_every = run_cfg.logging.save_every
        num_keep = run_cfg.logging.num_keep_ckpts

        weighting_cfg = run_cfg.weighting
        scheme = weighting_cfg.scheme
        weight_update_freq = int(weighting_cfg.get("update_freq", 1000))

        num_time_windows = int(run_cfg.training.num_time_windows)

        velocity, pressure, temperature, t_star, coords, alpha1, alpha2, alpha3, alpha4 = self.model.get_ref()


        u0 = self.model.u0
        v0 = self.model.v0
        p0 = self.model.p0
        T0 = self.model.T0

       

        num_time_steps = len(t_star) // num_time_windows
        t = t_star[:num_time_steps]

        dt = t[1] - t[0]
        t0 = t[0]
        t1 = t[-1] + 0.1 * dt

        x0 = 0.0
        x1 = 1.0

        y0 = 0.0
        y1 = 2.0

        dom = jnp.array([[t0, t1], [x0, x1], [y0, y1]])
        

        if jax.process_index() == 0:
            param_count = count_params(self.model.state.params)
            print(f"Number of parameters: {param_count}")
           # print(tree_map(lambda x: x.shape, self.model.state))
            #print(tree_map(lambda x: x.shape, self.model.state.params))

        

        for idx in range(num_time_windows):
            step_offset = idx * num_epochs
            ics_sampler = ICSampler(
                u0,
                v0,
                p0,
                T0,
                coords,
                run_cfg.training.batch_size * 2,

            )
            bcs_sampler = BCsSampler(
                dom,
                run_cfg.training.batch_size * 2,
            )
            res_sampler = UniformSampler(
                dom,
                run_cfg.training.batch_size,
            )
            samplers = {"ics": iter(ics_sampler), "bcs": iter(bcs_sampler), "res": iter(res_sampler)}

            u_star = velocity[num_time_steps * idx: num_time_steps * (idx + 1), :, 0]
            v_star = velocity[num_time_steps * idx: num_time_steps * (idx + 1), :, 1]
            temp_star = temperature[num_time_steps * idx: num_time_steps * (idx + 1)]

            velocity_scale = jnp.max(jnp.sqrt(u0 ** 2 + v0 ** 2))

            if idx > 0:
                self.model = get_pde(self.to_collection_cfg())
                self.model.set_initial_condition(
                    u0,
                    v0,
                    p0,
                    T0,
                    u_star,
                    v_star,
                    temp_star,
                    t,
                )

            self.train_window(
                sampler=samplers,
                u_star=u_star,
                v_star=v_star,
                temp_star=temp_star,
                t_star=t, 
                coords=coords,
                num_steps=num_epochs,
                step_offset=step_offset,
                log_every=log_every,
                save_every=save_every,
                num_keep=num_keep,
                scheme=scheme,
                weight_update_freq=weight_update_freq,
                use_wandb=use_wandb,
                idx = idx
            )

            if num_time_windows > 1:
                state = jax.device_get(tree_map(lambda x: x[0], self.model.state))

                u0 = self.model.u_ic_pred_fn(state, t_star[num_time_steps], coords[:, 0], coords[:, 1])
                v0 = self.model.v_ic_pred_fn(state, t_star[num_time_steps], coords[:, 0], coords[:, 1])
                p0 = self.model.p_ic_pred_fn(state, t_star[num_time_steps], coords[:, 0], coords[:, 1])
                T0 = self.model.temp_ic_pred_fn(state, t_star[num_time_steps], coords[:, 0], coords[:, 1])

                del state, self.model
            print(f"Completed time window {idx + 1} / {num_time_windows}")

        #return self.evaluate(plot=True)

    def evaluate(
        self,
        *,
        select_index: Optional[int] = None,
        load_from_this_run: bool = True,
        plot: bool = True
    ) -> Dict[str, float]:
        if not load_from_this_run:
            self.resolve_run_dir(mode="eval", select_index=select_index, create_if_missing=False)

        self.build()

        ckpt_dir = (
            os.path.join(self.run_dir, "checkpoints")
            if load_from_this_run
            else self.run_dir
        )
        self.restore_window(0)

        velocity, pressure, temperature, t_star, coords, alpha1, alpha2, alpha3, alpha4 = self.model.get_ref()
        state = self.model.state
        #remove the last step
        #velocity = velocity[:-1, :]
        #pressure = pressure[:-1, :]
        #temperature = temperature[:-1, :]


        cfg = self.to_collection_cfg()

        num_time_steps = len(t_star) // int(cfg.training.num_time_windows)
        t = t_star[:num_time_steps]

        u_preds = []
        v_preds = []
        temp_preds = []
        for idx in range(int(cfg.training.num_time_windows)):
            self.restore_window(idx)
            state = self.model.state
            
            u_star = velocity[num_time_steps * idx: num_time_steps * (idx + 1), :, 0]
            v_star = velocity[num_time_steps * idx: num_time_steps * (idx + 1), :, 1]
            temp_star = temperature[num_time_steps * idx: num_time_steps * (idx + 1)]

            u_error, v_error, temp_error = self.model.compute_metrics(state, t, coords, u_star, v_star, temp_star)
            self.logger.info(f"Window {idx}: u RMSE: {u_error:.6f}, v RMSE: {v_error:.6f}, temp RMSE: {temp_error:.6f}")

            u_pred = self.model.u_pred_fn(state, t, coords[:, 0], coords[:, 1])
            v_pred = self.model.v_pred_fn(state, t, coords[:, 0], coords[:, 1])
            temp_pred = self.model.temp_pred_fn(state, t, coords[:, 0], coords[:, 1])

            u_preds.append(u_pred)
            v_preds.append(v_pred)
            temp_preds.append(temp_pred)

        u_preds = jnp.concatenate(u_preds, axis=0)
        v_preds = jnp.concatenate(v_preds, axis=0)
        temp_preds = jnp.concatenate(temp_preds, axis=0)

        u_ref = velocity[..., 0]
        v_ref = velocity[..., 1]       
        temp_ref = temperature

        u_error = jnp.linalg.norm(u_ref - u_preds) / jnp.linalg.norm(u_ref)
        v_error = jnp.linalg.norm(v_ref - v_preds) / jnp.linalg.norm(v_ref)
        temp_error = jnp.linalg.norm(temp_ref - temp_preds) / jnp.linalg.norm(temp_ref)

        self.logger.info("Overall u error: {:.3e}".format(u_error))
        self.logger.info("Overall v error: {:.3e}".format(v_error))
        self.logger.info("Overall temp error: {:.3e}".format(temp_error))

        #  plots


    def train_window(
        self,
        *,
        sampler,
        u_star,
        v_star,
        temp_star,
        t_star,
        coords,
        num_steps: int,
        step_offset: int,
        log_every: int,
        save_every,
        num_keep: int,
        scheme: str,
        weight_update_freq: int,
        use_wandb: bool,
        idx: int
    ) -> None:
        for step in range(num_steps):
            step_t0 = time.time()
            batch = {k: next(sam) for k, sam in sampler.items()}

            self.model.state = self.model.step(self.model.state, batch)

            

            if step % log_every == 0 and jax.process_index() == 0:
                state = jax.device_get(tree_map(lambda x: x[0], self.model.state))
                batch0 = jax.device_get(tree_map(lambda x: x[0], batch))
                logs = self.model.log(state, batch0, t_star, coords, u_star, v_star, temp_star)
                self.logger.log_iter(step, step_t0, time.time(), logs)
                if use_wandb:
                    import wandb
                    wandb.log(logs, step=step + step_offset)

            ckpt_dir = self._window_ckpt_dir(idx)
            if save_every is not None and jax.process_index() == 0:
                if (step + 1) % int(save_every) == 0 or (step + 1) == num_steps:
                    save_checkpoint(
                        self.model.state,
                        ckpt_dir,
                        keep=num_keep,
                    )

    

class KFICSSampler(BaseSampler):

    def __init__(self, u, v,w, coords, batch_size, rng_key=random.PRNGKey(1234)):
        super().__init__(coords, batch_size, rng_key)

        self.u = u
        self.v = v
        self.w = w

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
       
        idx = random.choice(key, self.coords.shape[0], shape=(self.batch_size,))

        coords_batch = self.coords[idx, :]
        u_batch = self.u[idx]
        v_batch = self.v[idx]
        w_batch = self.w[idx]

        batch = (coords_batch, u_batch, v_batch, w_batch)

        return batch

class KolmogorovFlowTrainer(CurriculumTrainer):
    def train(self) -> Dict[str, float]:
        self.resolve_run_dir(mode="train")
        self.save_config()

        wandb_cfg = self.raw_cfg.get("wandb", {})
        use_wandb = bool(wandb_cfg.get("use", False))
        if use_wandb:
            import wandb
            wandb.init(
                project=wandb_cfg.get("project"),
                name=self.raw_cfg.get("exp_name", self.run_id.exp_name),
                config=self.raw_cfg,
                mode=wandb_cfg.get("mode", "online"),
            )

        self.build()
        run_cfg = self.to_collection_cfg()

        num_epochs = int(run_cfg.training.num_epochs)
        log_every = int(run_cfg.logging.log_every)
        save_every = run_cfg.logging.save_every
        num_keep = run_cfg.logging.num_keep_ckpts

        weighting_cfg = run_cfg.weighting
        scheme = weighting_cfg.scheme
        weight_update_freq = int(weighting_cfg.get("update_freq", 1000))

        num_time_windows = int(run_cfg.training.num_time_windows)

        u_ref, v_ref, w_ref, t_star, coords, nu  = self.model.get_ref()


        u0 = self.model.u0
        v0 = self.model.v0
        w0 = self.model.w0

       

        num_time_steps = len(t_star) // num_time_windows
        t = t_star[:num_time_steps]

        dt = t[1] - t[0]
        t0 = t[0]
        t1 = t[-1] + 1.1 * dt

        x0 = 0.0
        x1 = 1.0

        y0 = 0.0
        y1 = 1.0

        dom = jnp.array([[t0, t1], [x0, x1], [y0, y1]])
        

        if jax.process_index() == 0:
            param_count = count_params(self.model.state.params)
            print(f"Number of parameters: {param_count}")
                 

        for idx in range(num_time_windows):
            step_offset = idx * num_epochs
            ics_sampler =  KFICSSampler(u0, v0, w0, coords, run_cfg.training.batch_size * 2)
            res_sampler = UniformSampler(dom, run_cfg.training.batch_size)


            samplers = {"ics": iter(ics_sampler),  "res": iter(res_sampler)}

            u_star = u_ref[num_time_steps * idx: num_time_steps * (idx + 1), :]
            v_star = v_ref[num_time_steps * idx: num_time_steps * (idx + 1), :]
            w_star = w_ref[num_time_steps * idx: num_time_steps * (idx + 1), :]



           

            if idx > 0:
                self.model = get_pde(self.to_collection_cfg())
                self.model.set_initial_condition(
                    u0,
                    v0,
                    w0,
                    u_star,
                    v_star,
                    w_star,
                    t
                )

            self.train_window(
                sampler=samplers,
                u_star=u_star,
                v_star=v_star,
                w_star=w_star,
                t_star=t, 
                coords=coords,
                num_steps=num_epochs,
                step_offset=step_offset,
                log_every=log_every,
                save_every=save_every,
                num_keep=num_keep,
                scheme=scheme,
                weight_update_freq=weight_update_freq,
                use_wandb=use_wandb,
                idx = idx
            )

            if num_time_windows > 1:
                state = jax.device_get(tree_map(lambda x: x[0], self.model.state))

                u0 = self.model.u_ic_pred_fn(state, t_star[num_time_steps], coords[:, 0], coords[:, 1])
                v0 = self.model.v_ic_pred_fn(state, t_star[num_time_steps], coords[:, 0], coords[:, 1])
                w0 = self.model.w_ic_pred_fn(state, t_star[num_time_steps], coords[:, 0], coords[:, 1])
              

                del state, self.model
            print(f"Completed time window {idx + 1} / {num_time_windows}")

        return self.evaluate(plot=True)

    def train_window(
        self,
        *,
        sampler,
        u_star,
        v_star,
        w_star,
        t_star,
        coords,
        num_steps: int,
        step_offset: int,
        log_every: int,
        save_every,
        num_keep: int,
        scheme: str,
        weight_update_freq: int,
        use_wandb: bool,
        idx: int
    ) -> None:
        for step in range(num_steps):
            step_t0 = time.time()
            batch = {k: next(sam) for k, sam in sampler.items()}

            self.model.state = self.model.step(self.model.state, batch)
            if step % log_every == 0 and jax.process_index() == 0:
                state = jax.device_get(tree_map(lambda x: x[0], self.model.state))
                batch0 = jax.device_get(tree_map(lambda x: x[0], batch))
                logs = self.model.log(state, batch0, t_star, coords, u_star, v_star, w_star)
                self.logger.log_iter(step, step_t0, time.time(), logs)
                if use_wandb:
                    import wandb
                    wandb.log(logs, step=step + step_offset)

            ckpt_dir = self._window_ckpt_dir(idx)
            if save_every is not None and jax.process_index() == 0:
                if (step + 1) % int(save_every) == 0 or (step + 1) == num_steps:
                    save_checkpoint(
                        self.model.state,
                        ckpt_dir,
                        keep=num_keep,
                    )

    def evaluate(
        self,
        *,
        select_index: Optional[int] = None,
        load_from_this_run: bool = True,
        plot: bool = True
    ) -> Dict[str, float]:
        if not load_from_this_run:
            self.resolve_run_dir(mode="eval", select_index=select_index, create_if_missing=False)

        self.build()

        ckpt_dir = (
            os.path.join(self.run_dir, "checkpoints")
            if load_from_this_run
            else self.run_dir
        )
        self.restore_window(0)

        u_ref, v_ref, w_ref, t_star, coords, nu  = self.model.get_ref()
        state = self.model.state
       

        cfg = self.to_collection_cfg()

        num_time_steps = len(t_star) // int(cfg.training.num_time_windows)
        t = t_star[:num_time_steps]

        u_preds = []
        v_preds = []
        w_preds = []
        for idx in range(int(cfg.training.num_time_windows)):
            self.restore_window(idx)
            state = self.model.state
            
            u_star = u_ref[num_time_steps * idx: num_time_steps * (idx + 1), :]
            v_star = v_ref[num_time_steps * idx: num_time_steps * (idx + 1), :]
            w_star = w_ref[num_time_steps * idx: num_time_steps * (idx + 1), :]

            u_error, v_error, w_error = self.model.compute_metrics(state, t, coords, u_star, v_star, w_star)
            self.logger.info(f"Window {idx}: u RMSE: {u_error:.6f}, v RMSE: {v_error:.6f}, w RMSE: {w_error:.6f}")

            u_pred = self.model.u_pred_fn(state, t, coords[:, 0], coords[:, 1])
            v_pred = self.model.v_pred_fn(state, t, coords[:, 0], coords[:, 1])
            w_pred = self.model.w_pred_fn(state, t, coords[:, 0], coords[:, 1])

            u_preds.append(u_pred)
            v_preds.append(v_pred)
            w_preds.append(w_pred)

        u_preds = jnp.concatenate(u_preds, axis=0)
        v_preds = jnp.concatenate(v_preds, axis=0)
        w_preds = jnp.concatenate(w_preds, axis=0)

        u_ref_all = u_ref
        v_ref_all = v_ref
        w_ref_all = w_ref
        u_error = jnp.linalg.norm(u_ref_all - u_preds) / jnp.linalg.norm(u_ref_all)
        v_error = jnp.linalg.norm(v_ref_all - v_preds) / jnp.linalg.norm(v_ref_all)
        w_error = jnp.linalg.norm(w_ref_all - w_preds) / jnp.linalg.norm(w_ref_all)
        self.logger.info("Overall u error: {:.3e}".format(u_error))
        self.logger.info("Overall v error: {:.3e}".format(v_error))
        self.logger.info("Overall w error: {:.3e}".format(w_error))

        #  plots


    

        

    
        




