from __future__ import annotations
######

import os 
os.environ["TF_CUDNN_DETERMINISTIC"] = "1" 
#####


from typing import Any, Dict, List, Tuple
import hashlib, json, re, yaml, os, itertools, copy, time, wandb, matplotlib.pyplot as plt
import jax

jax.config.update("jax_default_matmul_precision", "highest")



import jax.numpy as jnp



from phijax.equations.base import *
from phijax.utils import *
from phijax.models import *
from phijax.equations import get_pde



Path = Tuple[str, ...]

def _path_str(p: Path) -> str:
    return ".".join(p)


_LITERAL_KEY = "_literal"

def _is_literal_dict(d: dict) -> bool:
    # treat {"_literal": <any>} as atomic (non-sweep)
    return isinstance(d, dict) and set(d.keys()) == {_LITERAL_KEY}

def _unwrap_literals(x):
    if isinstance(x, dict):
        if _is_literal_dict(x):
            return _unwrap_literals(x[_LITERAL_KEY])
        return {k: _unwrap_literals(v) for k, v in x.items()}
    if isinstance(x, list):
        return [_unwrap_literals(v) for v in x]
    return x

def _find_list_fields(cfg: Any, prefix: Path = ()) -> List[Tuple[Path, list]]:
    """Return [(path, list_value)] for every list-typed value (recursively)."""
    out: List[Tuple[Path, list]] = []
    if isinstance(cfg, dict):
        if _is_literal_dict(cfg):
            return out
        for k, v in cfg.items():
            out += _find_list_fields(v, prefix + (k,))
    elif isinstance(cfg, list):
        out.append((prefix, cfg))
    return out

def _set_in(d: Dict[str, Any], path: Path, value: Any) -> None:
    for k in path[:-1]:
        d = d.setdefault(k, {})
    d[path[-1]] = value

