import argparse
import os

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.ticker import FuncFormatter, LogFormatterMathtext, LogLocator
import numpy as np
import torch

from models import OnsagerNet_original
from models import drift_MLP


def load_checkpoint(model_path, device):
    state = torch.load(model_path, map_location=device, weights_only=False)
    normalization_info = None
    if isinstance(state, dict) and "model_state_dict" in state:
        normalization_info = state.get("normalization_info")
        state = state["model_state_dict"]
    return state, normalization_info


def extract_min_max(normalization_info, z_dim=None):
    if not normalization_info:
        return None, None
    data_min = normalization_info.get("z_min")
    data_max = normalization_info.get("z_max")
    if data_min is None or data_max is None:
        data_min = normalization_info.get("min")
        data_max = normalization_info.get("max")
    if data_min is None or data_max is None:
        return None, None
    data_min = np.asarray(data_min).squeeze()
    data_max = np.asarray(data_max).squeeze()
    if z_dim is not None and data_min.shape[0] != z_dim:
        data_min = data_min.reshape(-1)[-z_dim:]
        data_max = data_max.reshape(-1)[-z_dim:]
    return data_min, data_max


def load_macro_normalization(macro_norm_path):
    if macro_norm_path is None:
        return None, None
    if not os.path.exists(macro_norm_path):
        print(f"Macro normalization file not found: {macro_norm_path}")
        return None, None
    with np.load(macro_norm_path) as norm_data:
        for key_min, key_max in (("Z_min", "Z_max"), ("macro_min", "macro_max"), ("min", "max")):
            if key_min in norm_data and key_max in norm_data:
                return norm_data[key_min], norm_data[key_max]
    print(f"Macro normalization file missing min/max keys: {macro_norm_path}")
    return None, None


def denormalize_macro(values, macro_min, macro_max):
    if macro_min is None or macro_max is None:
        return values
    macro_min_val = float(np.asarray(macro_min).reshape(-1)[0])
    macro_max_val = float(np.asarray(macro_max).reshape(-1)[0])
    scale = 0.5 * (macro_max_val - macro_min_val)
    return (values + 1.0) * scale + macro_min_val


def get_color_pair(idx):
    pairs = [
        ("#1F77B4", "#17BECF"),  # blue
        ("#ff7f0e", "#F4E02C"),  # orange
        ("#228A22", "#42F311"),  # green
    ]
    return pairs[idx % len(pairs)]


def load_macro_and_z(data_path):
    data_npz = np.load(data_path)
    macro = data_npz["macro"]
    Z = data_npz["Z"]
    if macro.ndim == 2:
        macro = macro[:, :, None]
    return macro, Z

def rk4_step(model, z, dt):
    # Classic fixed-step RK4 update.
    dt = torch.as_tensor(dt, device=z.device, dtype=z.dtype)
    k1 = model(z)
    k2 = model(z + 0.5 * dt * k1)
    k3 = model(z + 0.5 * dt * k2)
    k4 = model(z + dt * k3)
    return z + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)


