import argparse, yaml, torch, pprint, time
from datetime import datetime
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

from src.data.dataset import GraphFlowSamplesDataset
from src.models.network import VAndBetaMLP, EnergyNetwork
from src.training.train_eval import train_one_epoch, train_one_epoch_v2, evaluate, evaluate_v2, epoch_diagnostics, epoch_diagnostics_v2, evaluate_by_forecast_tv, get_forecast_metrics
import os
import json
import glob

BANNER = "=" * 60

def build_full_parser():
    p = argparse.ArgumentParser(description="Train V/beta on discrete diffusion snapshots")
    p.add_argument("-f", "--config_file", type=str, default=None,
                   help="YAML config path (CLI overrides keys inside).")
    p.add_argument("--data_folder", type=str, help="Folder with samples_tm, v_mat_seq, (rho_gt_seq)")
    p.add_argument("--output_folder", type=str, default=None, help="Where to save checkpoints/logs")
    p.add_argument("--epochs", type=int, default=100)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--weight_decay", type=float, default=1e-6)
    p.add_argument("--batch_size", type=int, default=64)
    p.add_argument("--steps_per_epoch", type=int, default=1000)
    p.add_argument("--val_steps", type=int, default=200)
    p.add_argument("--use_rho_gt", action="store_true")
    p.add_argument("--embedding_dim", type=int, default=32)
    p.add_argument("--hidden_dim", type=int, default=128)
    p.add_argument("--device", type=str, default=None)
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--no_progress", action="store_true", help="Disable tqdm progress bars")
    p.add_argument("--whole-folder", default=False, help="If True, treats the data_folder as a superfolder and trains on each of them separately")
    return p

def parse_args():
    mini = argparse.ArgumentParser(add_help=False)
    mini.add_argument("-f", "--config_file", type=str, default=None)
    known, _ = mini.parse_known_args()

    parser = build_full_parser()
    if known.config_file:
        with open(known.config_file, "r") as fh:
            cfg = yaml.safe_load(fh) or {}
        if not isinstance(cfg, dict):
            raise ValueError("Config must be a YAML mapping.")
        parser.set_defaults(**cfg)
    args = parser.parse_args()
    if args.data_folder is None:
        raise ValueError("data_folder is required (YAML or CLI).")
    return args

def pretty_config(args, extra=None):
    cfg = vars(args).copy()
    if extra:
        cfg.update(extra)
    return pprint.pformat(cfg, indent=2, sort_dicts=False)

def ensure_dir(p):
    os.makedirs(p, exist_ok=True)
    return p

