import argparse
import math
import os
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt

from models import OnsagerNet_original
from models import drift_MLP


def load_checkpoint(model_path, device):
    state = torch.load(model_path, map_location=device, weights_only=False)
    normalization_info = None
    if isinstance(state, dict) and "model_state_dict" in state:
        normalization_info = state.get("normalization_info")
        state = state["model_state_dict"]
    return state, normalization_info


def extract_min_max(normalization_info, z_dim=None):
    if not normalization_info:
        return None, None
    data_min = normalization_info.get("z_min")
    data_max = normalization_info.get("z_max")
    if data_min is None or data_max is None:
        data_min = normalization_info.get("min")
        data_max = normalization_info.get("max")
    if data_min is None or data_max is None:
        return None, None
    data_min = np.asarray(data_min).reshape(-1)
    data_max = np.asarray(data_max).reshape(-1)
    if z_dim is not None and data_min.size != z_dim:
        data_min = data_min[-z_dim:]
        data_max = data_max[-z_dim:]
    return data_min, data_max


def load_macro_and_z(data_path):
    data_npz = np.load(data_path)
    macro = data_npz["macro"]
    Z = data_npz["Z"]
    if macro.ndim == 2:
        macro = macro[:, :, None]
    return macro, Z


def rk4_step(model, z, dt):
    # Classic fixed-step RK4 update.
    dt = torch.as_tensor(dt, device=z.device, dtype=z.dtype)
    k1 = model(z)
    k2 = model(z + 0.5 * dt * k1)
    k3 = model(z + 0.5 * dt * k2)
    k4 = model(z + dt * k3)
    return z + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)




def compute_rollout_loss(
    data_path,
    train_frac,
    dt_scalar,
    model_path,
    drift_model="OnsagerNet",
    device=None,
    eps=1e-12,
):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # load data: [num_tra, T, D]
    macro, Z = load_macro_and_z(data_path)
    state_dict, normalization_info = load_checkpoint(model_path, device)
    Z_min, Z_max = extract_min_max(normalization_info, z_dim=Z.shape[-1])
    if Z_min is not None and Z_max is not None:
        Z = 2.0 * (Z - Z_min) / (Z_max - Z_min) - 1.0
    data = np.concatenate([macro, Z], axis=-1)
    num_tra, T, n_dim = data.shape
    # print(f"Z_min shape: {Z_min.shape}, Z_max shape: {Z_max.shape}"); exit()


    # n_train = int(train_frac * num_tra)
    # data = data[n_train:]  # [n_val, T, D]
    if data.shape[0] == 0:
        print("No validation trajectories available for rollout loss.")
        return None

    data_t = torch.from_numpy(data).float().to(device)

    # build model and load best weights
    if drift_model == "OnsagerNet":
        model = OnsagerNet_original(input_dim=n_dim).to(device)
    elif drift_model == "MLP":
        model = drift_MLP(input_dim=n_dim).to(device)
    else:
        raise ValueError(f"Unknown drift_model: {drift_model}")

    model.load_state_dict(state_dict)
    model.eval()

    with torch.no_grad():
        z_t = data_t[:, 0, :]  # [n_val, D]
        data_macro_0 = data_t[:, :, 0]
        # print(data_macro_0[0, 0:10])

        # Only use the first dimension for the rollout loss.
        pred_macro_0 = z_t[:, 0]
        diff = pred_macro_0 - data_macro_0[:, 0]
        numerator = diff ** 2

        for step in range(1, T):
            z_t = rk4_step(model, z_t, dt_scalar)
            pred_macro_0 = z_t[:, 0]
            diff = pred_macro_0 - data_macro_0[:, step]
            numerator += diff ** 2

        denominator = (data_macro_0 ** 2).sum(dim=1)
        rel_error = numerator / (denominator + eps)
        rollout_loss = rel_error.mean().item()

    # print(f"Rollout mean relative error (dim 0): {rollout_loss:.6e}")
    return rollout_loss


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # parser.add_argument(
    #     "--base_folder",
    #     type=str,
    #     default="../trained_nflow_gmm2/exp1/",
    #     help="Path to .npy file of shape [num_tra, T, D].",
    # )
    
    parser.add_argument(
        "--dt",
        type=float,
        default=0.002,
        help="Time step between successive slices in the trajectories.",
    )    
        
    parser.add_argument("--train_frac", type=float, default=0.8)

    parser.add_argument("--first_K", type=int, default=5, help="Number of first trajectories to test and plot.")

    parser.add_argument("--drift_model", type=str, default="MLP", choices=["MLP", "OnsagerNet"])
    parser.add_argument(
        "--rollout_loss",
        action="store_true",
        help="Compute rollout mean relative error on validation trajectories.",
    )
    parser.set_defaults(rollout_loss=True)

    parser.add_argument("--device", type=int, default=0, help="CUDA device index to use.")

    args = parser.parse_args()
    device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")

    ## in-distribution test data
    input_file = "test_inDistribution.npz"

    # # out-of-distribution test data, 3-GMM
    # input_file = "test_outDistribution_3gmm.npz"

    # # out-of-distribution test data, 400 particles
    # input_file = "test_outDistribution_400N.npz"



    loss = {}
    

    for z_dim in range(1, 8, 1):
        loss[z_dim] = []
        for exp_id in range(1, 4, 1):
            args.base_folder = f"../trained_nflow_gmm2_diffZ/exp{exp_id}/macro_input_Z{z_dim}/"

    
            data_path = os.path.join(args.base_folder, input_file)
            model_path = os.path.join(args.base_folder, "learned_dynamics", args.drift_model, "best_drift_mlp.pth")
            output_path = os.path.join(args.base_folder, "learned_dynamics", args.drift_model)
            
            

            this_loss = compute_rollout_loss(
                data_path=data_path,
                train_frac=args.train_frac,
                dt_scalar=args.dt,
                model_path=model_path,
                drift_model=args.drift_model,
                device=device,
            )

            loss[z_dim].append(this_loss) 
            print(f"z_dim {z_dim}, exp {exp_id}, Rollout loss: {this_loss:.6e}")
    
    with open("loss.pkl", "wb") as f:        
        pickle.dump(loss, f)