def plot_macro_dim0(
    data_path,
    dt_scalar,
    model_path,
    first_K=5,
    macro_norm_path=None,
    output_path=".",
    output_name="TEST-true_vs_pred_dim0.png",
    drift_model="OnsagerNet",
    device=None,
):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # load data: [num_tra, T, D]
    macro, Z = load_macro_and_z(data_path)
    if macro.shape[-1] < 1:
        raise ValueError("Expected data with at least one dimension.")

    state_dict, normalization_info = load_checkpoint(model_path, device)
    Z_min, Z_max = extract_min_max(normalization_info, z_dim=Z.shape[-1])
    if Z_min is not None and Z_max is not None:
        Z = 2.0 * (Z - Z_min) / (Z_max - Z_min) - 1.0
    else:
        print("No Z normalization info found in checkpoint; using raw Z.")
    macro_min, macro_max = load_macro_normalization(macro_norm_path)
    data = np.concatenate([macro, Z], axis=-1)
    num_tra, _, n_dim = data.shape

    if data.shape[0] == 0:
        print("No trajectories available for plotting.")
        return

    # Sort trajectories by their initial value in the first dimension.
    sort_idx = np.argsort(data[:, 0, 0])
    data = data[sort_idx]
    selected_idx = np.linspace(
        0, data.shape[0] - 1, min(first_K, data.shape[0])
    ).astype(int)
    selected_idx = [int(num_tra * 0.2), int(num_tra * 0.7), int(num_tra * 0.98)]
    print(f"Selected trajectory indices for plotting: {selected_idx}")

    data = torch.from_numpy(data).float().to(device)

    # build model and load weights
    if drift_model == "OnsagerNet":
        model = OnsagerNet_original(input_dim=n_dim).to(device)
    elif drift_model == "MLP":
        model = drift_MLP(input_dim=n_dim).to(device)
    else:
        raise ValueError(f"Unknown drift_model: {drift_model}")

    model.load_state_dict(state_dict)
    model.eval()

    with torch.no_grad():
        n_exp = min(first_K, data.shape[0])
        if n_exp == 0:
            print("No trajectories available for plotting.")
            return

        data_sel = data[selected_idx]  # [n_exp, T, D]
        steps = data_sel.shape[1]

        z_t = data_sel[:, 0, :]  # [n_exp, D]
        preds = [z_t]
        for _ in range(steps - 1):
            z_t = rk4_step(model, z_t, dt_scalar)
            preds.append(z_t)

        pred_tra = torch.stack(preds, dim=1)  # [n_exp, T, D]

        t = np.arange(steps)
        fig, ax = plt.subplots(figsize=(1.6, 1.6))
        # time_steps = data_sel.shape[1]
        time_steps = 100
        t = t[:time_steps]
        true_all = data_sel[:, :time_steps, 0].detach().cpu().numpy()
        pred_all = pred_tra[:, :time_steps, 0].detach().cpu().numpy()
        if macro_min is not None and macro_max is not None:
            true_all = denormalize_macro(true_all, macro_min, macro_max)
            pred_all = denormalize_macro(pred_all, macro_min, macro_max)
        for exp_idx in range(n_exp):
            true_label = "true" if exp_idx == 0 else None
            pred_label = "pred" if exp_idx == 0 else None

            base_color, pred_color = get_color_pair(exp_idx)

            (true_line,) = ax.plot(
                t,
                true_all[exp_idx],
                linewidth=1,
                color=base_color,
                label=true_label,
                zorder=3
            )
            ax.plot(
                t,
                pred_all[exp_idx],
                linestyle=":",
                color=pred_color,
                linewidth=1,
                label=pred_label,
                zorder=4
            )

        font_size = 6
        label_pad = 1
        tick_pad = 1
        ax.set_xlabel("Time steps", fontsize=font_size, labelpad=label_pad)
        ax.set_ylabel("Energy", fontsize=font_size, labelpad=label_pad)
        ax.set_yscale("log")

        # If values fall within a single decade, show mantissas (2..9) and draw
        # a single “×10^k” in the top-left corner (instead of repeating “×10^k”
        # in every y-tick label).
        all_vals = np.concatenate([true_all.reshape(-1), pred_all.reshape(-1)])
        positive = all_vals[np.isfinite(all_vals) & (all_vals > 0)]
        if positive.size:
            exp_low = int(np.floor(np.log10(float(np.min(positive)))))
            exp_high = int(np.floor(np.log10(float(np.max(positive)))))
            if exp_low == exp_high:
                scale_exp = exp_high
                scale = 10.0 ** scale_exp
                ax.yaxis.set_major_locator(LogLocator(base=10.0, subs=np.arange(2, 10)))
                ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y / scale:g}"))
                ax.yaxis.set_minor_formatter(FuncFormatter(lambda *_: ""))
                ax.text(
                    0.01,
                    1.1,
                    rf"$\times 10^{{{scale_exp}}}$",
                    transform=ax.transAxes,
                    ha="left",
                    va="top",
                    fontsize=font_size,
                )
            else:
                ax.yaxis.set_major_formatter(LogFormatterMathtext(base=10))
        else:
            ax.yaxis.set_major_formatter(LogFormatterMathtext(base=10))
        ax.tick_params(axis="both", which="major", labelsize=font_size, pad=tick_pad)
        legend_handles = [
            Line2D([0], [0], color="black", linewidth=1.0, linestyle="-", label="true"),
            Line2D([0], [0], color="black", linewidth=1.0, linestyle=":", label="pred"),
        ]
        ax.legend(handles=legend_handles, frameon=True, fontsize=font_size, loc="best")
        fig.tight_layout(pad=0.1)

        os.makedirs(output_path, exist_ok=True)
        out_file = os.path.join(output_path, output_name)
        fig.savefig(out_file, dpi=300)
        npz_name = os.path.splitext(output_name)[0] + ".npz"
        npz_path = os.path.join(output_path, npz_name)
        np.savez(npz_path, true_all=true_all, pred_all=pred_all, t=t)
        plt.close(fig)
        print(f"Saved comparison plot to {out_file}")
        print(f"Saved trajectories to {npz_path}")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Plot true vs predicted macro dimension 0 trajectories."
    )
    parser.add_argument(
        "--base_folder",
        type=str,
        default="../trained_nflow_gmm2/exp1/",
        help="Base folder containing macro_input and learned_dynamics.",
    )
    parser.add_argument(
        "--input_file",
        type=str,
        default="test_inDistribution.npz",
        help="Test dataset filename under macro_input.",
    )
    parser.add_argument(
        "--dt",
        type=float,
        default=0.002,
        help="Time step between successive slices in the trajectories.",
    )
    parser.add_argument(
        "--first_K",
        type=int,
        default=3,
        help="Number of trajectories to plot.",
    )
    parser.add_argument(
        "--drift_model",
        type=str,
        default="MLP",
        choices=["MLP", "OnsagerNet"],
    )
    parser.add_argument(
        "--output_name",
        type=str,
        default="TEST-true_vs_pred_dim0.pdf",
        help="Output figure filename.",
    )
    parser.add_argument(
        "--macro_norm_path",
        type=str,
        default=os.path.join(
            os.path.dirname(__file__),
            "..",
            "generate_dataset",
            "data",
            "macro_feature_normalization.npz",
        ),
        help="Path to macro normalization info for denormalizing plots.",
    )

    args = parser.parse_args()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_path = os.path.join(args.base_folder, "macro_input", args.input_file)
    model_path = os.path.join(
        args.base_folder, "learned_dynamics", args.drift_model, "best_drift_mlp.pth"
    )
    output_path = os.path.join(
        args.base_folder, "learned_dynamics", args.drift_model
    )

    plot_macro_dim0(
        data_path=data_path,
        dt_scalar=args.dt,
        model_path=model_path,
        first_K=args.first_K,
        macro_norm_path=args.macro_norm_path,
        output_path=output_path,
        output_name=args.output_name,
        drift_model=args.drift_model,
        device=device,
    )


if __name__ == "__main__":
    main()
