from __future__ import annotations
######

import os 
os.environ["TF_CUDNN_DETERMINISTIC"] = "1" 
#####


from typing import Any, Dict, List, Tuple
import hashlib, json, re, yaml, os, itertools, copy, time, wandb, matplotlib.pyplot as plt
import jax

jax.config.update("jax_default_matmul_precision", "highest")



import jax.numpy as jnp
import socket



from phijax.equations.base import *
from phijax.utils import *
from phijax.models import *
from phijax.equations import get_pde
from phijax.trainer import Trainer, TimeMarchingTrainer



Path = Tuple[str, ...]

def _path_str(p: Path) -> str:
    return ".".join(p)


_LITERAL_KEY = "_literal"

def _is_literal_dict(d: dict) -> bool:
    # treat {"_literal": <any>} as atomic (non-sweep)
    return isinstance(d, dict) and set(d.keys()) == {_LITERAL_KEY}

def _unwrap_literals(x):
    if isinstance(x, dict):
        if _is_literal_dict(x):
            return _unwrap_literals(x[_LITERAL_KEY])
        return {k: _unwrap_literals(v) for k, v in x.items()}
    if isinstance(x, list):
        return [_unwrap_literals(v) for v in x]
    return x

def _find_list_fields(cfg: Any, prefix: Path = ()) -> List[Tuple[Path, list]]:
    """Return [(path, list_value)] for every list-typed value (recursively)."""
    out: List[Tuple[Path, list]] = []
    if isinstance(cfg, dict):
        if _is_literal_dict(cfg):
            return out
        for k, v in cfg.items():
            out += _find_list_fields(v, prefix + (k,))
    elif isinstance(cfg, list):
        out.append((prefix, cfg))
    return out

def _set_in(d: Dict[str, Any], path: Path, value: Any) -> None:
    for k in path[:-1]:
        d = d.setdefault(k, {})
    d[path[-1]] = value

