# train_eval.py
import os, time, random, json
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.loader import DataLoader

import networkarch as NA
from loss import LossComputer

import matplotlib.pyplot as plt

FPS = 25.0  # HighD is recorded at 25 Hz
DT = 1.0 / FPS
# FPS=10.0
# DT=1.0/FPS
FUTURE_S = 5.0
HISTORY_S=3.0

FUTURE_STEPS = int(round(FUTURE_S * FPS))
HISTORY_STEPS = int(round(HISTORY_S * FPS))

def set_seed(seed=42):
    random.seed(seed);
    np.random.seed(seed)
    torch.manual_seed(seed);
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)


def split_dataset(ds: NA.GKVKDataset, train_ratio=0.9):
    n = len(ds)
    idx = list(range(n))
    random.shuffle(idx)
    n_tr = int(n * train_ratio)
    train_items = [ds.items[i] for i in idx[:n_tr]]
    val_items = [ds.items[i] for i in idx[n_tr:]]
    return train_items, val_items


def save_ckpt(state, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(state, path)


def collate_batch(batch_list):
    """Custom collate function to handle batching of Data objects"""
    from torch_geometric.data import Batch
    return Batch.from_data_list(batch_list)

def train_one_epoch(model, loader, optimizer, loss_comp, device, writer, global_step, grad_clip=1.0):
    model.train()
    losses = []
    for batch in loader:
        batch = batch.to(device)
        batch_size = batch.num_graphs

        optimizer.zero_grad()

        (y_gk, y_vk), aux = model(batch)

        # For batched data, ensure outputs have correct shape
        # y_gk: [B, T, out_dim_gk], y_vk: [B, T, out_dim_vk]

        def _stat_dict(t, name):
            d = {}
            if t is None: return d
            try:
                t = t.detach().float()
                if t.dim() > 1 and batch_size > 1:
                    t = t[0]
                d[f'{name}/mean'] = float(t.mean())
                d[f'{name}/std'] = float(t.std())
                d[f'{name}/min'] = float(t.min())
                d[f'{name}/max'] = float(t.max())
            except Exception:
                pass
            return d

        tokens = aux.get('tokens', {}) if isinstance(aux, dict) else {}
        for k_pred, k_tgt, tag in [
            ('vk_ctx', 'vk', 'recon/v'),
            ('G_ctx', 'G', 'recon/G'),
        ]:
            pred = tokens.get(k_pred, None)
            tgt = tokens.get(k_tgt, None)
            for kk, vv in _stat_dict(pred, f'{tag}_pred').items():
                writer.add_scalar(f'train/{kk}', vv, global_step)
            for kk, vv in _stat_dict(tgt, f'{tag}_tgt').items():
                writer.add_scalar(f'train/{kk}', vv, global_step)

        loss, logs = loss_comp((y_gk, y_vk), batch, aux, model)

        # Average loss across batch
        if isinstance(loss, torch.Tensor) and loss.dim() > 0:
            loss = loss.mean()

        loss.backward()
        writer.add_scalar("train/loss", loss, global_step)

        # if global_step%100==1:
        #     ade_open, ade_closed = eval_ADE_curves(model, batch, device=device, max_horizon=125)


        # Log phys param grad norms
        if hasattr(model, 'phys_params'):
            for name, p in model.phys_params.items():
                g = (p.grad.norm().item() if (p.grad is not None) else 0.0)
                writer.add_scalar(f'grad/{name}_norm', g, global_step)

        # Gradient clipping
        total_params = list(model.parameters()) + (
            list(model.phys_params.parameters()) if hasattr(model, 'phys_params') else [])
        grad_norm = torch.nn.utils.clip_grad_norm_(total_params, grad_clip)
        writer.add_scalar('train/grad_norm', float(grad_norm), global_step)
        torch.nn.utils.clip_grad_value_(total_params, 5.0)

        optimizer.step()

        losses.append(float(logs["loss/total"]))

        for k, v in logs.items():
            val = v
            if torch.is_tensor(val):
                if val.dim() == 0:
                    val = float(val.detach().cpu())
                else:
                    val = float(val.detach().mean().cpu())
            else:
                val = float(val)
            writer.add_scalar("train/" + k, val, global_step)

        if hasattr(model, "phys_params"):
            writer.add_scalar("phys/c", model.phys_params["c"].item(), global_step)
            writer.add_scalar("phys/nu", model.phys_params["nu"].item(), global_step)

        global_step += 1

    epoch_mean = float(np.mean(losses)) if len(losses) > 0 else float('nan')
    writer.add_scalar('train/epoch_mean_loss', epoch_mean, global_step)
    return epoch_mean, global_step


@torch.no_grad()
def evaluate(model, loader, loss_comp, device, writer, epoch):
    model.eval()
    losses = []
    for batch in loader:
        batch = batch.to(device)
        (y_gk, y_vk), aux = model(batch)
        loss, logs = loss_comp((y_gk, y_vk), batch, aux, model)

        if isinstance(loss, torch.Tensor) and loss.dim() > 0:
            loss = loss.mean()

        losses.append(loss.item())

        if 'recon/weighted' in logs:
            writer.add_scalar('val/recon_weighted', float(logs['recon/weighted']), epoch)
        if 'macro/mse' in logs:
            writer.add_scalar('val/macro_mse', float(logs['macro/mse']), epoch)
        for key in ['macro/mse_1s', 'macro/mse_3s', 'macro/mse_5s']:
            if key in logs:
                writer.add_scalar(f'val/{key.replace("/", "_")}', float(logs[key]), epoch)

    val = float(np.mean(losses))
    writer.add_scalar("val/loss_total", val, epoch)
    return val

def main(processedDataset: str, batch_size: int = 64):
    set_seed(42)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    pt_path = processedDataset
    ds = NA.GKVKDataset(pt_path)
    print(f"Dataset size: {len(ds)} items")
    train_items, val_items = split_dataset(ds, 0.9)
    print(f"Split: train={len(train_items)}  val={len(val_items)}")

    nw_env = int(os.environ.get('GK_NUM_WORKERS', '4'))
    nw = min(nw_env, os.cpu_count() or 2)
    pin = (device == "cuda")

    cfg = {
        "dt": DT,
        "history_steps":HISTORY_STEPS,
        "future_steps":FUTURE_STEPS,
        "w_micro": 1.0, "w_macro": 1.0,'w_recon_micro':0.1,'w_recon_macro':0.1,
        "w_jad": 0.2, "w_spec": 0.2,
        "spec_cfg": {
            'dt': DT,
            'c_min': -20.0, 'c_max': 100.0,
            'nu_min': 0.1, 'nu_max': 100.0,
            'c_target': None, 'nu_target': None,
            'lam_c_range': 0.01, 'lam_nu_range': 0.01,
            'adv_strength_target': 0.02, 'lam_adv_presence': 0.05,
            'diff_strength_target': 0.02, 'lam_diff_presence': 0.05
        },
        "jad_beta_comm": 1.0, "jad_alpha1": 1.0, "jad_alpha2": 0.1,
        "lam_diff_edge_l1": 1e-4, "lam_adv_edge_l1": 1e-4,
        "lam_alpha_l1": 1e-5, "lam_beta_l1": 1e-5,
        "lam_iss": 0.2, "delta_iss": 0.05,
        "vk_as_delta": True,
        "w_traj_abs": 1.0, "w_traj_delta": 0.5,
        "delta_loss_beta": 1.0,
        "spectral_clip":1.0
    }

    train_loader = DataLoader(train_items, batch_size=batch_size, shuffle=True,
                              num_workers=nw, pin_memory=pin, persistent_workers=(nw > 0),
                              collate_fn=collate_batch)
    val_loader = DataLoader(val_items, batch_size=batch_size, shuffle=False,
                            num_workers=nw, pin_memory=pin, persistent_workers=(nw > 0),
                            collate_fn=collate_batch)

    model = NA.GraphKoopmanModel(
        cfg=cfg,gk_node_dim=10, gk_edge_dim=9, vk_dim=10,
        gk_hidden=128, vk_hidden=128, d_model=128,
        d_control_vk=64,
        out_dim_gk=1,
        out_dim_vk=2,
        n_heads=8,
        stride=10,#25 for HighD, 10 for NGSIM
    ).to(device)
    writer = SummaryWriter()

    loss_comp = LossComputer(cfg)

    # Adjust learning rate based on batch size
    base_lr = 1e-3
    scaled_lr = base_lr * (batch_size / 1.0) ** 0.5  # Square root scaling

    # phys_group = list(model.phys_params.parameters())
    # phys_ids = {id(p) for p in phys_group}
    optimizer=torch.optim.Adam(model.parameters(), lr=scaled_lr,weight_decay=1e-5)
    # base_group = [p for p in model.parameters() if p.requires_grad and (id(p) not in phys_ids)]
    # optimizer = torch.optim.Adam([
    #     {"params": base_group, "lr": scaled_lr, "weight_decay": 1e-5},
    #     #{"params": phys_group, "lr": scaled_lr * 0.3, "weight_decay": 0.0},
    # ])


    best = float("inf");
    patience = 10;
    bad = 0;
    global_step = 0
    max_epochs = 60
    for epoch in range(1, max_epochs + 1):
        #t0 = time.time()
        tr, global_step = train_one_epoch(model, train_loader, optimizer, loss_comp, device, writer, global_step,
                                          grad_clip=1.0)
        #val = evaluate(model, val_loader, loss_comp, device, writer, epoch)
        # print(f"[{epoch:03d}] train={tr:.6f}  val={val:.6f}  ({time.time() - t0:.1f}s)")
        #
        # if epoch % 5 == 0 or epoch == 1:
        #     visualize_one_batch(model, val_loader, writer, "viz/val", epoch)
        if epoch%5==0:
            print()
            save_ckpt({"epoch": epoch,
                       "model": model.state_dict(),
                       # "phys": model.phys_params.state_dict()
                       },
                      f"./checkpoints/ngsim_{epoch}.pt")

        # if val < best - 1e-6:
        #     best = val;
        #     bad = 0
        #     save_ckpt({"epoch": epoch,
        #                "model": model.state_dict(),
        #                #"phys": model.phys_params.state_dict(),
        #                "best": best}, "./checkpoints/best.pt")
        #     print(f"  ↳ saved best (val={best:.6f})")
        # else:
        #     bad += 1
        #     if bad >= patience:
        #         print("Early stopping.");
        #         break

    writer.close()
    save_ckpt({"epoch": epoch,
               "model": model.state_dict(),
               #"phys": model.phys_params.state_dict()
               },
              "./checkpoints/ngsim_last.pt")


if __name__ == "__main__":
    processedDataset_path = 'ProcessedDataset/HighD/01_gk_vk_dataset.pt'
    #processedDataset_path = 'ProcessedDataset/NGSIM/trajectories-0750am-0805am_gk_vk_dataset.pt'
    batch_size = 32  # You can adjust this
    main(processedDataset_path, batch_size=batch_size)