

from functools import partial
from dataclasses import dataclass
from typing import Any

import jax
import jax.numpy as jnp
from jax import jit, grad#, tree_map
from jax.tree_util import tree_map, tree_leaves
from jax.flatten_util import ravel_pytree





import os, tempfile, shutil, time, glob, json, re
import jax.tree_util as jtu
from pathlib import Path

from flax.training import checkpoints
from flax import serialization as fser

import logging
from tabulate import tabulate



def flatten_pytree(pytree):
    return ravel_pytree(pytree)[0]


@partial(jit, static_argnums=(0,))
def jacobian_fn(apply_fn, params, *args):
    # apply_fn needs to be a scalar function
    J = grad(apply_fn, argnums=0)(params, *args)
    J, _ = ravel_pytree(J)
    return J


@partial(jit, static_argnums=(0,))
def ntk_fn(apply_fn, params, *args):
    # apply_fn needs to be a scalar function
    J = jacobian_fn(apply_fn, params, *args)
    K = jnp.dot(J, J)
    return K


def save_checkpoint__(state, workdir, keep=5, name=None):
    if not os.path.isdir(workdir):
        os.makedirs(workdir)
    if jax.process_index() == 0:
        state = jax.device_get(tree_map(lambda x: x[0], state))
        step = int(state.step)
        checkpoints.save_checkpoint(workdir, state, step=step, keep=keep)


def _unreplicate(x):
    try:
        if getattr(x, "ndim", 0) > 0 and x.shape[0] == jax.local_device_count():
            return x[0]
    except Exception:
        pass
    return x

def _safe_write_bytes(path: Path, data: bytes):
    path.parent.mkdir(parents=True, exist_ok=True)
    with tempfile.NamedTemporaryFile(dir=path.parent, delete=False) as tmp:
        tmp.write(data)
        tmp.flush()
        os.fsync(tmp.fileno())
        tmp_name = tmp.name
    os.replace(tmp_name, path)

def _safe_rmtree(path: Path, retries=3, delay=0.2):
    for i in range(retries):
        try:
            shutil.rmtree(path)
            return
        except FileNotFoundError:
            return
        except Exception:
            if i == retries - 1:
                raise
            time.sleep(delay)

def save_ckpt_bytes(state, workdir, keep: int | None = None):
    if jax.process_index() != 0:
        return
    state = jtu.tree_map(lambda x: x.block_until_ready() if hasattr(x, "block_until_ready") else x, state)
    host_state = jax.device_get(jtu.tree_map(_unreplicate, state))
    step  = int(getattr(host_state, "step", 0))

    payload = fser.to_bytes(host_state)
    step_dir = Path(workdir) / str(int(step))
    _safe_write_bytes(step_dir / "state.msgpack", payload)

    # Prune
    if keep and keep > 0:
        steps = sorted(int(p.name) for p in Path(workdir).iterdir() if p.is_dir() and p.name.isdigit())
        for old in steps[:-keep]:
            _safe_rmtree(Path(workdir) / str(old))


def restore_ckpt_bytes(workdir, target_like, step: int | str = "latest"):
    wd = Path(workdir)
    dirs = [p for p in wd.iterdir() if p.is_dir() and p.name.isdigit()]
    if not dirs:
        raise FileNotFoundError(f"No checkpoints found in {workdir}")

    if step is None or step == "latest":
        step = max(int(p.name) for p in dirs)

    ckpt_file = wd / str(int(step)) / "state.msgpack"
    if not ckpt_file.exists():
        raise FileNotFoundError(f"Checkpoint {ckpt_file} not found.")

    data = ckpt_file.read_bytes()
    return fser.from_bytes(target_like, data)


def peek_checkpoint(path):
    with open(path, "rb") as f:
        data = f.read()
    obj = fser.msgpack_restore(data)
    def show_tree(tree, prefix=""):
        if isinstance(tree, dict):
            for k, v in tree.items():
                show_tree(v, prefix + f"{k}/")
        elif hasattr(tree, "shape"):
            print(f"{prefix} {list(tree.shape)} {tree.dtype}")
    show_tree(obj)



def save_checkpoint(state, workdir, keep=5, name="state"):
    save_ckpt_bytes(state, workdir, keep=keep)

def restore_checkpoint(state, workdir, step=None):
    print(f"Restoring checkpoint from {workdir} at step {step}")
    return restore_ckpt_bytes(workdir, state, step=step)



def restore_checkpoint__(state, workdir, step=None):
    if isinstance(
        jax.tree_map(lambda x: x.sharding, jax.tree_leaves(state.params))[0],
        jax.sharding.PmapSharding,
    ):
        state = jax.tree_map(lambda x: x[0], state)

    assert isinstance(
        jax.tree_map(lambda x: x.sharding, jax.tree_leaves(state.params))[0],
        jax.sharding.SingleDeviceSharding,
    )
    state = checkpoints.restore_checkpoint(workdir, state, step=step)
    return state


class CustomJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, jnp.ndarray):
            return obj.tolist() 
        return json.JSONEncoder.default(self, obj)


def save_config(config, workdir, name=None):
    # Create the workdir if it doesn't exist.
    if not os.path.isdir(workdir):
        os.makedirs(workdir)

    # Set default name if not provided
    if name is None:
        name = "config"
    # Correctly append the '.json' extension to the filename
    config_path = os.path.join(workdir, name + ".json")

    # Write the config to a JSON file
    with open(config_path, "w") as config_file:
        json.dump(config.to_dict(), config_file, cls=CustomJSONEncoder, indent=4)




def get_log_keys(log_dict):
    key_list = []
    for key in log_dict.keys():
        if key.endswith("_loss"):
            key_list.append(key)
        elif key.endswith("_error"):
            key_list.append(key)
        elif key.startswith("scale_"):
            key_list.append(key)
    return key_list


class Logger:
    def __init__(self, name: str = "main", level=logging.INFO):
        self.logger = logging.getLogger(name)
        self.logger.handlers.clear()
        self.logger.setLevel(level)          # <-- THIS is the missing line
        self.logger.propagate = False

        formatter = logging.Formatter(
            "[%(asctime)s] %(message)s", datefmt="%H:%M:%S"
        )
        handler = logging.StreamHandler()
        handler.setFormatter(formatter)
        handler.setLevel(level)
        self.logger.addHandler(handler)

    def info(self, message):
        self.logger.info(message)

    def log_iter(self, step, start_time, end_time, log_dict):
        log_keys = get_log_keys(log_dict)
        log_list = [[key, "{:.3e}".format(float(log_dict[key]))] for key in log_keys]

        message = tabulate(
            log_list,
            headers=[f"Iter: {step:3d}", f"Time: {end_time - start_time:.3f}"],
            tablefmt="simple",
            numalign="right",
            disable_numparse=True,
        )

        header_length = len(message.split("\n")[0]) + 2
        dashed_line = "-" * header_length
        message = dashed_line + "\n" + message

        for line in message.split("\n"):
            self.logger.info(line)



class Collection(dict):
    __slots__ = ()

    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError(f"No such config key: {key}")

    def __setattr__(self, key, value):
        if key.startswith("_"):
            return super().__setattr__(key, value)
        self[key] = self._wrap(value)

    def __delattr__(self, key):
        try:
            del self[key]
        except KeyError:
            raise AttributeError(f"No such config key: {key}")

    @staticmethod
    def _wrap(value):
        if isinstance(value, dict):
            return Collection({k: Collection._wrap(v) for k, v in value.items()})
        if isinstance(value, list):
            return [Collection._wrap(v) for v in value]
        return value

    @classmethod
    def from_dict(cls, d: dict):
        return cls({k: cls._wrap(v) for k, v in d.items()})
    




def count_params(params) -> int:
    return sum(x.size for x in tree_leaves(params))
    



##### ======= Training utils ======= #####





@dataclass(frozen=True)
class LoadedRuns:
    run_root: str
    cfg: Any = None
    model: Any = None
    step: Any = None
  

def _parse_step(path):
    b = os.path.basename(path.rstrip("/"))
    m = re.search(r"(\d+)(?!.*\d)", b)
    return int(m.group(1)) if m else None


def list_checkpoint_steps(run_root):
    ckpt_root = os.path.join(run_root, "checkpoints")
    steps = []
    for p in glob.glob(os.path.join(ckpt_root, "*")):
        s = _parse_step(p)
        if s is not None:
            steps.append(s)
    return sorted(set(steps))


def load_checkpoint_as_run(trainer_cls, run_root, step, device=None):
    from phijax.equations import get_pde
    t = trainer_cls.from_run_dir(run_root, device=device)
    t.run_dir = run_root
    try:
        t.build()
        t.restore(run_root, step=step)
    except:
        run_cfg = t.to_collection_cfg()
        run_cfg.flag = "state_fail"
        t.model = get_pde(run_cfg)
        t.restore(run_root, step=step)
    return LoadedRuns(
        run_root=run_root,
        step=step,
        cfg=t.raw_cfg,
        model=t.model,
        #state = t.model.state,
    )
    #return {"run_root": run_root, "step": step, "cfg": t.raw_cfg, "model": t.model}


def load_all_checkpoints_as_runs(trainer_cls, run_root, device=None, steps=None):
    from phijax.equations import get_pde
    if steps is None:
        steps = list_checkpoint_steps(run_root)
    return [load_checkpoint_as_run(trainer_cls, run_root, s, device=device) for s in steps]