def _expand_dict(d: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Expand a dict by taking cartesian product over list-typed leaves."""
    fields = _find_list_fields(d)
    if not fields:
        return [copy.deepcopy(_unwrap_literals(d))]
    paths = [p for p, _ in fields]
    grids = [vals for _, vals in fields]
    out = []
    for combo in itertools.product(*grids):
        inst = copy.deepcopy(d)
        for p, v in zip(paths, combo):
            _set_in(inst, p, v)
        out.append(_unwrap_literals(inst))
    return out

def _deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
    for k, v in src.items():
        if isinstance(v, dict) and isinstance(dst.get(k), dict):
            _deep_merge(dst[k], v)
        else:
            dst[k] = v
    return dst

def _expand_named_entry(entry: Dict[str, Any], *, config_key="config") -> List[Tuple[str, Dict[str, Any], Dict[str, Any]]]:
    """
    Returns a list of (name, meta, cfg_variant) where:
      - name is a string (never a list)
      - meta is the entry shallow copy WITHOUT 'config' (e.g., tag)
      - cfg_variant is one expanded config dict
    Supports entry like:
      { name: ["pinnmamba","pinnsformer"], tag: "small", config: {... lists ...} }
    """
    names_val = entry.get("name")
    names = names_val if isinstance(names_val, list) else [names_val]

    cfg_base = entry.get(config_key, {}) or {}
    cfg_variants = _expand_dict(cfg_base)

    meta = {k: v for k, v in entry.items() if k not in ("name", config_key)} # sweepable tuple

    out = []
    for n, cfg in itertools.product(names, cfg_variants):
        out.append((n, meta, cfg))
    return out

def assemble_runs(cfg: Dict[str, Any]) -> List[Dict[str, Any]]:
    base = copy.deepcopy(cfg)

    pdes = [p for p in base.get("pdes", []) if p.get("active", True)]
    models = [m for m in base.get("models", []) if m.get("active", True)]
    opts = [o for o in base.get("optimizers", []) if o.get("active", True)]

    if not pdes:
        raise ValueError("config must contain a non-empty or active 'pdes' list")
    if not models:
        raise ValueError("config must contain a non-empty or active 'models' list")
    if not opts:
        raise ValueError("config must contain a non-empty or active 'optimizers' list")

    top_base = {k: v for k, v in base.items() if k not in ("pdes", "models", "optimizers")}
    top_variants = _expand_dict(top_base)

    pde_variants: List[Tuple[str, dict, dict]] = []
    for p_entry in pdes:
        pde_variants.extend(_expand_named_entry(p_entry, config_key="config"))

    model_variants: List[Tuple[str, dict, dict]] = []
    for m_entry in models:
        model_variants.extend(_expand_named_entry(m_entry, config_key="config"))

    opt_variants: List[Tuple[str, dict, dict]] = []
    for o_entry in opts:
        opt_variants.extend(_expand_named_entry(o_entry, config_key="config"))

    runs: List[Dict[str, Any]] = []
    for top_cfg in top_variants:
        for (p_name, p_meta, p_cfg) in pde_variants:
            for (m_name, m_meta, m_cfg) in model_variants:
                for (o_name, o_meta, o_cfg) in opt_variants:
                    run = copy.deepcopy(top_cfg)

                    # what get_pde will read
                    run["pde"] = p_name
                    run["pde_config"] = p_cfg
                    run["activation"] = m_cfg.get("activation", run.get("activation", "tanh"))
                    

                    # what get_model will read (arch selection via exp_name prefix)
                    run["exp_name"] = f"{m_name}-{p_name}-{run['activation']}"

                    
                    run["model_config"] = m_cfg
                    if p_cfg.get("fourier_embeddings", None) is not None or p_cfg.get("fourier_embeddings") is not False:
                        run["model_config"]["fourier_embeddings"] = p_cfg.get("fourier_embeddings", None)
                    if p_cfg.get("batch_size", None) is not None:
                        run["training"] = run.get("training", {})
                        run["training"]["batch_size"] = p_cfg.get("batch_size", None)
                    run["input_dim"] = int(m_cfg.get("input_dim", run.get("input_dim", 2)))
                    run["init_batch_size"] = int(run.get("init", {}).get("batch_size", 1))

                    # optimizer config
                    run["optim"] = {"optimizer": o_name, **o_cfg}
                    run["model_meta"] = { "name": m_name, **m_meta, 'config': m_cfg }

                    # keep for bookkeeping
                    run["_names"] = {"pde": p_name, "model": m_name, "optim": o_name}

                    runs.append(run)

    return runs

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--tm", action="store_true", default=False)
    parser.add_argument("--ns", action="store_true", default=False)
    #parser.add_argument("--dir", type=str, default=None)
    parser.add_argument("--run_idx", type=int, default=None)
    parser.add_argument("--print_num_runs", action="store_true", default=False)
    args = parser.parse_args()

    base_config = yaml.safe_load(open(args.config, "r"))
    all_runs = assemble_runs(base_config)
    print(f"Total runs to execute: {len(all_runs)}")

    if args.print_num_runs:
        print(len(all_runs))
        raise SystemExit(0)

    if args.run_idx is not None:
        all_runs = [all_runs[args.run_idx]]


    for sweep_params in all_runs:
        print(yaml.dump(sweep_params, default_flow_style=True))
        sweep_params['hostname'] = socket.gethostname()
        try:

            if args.tm:
                trainer = TimeMarchingTrainer(sweep_params, device=args.device)

            elif args.ns:
                from phijax.trainer import NSTrainer
                trainer = NSTrainer(sweep_params, device=args.device)
            else:
            
                trainer = Trainer(sweep_params, device=args.device)

            if sweep_params.get("mode", "train") == "eval":
                if sweep_params.get("eval_dir", None) is not None:
                    trainer.run_dir = sweep_params["eval_dir"]
                    trainer.build()
                    #trainer.restore_latest()
                trainer.evaluate()
            else:
                trainer.train()
            continue
        except Exception as e:
            raise e
            print(f"Error running experiment {sweep_params['exp_name']}: {e}")
            
            print("Skipping this run due to error.")
            print ("" + "="*80 + "\n")
            continue