def main():
    args = parse_args()
    device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
    torch.set_default_dtype(torch.float64)
    torch.manual_seed(args.seed)


    args.data_folder = os.path.expanduser(args.data_folder)
    # Build the list of folders to iterate
    if args.whole_folder:
        # Only keep subdirectories
        data_folders = [os.path.join(args.data_folder, d) for d in os.listdir(args.data_folder) if os.path.isdir(os.path.join(args.data_folder, d))]
    else:
        data_folders = [args.data_folder]

    # Base output location
    if args.output_folder is None:
        # Derive base name from the (super) data folder
        base_name = os.path.basename(os.path.normpath(args.data_folder))
        base_output_dir = f"./experiments/{base_name}"
    else:
        base_output_dir = os.path.expanduser(args.output_folder)

    for data_folder in data_folders:
        folder = os.path.expanduser(data_folder)  # always expand; fixes potential undefined var
        dataset_name = os.path.basename(os.path.normpath(folder))
        # Each dataset goes into its own subdir
        save_dir = ensure_dir(os.path.join(base_output_dir, dataset_name))

        print(f"\n{BANNER}\n[INIT] Starting…\n{BANNER}")
        print("[INIT] Loading dataset from:", folder)

        ds_train = GraphFlowSamplesDataset.from_folder(os.path.join(folder, 'train'), dtype=torch.float64)
        ds_val   = GraphFlowSamplesDataset.from_folder(os.path.join(folder, 'val'),   dtype=torch.float64)
        
        # Load graph data (K, pi) from parent folder and add to metadata
        import pickle
        graph_data_path = os.path.join(folder, 'graph_data.pkl')
        if os.path.exists(graph_data_path):
            with open(graph_data_path, 'rb') as f:
                K, pi, pos, W, G = pickle.load(f)
            # Convert to lists for JSON serialization if needed, but keep tensors for use
            ds_train.metadata['K'] = K if isinstance(K, torch.Tensor) else torch.tensor(K)
            ds_train.metadata['pi'] = pi if isinstance(pi, torch.Tensor) else torch.tensor(pi)
            ds_val.metadata['K'] = K if isinstance(K, torch.Tensor) else torch.tensor(K)
            ds_val.metadata['pi'] = pi if isinstance(pi, torch.Tensor) else torch.tensor(pi)

        print("[INIT] Building model/optimizer…")
        # model = VAndBetaMLP(num_nodes=ds_train.n,
        #                     embedding_dim=args.embedding_dim,
        #                     hidden_dim=args.hidden_dim).to(device)
        model = EnergyNetwork(num_nodes=ds_train.n,
                              embedding_dim=args.embedding_dim,
                              hidden_dim=args.hidden_dim).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        # Basic run facts
        facts = {
            "device": device,
            "dtype": str(torch.get_default_dtype()),
            "n_nodes": ds_train.n,
            "T_snapshots": len(ds_train),
            "M_samples_per_t": ds_train.M,
            "smoothing": getattr(ds_train, "smoothing", None),
            "n_params": sum(p.numel() for p in model.parameters()),
            "dataset_name": dataset_name,
            "data_folder": folder,
            "start_time_iso": datetime.utcnow().isoformat() + "Z",
        }
        print("[INIT] Configuration (effective):")
        print(pretty_config(args, extra=facts))

        best_val = float("inf")
        run_t0 = time.perf_counter()

        # Open a JSONL to accumulate per-epoch metrics (easy to parse later)
        metrics_jsonl_path = os.path.join(save_dir, "metrics.jsonl")
        # Write a header file with static facts
        with open(os.path.join(save_dir, "run_facts.json"), "w") as f:
            json.dump(facts, f, indent=2)

        print(f"\n{BANNER}\n[TRAIN] Begin training ({dataset_name})\n{BANNER}")
        for epoch in range(1, args.epochs + 1):
            # ---- Train (timed)
            t_train0 = time.perf_counter()
            tr = train_one_epoch_v2(
                model, ds_train, optimizer,
                device=device,
                batch_size=args.batch_size,
                steps_per_epoch=args.steps_per_epoch,
                use_rho_gt=args.use_rho_gt,
                grad_clip=1.0,
                seed=args.seed + epoch,
                show_progress=(not args.no_progress),
                reg_beta=getattr(args, "reg_beta", None),
                beta0=getattr(args, "beta0", None),
                lam_beta=getattr(args, "lam_beta", None),
            )
            t_train = time.perf_counter() - t_train0

            # ---- Eval (timed)
            t_eval0 = time.perf_counter()
            va = evaluate_v2(
                model, ds_val,
                device=device,
                batch_size=args.batch_size,
                steps=args.val_steps,
                use_rho_gt=args.use_rho_gt,
                seed=args.seed + epoch,
                show_progress=(not args.no_progress),
                reg_beta=getattr(args, "reg_beta", None),
                beta0=getattr(args, "beta0", None),
                lam_beta=getattr(args, "lam_beta", None),
            )
            t_eval = time.perf_counter() - t_eval0

            # ---- Forecast metrics (timed)
            t_fore0 = time.perf_counter()
            forecast_metrics = get_forecast_metrics(
                model, ds_val,
                metrics=['tv', 'l2', 'hellinger', 'random_check'],
                device=device,
            )
            t_fore = time.perf_counter() - t_fore0

            TV = forecast_metrics['TV']
            L2 = forecast_metrics['L2']
            H  = forecast_metrics['Hellinger']

            print(
                f"[EPOCH {epoch:03d}] "
                f"train {tr:.6f} | val {va:.6f} | "
                f"forecast: TV={TV['mean']:.4e}/{TV['max']:.4e}, "
                f"L2={L2['mean']:.4e}/{L2['max']:.4e}, "
                f"H={H['mean']:.4e}/{H['max']:.4e} | "
                f"times: train={t_train:.2f}s, eval={t_eval:.2f}s, fore={t_fore:.2f}s"
            )

            # ---- Save per-epoch files (separate files per type)
            # 1) evaluation metrics file
            eval_path = os.path.join(save_dir, f"epoch_{epoch:03d}_eval.json")
            with open(eval_path, "w") as f:
                json.dump(
                    {
                        "epoch": epoch,
                        "train_loss": tr,
                        "val_loss": va,
                        "times_sec": {
                            "train": t_train,
                            "eval": t_eval,
                            "forecast": t_fore
                        },
                        "timestamp_iso": datetime.utcnow().isoformat() + "Z",
                    },
                    f,
                    indent=2,
                )

            # 2) forecast metrics file
            fore_path = os.path.join(save_dir, f"epoch_{epoch:03d}_forecast.json")
            with open(fore_path, "w") as f:
                json.dump(forecast_metrics, f, indent=2)

            # 3) append combined record to metrics.jsonl (easy for pandas)
            combo_record = {
                "dataset": dataset_name,
                "epoch": epoch,
                "train_loss": tr,
                "val_loss": va,
                "TV_mean": TV["mean"], "TV_max": TV["max"],
                "L2_mean": L2["mean"], "L2_max": L2["max"],
                "H_mean": H["mean"],  "H_max": H["max"],
                "time_train_sec": t_train,
                "time_eval_sec": t_eval,
                "time_forecast_sec": t_fore,
                "timestamp_iso": datetime.utcnow().isoformat() + "Z",
            }
            with open(metrics_jsonl_path, "a") as f:
                f.write(json.dumps(combo_record) + "\n")

            # ---- Best checkpoint per dataset folder
            if va + 1e-6 < best_val:
                best_val = va
                ckpt_path = os.path.join(save_dir, "vbeta_best.pt")
                torch.save(model.state_dict(), ckpt_path)
                print(f"[EPOCH {epoch:03d}] ✓ Saved checkpoint: {ckpt_path} (best val {best_val:.6f})")

        total_dur = time.perf_counter() - run_t0
        summary = {
            "dataset": dataset_name,
            "best_val": best_val,
            "total_training_time_sec": total_dur,
            "total_training_time_min": total_dur / 60.0,
            "epochs": args.epochs,
            "end_time_iso": datetime.utcnow().isoformat() + "Z",
        }
        with open(os.path.join(save_dir, "summary.json"), "w") as f:
            json.dump(summary, f, indent=2)

        print(f"\n{BANNER}\n[DONE] {dataset_name}: Finished training in {total_dur/60.0:.2f} min. Best val={best_val:.6f}\n{BANNER}\n")

if __name__ == "__main__":
    main()