def _expand_dict(d: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Expand a dict by taking cartesian product over list-typed leaves."""
    fields = _find_list_fields(d)
    if not fields:
        return [copy.deepcopy(_unwrap_literals(d))]
    paths = [p for p, _ in fields]
    grids = [vals for _, vals in fields]
    out = []
    for combo in itertools.product(*grids):
        inst = copy.deepcopy(d)
        for p, v in zip(paths, combo):
            _set_in(inst, p, v)
        out.append(_unwrap_literals(inst))
    return out

def _deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
    for k, v in src.items():
        if isinstance(v, dict) and isinstance(dst.get(k), dict):
            _deep_merge(dst[k], v)
        else:
            dst[k] = v
    return dst

def _expand_named_entry(entry: Dict[str, Any], *, config_key="config") -> List[Tuple[str, Dict[str, Any], Dict[str, Any]]]:
    """
    Returns a list of (name, meta, cfg_variant) where:
      - name is a string (never a list)
      - meta is the entry shallow copy WITHOUT 'config' (e.g., tag)
      - cfg_variant is one expanded config dict
    Supports entry like:
      { name: ["pinnmamba","pinnsformer"], tag: "small", config: {... lists ...} }
    """
    names_val = entry.get("name")
    names = names_val if isinstance(names_val, list) else [names_val]

    cfg_base = entry.get(config_key, {}) or {}
    cfg_variants = _expand_dict(cfg_base)

    meta = {k: v for k, v in entry.items() if k not in ("name", config_key)} # sweepable tuple

    out = []
    for n, cfg in itertools.product(names, cfg_variants):
        out.append((n, meta, cfg))
    return out

def assemble_runs(cfg: Dict[str, Any]) -> List[Dict[str, Any]]:
    base = copy.deepcopy(cfg)

    pdes = [p for p in base.get("pdes", []) if p.get("active", True)]
    models = [m for m in base.get("models", []) if m.get("active", True)]
    opts = [o for o in base.get("optimizers", []) if o.get("active", True)]

    if not pdes:
        raise ValueError("config must contain a non-empty or active 'pdes' list")
    if not models:
        raise ValueError("config must contain a non-empty or active 'models' list")
    if not opts:
        raise ValueError("config must contain a non-empty or active 'optimizers' list")

    top_base = {k: v for k, v in base.items() if k not in ("pdes", "models", "optimizers")}
    top_variants = _expand_dict(top_base)

    pde_variants: List[Tuple[str, dict, dict]] = []
    for p_entry in pdes:
        pde_variants.extend(_expand_named_entry(p_entry, config_key="config"))

    model_variants: List[Tuple[str, dict, dict]] = []
    for m_entry in models:
        model_variants.extend(_expand_named_entry(m_entry, config_key="config"))

    opt_variants: List[Tuple[str, dict, dict]] = []
    for o_entry in opts:
        opt_variants.extend(_expand_named_entry(o_entry, config_key="config"))

    runs: List[Dict[str, Any]] = []
    for top_cfg in top_variants:
        for (p_name, p_meta, p_cfg) in pde_variants:
            for (m_name, m_meta, m_cfg) in model_variants:
                for (o_name, o_meta, o_cfg) in opt_variants:
                    run = copy.deepcopy(top_cfg)

                    # what get_pde will read
                    run["pde"] = p_name
                    run["pde_config"] = p_cfg
                    run["activation"] = m_cfg.get("activation", run.get("activation", "tanh"))
                    

                    # what get_model will read (arch selection via exp_name prefix)
                    run["exp_name"] = f"{m_name}-{p_name}-{run['activation']}"

                    
                    run["model_config"] = m_cfg
                    if p_cfg.get("fourier_embeddings", None) is not None or p_cfg.get("fourier_embeddings") is not False:
                        run["model_config"]["fourier_embeddings"] = p_cfg.get("fourier_embeddings", None)
                    if p_cfg.get("batch_size", None) is not None:
                        run["training"] = run.get("training", {})
                        run["training"]["batch_size"] = p_cfg.get("batch_size", None)
                    run["input_dim"] = int(m_cfg.get("input_dim", run.get("input_dim", 2)))
                    run["init_batch_size"] = int(run.get("init", {}).get("batch_size", 1))

                    # optimizer config
                    run["optim"] = {"optimizer": o_name, **o_cfg}

                    # keep for bookkeeping
                    run["_names"] = {"pde": p_name, "model": m_name, "optim": o_name}

                    runs.append(run)

    return runs



def _slugify(s: str) -> str:
    s = s.strip().lower()
    s = re.sub(r"[^a-z0-9._-]+", "-", s)
    s = re.sub(r"-{2,}", "-", s).strip("-")
    return s or "run"


def _short_hash(obj: Any, n: int = 8) -> str:
    payload = json.dumps(obj, sort_keys=True, default=str).encode("utf-8")
    return hashlib.sha1(payload).hexdigest()[:n]




def create_experiment_dir(
    exp_path: str,
    exp_name: str,
    seed: int,
    sweep_id: Optional[str] = None,
    create_if_missing: bool = True,
    select_index: Optional[int] = None,
) -> str:
    """
    Creates: <exp_path>/<exp_name>/seed_<seed>/<sweep_id or timestamp>/
    Returns the created directory.
    """
    exp_path = exp_path or "./runs"
    exp_name = _slugify(exp_name)
    model_name, eq_name, *_ = exp_name.split("-")
    parts = [exp_path, eq_name, model_name]
    if sweep_id:
        parts = [exp_path, sweep_id ]
    parts.append(f"seed_{int(seed)}")

    base = os.path.join(*parts)

    parent = os.path.dirname(base)
    basename = os.path.basename(base)

    def existing_suffixes():
        if not os.path.isdir(parent):
            return set(), False
        dirs = [d for d in os.listdir(parent) if d == basename or d.startswith(basename + "-")]
        suffixes = set()
        has_base = False
        for d in dirs:
            if d == basename:
                has_base = True
            else:
                tail = d[len(basename) + 1:] 
                if tail.isdigit():
                    suffixes.add(int(tail))
        return suffixes, has_base
    suffixes, has_base = existing_suffixes()
    if create_if_missing:
        if not has_base and not suffixes:
            path = base
        else:
            next_idx = (max(suffixes) if suffixes else 0) + 1
            path = f"{base}-{next_idx}"
        path = os.path.join(path)
        os.makedirs(path, exist_ok=True)
        print(path)
        return path

    if select_index is not None:
        if select_index < 0:
            raise ValueError("select_index must be >= 0")
        path_no_ckpt = base if select_index == 0 else f"{base}-{select_index}"
        path = os.path.join(path_no_ckpt, "checkpoints")
        if not os.path.isdir(path):
            raise FileNotFoundError(f"Requested run does not exist: {path}")
        return path

    if not has_base and not suffixes:
        raise FileNotFoundError(f"No runs exist under {parent} for {basename}")
    if suffixes:
        latest_idx = max(suffixes)
        path_no_ckpt = f"{base}-{latest_idx}"
    else:
        path_no_ckpt = base

    path = os.path.join(path_no_ckpt, "checkpoints")
    if not os.path.isdir(path):
        raise FileNotFoundError(f"'checkpoints' not found for latest run: {path}")
    return path




def run_one_experiment(run_cfg, device=None):
    names = run_cfg.get("_names", {})
    pde_name = names.get("pde", run_cfg.get("pde", "pde"))
    model_name = names.get("model", run_cfg.get("exp_name", "model").split("-")[0])
    opt_name = names.get("optim", run_cfg.get("optim", {}).get("optimizer", "opt"))

    seed = int(run_cfg.get("seed", 0))
    exp_path = run_cfg.get("exp_path", "./runs")

    activation = run_cfg.get("activation")#.get("activation")

    exp_name = f"{model_name}-{pde_name}-{activation}"
    sweep_id = f"{pde_name}_{model_name}_{opt_name}_{activation}"
    run_dir = create_experiment_dir(
        exp_path,
        exp_name, 
        seed, 
        sweep_id=sweep_id, 
        create_if_missing=run_cfg.get("mode", "train") == "train"
        )

    # --- Save the initial config for this specific run ---
    config_save_path = os.path.join(run_dir, "config.yml")
    with open(config_save_path, "w") as f:
        save_cfg = copy.deepcopy(run_cfg)
        save_cfg['exp_path'] = run_dir  
        yaml.dump(save_cfg, f, default_flow_style=False)

   


    # Wandb
    wandb_cfg = run_cfg.get("wandb", {})
    if wandb_cfg.get("use", False):
        wandb.init(
            project=wandb_cfg.get("project", "default_project"),
            name=run_cfg["exp_name"],  # Keep short
            config=run_cfg,
            mode=wandb_cfg.get("mode", "online"),
        )
        #wandb.watch(model, log="gradients", log_freq=1000)

    run_cfg = Collection.from_dict(run_cfg)
    logger = Logger()
    t0 = time.time()

    model = get_pde(run_cfg)
    num_epochs = run_cfg.training.num_epochs
    res_sampler = iter(model.sampler)
    u_ref = model.u_ref
    log_every = run_cfg.logging.log_every

    #parameter count
    param_count = count_params(model.state.params)
    print(f"Number of parameters: {param_count}")

    for step in range(num_epochs):
        step_start = time.time()
        batch = next(res_sampler)
        model.state = model.step(model.state, batch)
        

        if  jax.process_index() == 0 and (step % log_every) == 0:
            state = jax.device_get(tree_map(lambda x: x[0], model.state))
            batch = jax.device_get(tree_map(lambda x: x[0], batch))
            log_dict = model.log(state, batch, u_ref)
            logger.log_iter(step, step_start, time.time(), log_dict)
            # 
            # Wandb logging
            if wandb_cfg.get("use", False):
                wandb.log(log_dict, step=step)

        # saving
        if jax.process_index() == 0 and run_cfg.logging.save_every is not None:
            if (step + 1) % run_cfg.logging.save_every == 0 or (step + 1) == num_epochs:
                ckpt_path = os.path.join(run_dir, "checkpoints")
                save_checkpoint(
                    model.state,
                    ckpt_path,
                    keep=run_cfg.logging.num_keep_ckpts
                )
    #evaluate and plot

    metrics = evaluate(run_cfg)
    t_end = time.time()
    if jax.process_index() == 0:
        print(f"Experiment completed in {t_end - t0:.2f} seconds.")
        print("Final Metrics:", metrics)
        
        if wandb_cfg.get("use", False):
            wandb.log(
               {
                    **metrics,
                    "total_time": t_end - t0,
                    "param_count": param_count,
                    
               }
            )
            wandb.finish()





def train(run_cfg):
    logger = Logger()
    model = get_pde(run_cfg)
    num_epochs = run_cfg.training.num_epochs
    res_sampler = iter(model.sampler)
    u_ref = model.u_ref
    log_every = run_cfg.logging.log_every
    for step in range(num_epochs):
        step_start = time.time()
        batch = next(res_sampler)
        model.state = model.step(model.state, batch)

        if jax.process_index() == 0 and (step % log_every) == 0:
            state = jax.device_get(tree_map(lambda x: x[0], model.state))
            batch = jax.device_get(tree_map(lambda x: x[0], batch))
            log_dict = model.log(state, batch, u_ref)
            logger.log_iter(step, step_start, time.time(), log_dict)


def evaluate(run_cfg):
    names = run_cfg.get("_names", {})
    pde_name = names.get("pde", run_cfg.get("pde", "pde"))
    model_name = names.get("model", run_cfg.get("exp_name", "model").split("-")[0])
    opt_name = names.get("optim", run_cfg.get("optim", {}).get("optimizer", "opt"))

    seed = int(run_cfg.get("seed", 0))
    exp_path = run_cfg.get("exp_path", "./runs")

    activation = run_cfg.get("activation")#.get("activation")

    exp_name = f"{model_name}-{pde_name}-{activation}"
    sweep_id = f"{pde_name}_{model_name}_{opt_name}_{activation}"
    run_dir = create_experiment_dir(
        exp_path,
        exp_name, 
        seed, 
        sweep_id=sweep_id, 
        create_if_missing=False
        )
    #restor the model and evaluate
    model = get_pde(run_cfg)
    model.state = restore_checkpoint(model.state, os.path.join(run_dir))
    params = model.state.params

    u_ref = model.u_ref
    l2_error = model.compute_l2_error(params, u_ref)
    rmae_error = model.compute_rmae(params, u_ref)

    metrics = {"rmse": float(l2_error), "rmae": float(rmae_error)}
    
    with open(os.path.join(run_dir, "metrics.json"), "w") as f:
        json.dump(metrics, f, indent=4)
    t_star = model.t_star
    x_star = model.x_star
    
    u_pred = model.u_pred_fn(params, t_star, x_star)
    TT, XX = jnp.meshgrid(t_star, x_star, indexing="ij")

    # plot
    fig = plt.figure(figsize=(18, 5))
    plt.subplot(1, 3, 1)
    plt.pcolor(TT, XX, u_ref, cmap="jet")
    plt.colorbar()
    plt.xlabel("t")
    plt.ylabel("x")
    plt.title("Exact")
    plt.tight_layout()

    plt.subplot(1, 3, 2)
    plt.pcolor(TT, XX, u_pred, cmap="jet")
    plt.colorbar()
    plt.xlabel("t")
    plt.ylabel("x")
    plt.title("Predicted")
    plt.tight_layout()

    plt.subplot(1, 3, 3)
    plt.pcolor(TT, XX, jnp.abs(u_ref - u_pred), cmap="jet")
    plt.colorbar()
    plt.xlabel("t")
    plt.ylabel("x")
    plt.title("Absolute error")
    plt.tight_layout()

   

    fig_path = os.path.join(run_dir, "convection.pdf")
    fig.savefig(fig_path, bbox_inches="tight", dpi=300)

    return metrics






if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--device", type=str, default=None)
    args = parser.parse_args()

    base_config = yaml.safe_load(open(args.config, "r"))
    all_runs = assemble_runs(base_config)
    for sweep_params in all_runs:
        print(yaml.dump(sweep_params, default_flow_style=True))
        try:
            
            run_one_experiment(sweep_params, device=args.device)
            #print ("" + "="*80 + "\n")
            continue
        except Exception as e:
            raise e
            print(f"Error running experiment {sweep_params['exp_name']}: {e}")
            
            print("Skipping this run due to error.")
            print ("" + "="*80 + "\n")
            continue
