import torch
from torch.utils.data import DataLoader, TensorDataset
from data_utils import ODESequenceDataModule, read_data, TimeSeriesDataset, read_swing_data, read_corey_matlab_data
from model import GRUUpdate, GRUEncoder, GRUUpdateEnhanced, GRUDeltaEncoderEnhanced, TimeModulationEmbedder, FourierFeaturePositionalEncoding, TokenSelfAttentionStack, TokenPrimitiveCrossAttentionStack,DynamicHubUpdateBlock, NonlinearDecoderMLP, EmbedToParamMLP, ReconstructMLP, LassODE
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import math
import time
from itertools import zip_longest
import random


# --- The Main Evaluation Function ---
@torch.no_grad()
def evaluate_and_plot(model_class, num_tokens, model_kwargs, model_path,
                      dataloaders, device, save_dir=None, plot_indices=[0, 1, 2]):
    """
    Evaluate LassODE model on multiple systems and plot predictions vs. ground truth.

    Args:
        model_class: The LassODE class.
        model_kwargs: Keyword arguments used to initialize LassODE.
        model_path: Path to trained model checkpoint (saved using state_dict).
        dataloaders: List of DataLoaders, one per system.
        device: torch.device.
        true_dims: list of ints, true dimensionality of each system before padding
        save_dir: Optional directory to save plots.
        plot_indices: Which sequences to visualize.

    Returns:
        results: dict mapping sys_id -> list of MSE losses per scenario
    """
    model = model_class(**model_kwargs).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    results = {}

    for sys_id, test_loader in enumerate(dataloaders):
        all_preds = []
        all_trues = []
        d_true = true_dims[sys_id]

        print(f"\n🔎 Evaluating system {sys_id}... (true_dim={d_true})")

        for x_batch, t_batch in test_loader:
            x_batch = x_batch.to(device)   # [B, T, D]
            t_batch = t_batch.to(device)   # [B, T]
            t_scalar = t_batch[0]          # assume shared time vector

            # Generate prefix times
            t_start, t_end = t_scalar[0].item(), t_scalar[-1].item()
            prefix_time_list = torch.tensor([
                t_start + 0.1 * (t_end - t_start),
                t_start + 0.5 * (t_end - t_start),
                t_start + 0.9 * (t_end - t_start)
            ], device=device)

            # Tokenize time
            token_times = model.tokenize_time(num_tokens, device=device)

            t_target = t_scalar
            # Forward prediction: [S, B, T, D]
            x_pred = model(x_batch, t_scalar, t_target, prefix_time_list, token_times)

            # Cut down to true dimensions
            x_pred = x_pred[..., :d_true]      # [S, B, T, d_true]
            x_batch = x_batch[..., :d_true]    # [B, T, d_true]

            all_preds.append(x_pred.cpu())
            all_trues.append(x_batch.cpu())

        # Concatenate
        all_preds = torch.cat(all_preds, dim=1)   # [S, total_B, T, d_true]
        all_trues = torch.cat(all_trues, dim=0)   # [total_B, T, d_true]

        # Compute losses
        sys_losses = []
        for s in range(all_preds.shape[0]):
            loss = model.compute_loss(all_preds[s], all_trues, true_dim=d_true).item()
            sys_losses.append(loss)
            print(f"System {sys_id} | Scenario {s} MSE: {loss:.6f}")

        results[sys_id] = sys_losses

        # Plot if needed
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            t = t_scalar.cpu().numpy()
            for idx in plot_indices:
                if idx >= all_trues.shape[0]:
                    continue
                for s in range(all_preds.shape[0]):
                    plt.figure(figsize=(8, 3))
                    for d in range(d_true):  # only plot true dims
                        plt.plot(t, all_trues[idx, :, d], label=f'True Dim {d}', linestyle='--')
                        plt.plot(t, all_preds[s, idx, :, d], label=f'Pred Dim {d}')
                    plt.title(f"System {sys_id} | Seq {idx} | Scenario {s}")
                    plt.xlabel("Time")
                    plt.ylabel("Value")
                    plt.legend()
                    save_path = f"{save_dir}/sys{sys_id}_seq{idx}_scenario{s}.png"
                    plt.savefig(save_path)
                    print(f"Saved plot to {save_path}")
                    plt.close()

    return results



def get_dataloader(x_tokens, t_tokens, mask_tokens, batch_size=8, shuffle=True):
    dataset = TensorDataset(x_tokens, t_tokens, mask_tokens)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def build_warmup_cosine(optimizer, total_steps, warmup_ratio=0.03, min_lr_ratio=0.1):
    init_lr = optimizer.param_groups[0]["lr"]
    warmup_steps = max(1, int(total_steps * warmup_ratio))
    min_lr = init_lr * min_lr_ratio

    def lr_lambda(step):
        if step < warmup_steps:
            return step / float(warmup_steps)
        progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return (min_lr / init_lr) + 0.5 * (1.0 - (min_lr / init_lr)) * (1.0 + math.cos(math.pi * progress))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def train(model, dataloader, num_epochs, num_scenarios, num_tokens, model_path, plot_path="training_loss.png"):
    """
    Train the LassODE model and save the best checkpoint and training loss plot.

    Args:
        model: LassODE instance
        dataloader: DataLoader yielding (x_batch, mask_batch, t_batch)
        num_epochs: number of training epochs
        num_scenarios: number of different prefix times (scenarios)
        num_tokens: number of tokenized time chunks
        model_path: where to save the best model
        plot_path: path to save training loss plot
    """
    model.train()
    best_loss = float('inf')
    loss_history = []
    start_time = time.time()

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        batch_count = 0

        for x_batch, mask_batch, t_batch in dataloader:
            x_batch = x_batch.to(device)
            t_batch = t_batch.to(device)

            t_scalar = t_batch[0]  # Assume fixed resolution for all batches

            prefix_time_list = model.generate_prefix_times(t_scalar, num_scenarios)
            token_times = model.tokenize_time(t_scalar, num_tokens)

            x_traj = model(x_batch, t_scalar, prefix_time_list, token_times)
            loss = model.compute_loss(x_traj, x_batch)

            model.optimizer.zero_grad()
            loss.backward()
            model.optimizer.step()

            epoch_loss += loss.item()
            batch_count += 1

        avg_loss = epoch_loss / batch_count
        loss_history.append(avg_loss)
        print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {avg_loss:.6f}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), model_path)
            print(f"✅ Best model saved at epoch {epoch+1} with loss {avg_loss:.6f}")

    total_time = time.time() - start_time
    print(f"\n⏱️ Total training time: {total_time:.2f} seconds")

    # Plot and save training loss
    plt.figure(figsize=(6, 4))
    plt.plot(loss_history, label='Training Loss')
    plt.xlabel("Epoch")
    plt.ylabel("MSE Loss")
    plt.title("Training Loss over Epochs")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(plot_path)
    print(f"📈 Training loss plot saved to {plot_path}")

def train_lr_schedule(model, dataloader, num_epochs, num_scenarios, num_tokens, model_path, plot_path="training_loss.png"):
    """
    Train the LassODE model and save the best checkpoint and training loss plot.

    Args:
        model: LassODE instance
        dataloader: DataLoader yielding (x_batch, t_batch)
        num_epochs: number of training epochs
        num_scenarios: number of different prefix times (scenarios)
        num_tokens: number of tokenized time chunks
        model_path: where to save the best model
        plot_path: path to save training loss plot
    """
    model.train()
    best_loss = float('inf')
    loss_history = []
    start_time = time.time()

    total_steps = num_epochs * len(dataloader)
    scheduler = build_warmup_cosine(model.optimizer, total_steps, warmup_ratio=0.04, min_lr_ratio=0.1)

    global_step = 0

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        batch_count = 0
        epoch_loader = mixed_batches(dataloaders, verbose=False) # Print system id order if verbose

        for sys_id, (x_batch, t_batch) in epoch_loader:
            x_batch = x_batch.to(device)
            t_batch = t_batch.to(device)

            t_scalar = t_batch[0]  # Assume fixed resolution for each batches

            prefix_time_list = model.generate_prefix_times(t_scalar, num_scenarios)
            token_times = model.tokenize_time(num_tokens=num_tokens,device=device,T=Tmax)
            t_target = t_scalar  # during training, predict on observed grid
            x_traj = model(x_batch, t_scalar, t_target, prefix_time_list, token_times)  # TODO: use mask to generate \Delta_t for GRU-Deltat

            true_dim = true_dims[sys_id]
            loss = model.compute_loss(x_traj, x_batch, true_dim=true_dim)

            model.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            model.optimizer.step()
            scheduler.step()

            epoch_loss += loss.item()
            batch_count += 1
            global_step += 1

        avg_loss = epoch_loss / batch_count
        loss_history.append(avg_loss)
        current_lr = scheduler.get_last_lr()[0]
        print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {avg_loss:.6f} | LR: {current_lr:.2e}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), model_path)
            print(f"✅ Best model saved at epoch {epoch+1} with loss {avg_loss:.6f}")

    total_time = time.time() - start_time
    print(f"\n⏱️ Total training time: {total_time:.2f} seconds")

    # Plot and save training loss
    plt.figure(figsize=(6, 4))
    plt.plot(loss_history, label='Training Loss')
    plt.xlabel("Epoch")
    plt.ylabel("MSE Loss")
    plt.title("Training Loss over Epochs")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(plot_path)
    print(f"📈 Training loss plot saved to {plot_path}")


def load_systems(Tmax):
    systems = []

    # --- Synthetic system ---
    datapath1 = "C:\\Software\\Large_Dynamic_Model\\ODE_embedding\\Dataset\\Synthetic1"
    x1, t1, num_sys1 = read_data(datapath1, plot=True)
    x1 = torch.tensor(x1, dtype=torch.float32)   # [N, T, D1]
    t1 = torch.tensor(t1, dtype=torch.float32)   # [N, T]
    t1 = t1 * 2
    t1 = t1 * Tmax
    systems.append((x1, t1))

    # --- Swing system ---
    datapath2 = "C:\\Software\\Large_Dynamic_Model\\ODE_embedding\\Dataset\\Swing_data\\swing_data.npy"
    x2, t2, num_sys2 = read_swing_data(
        filepath=datapath2,
        state_dim=2,
        max_seq_length=150,
        sample_resolution=0.03,
        use_seq=30,
        plot=True
    )
    x2 = torch.tensor(x2, dtype=torch.float32)   # [N, T, D2]
    t2 = torch.tensor(t2, dtype=torch.float32)   # [N, T]
    t2 = t2 * Tmax
    systems.append((x2, t2))

    # --- Corey Matlab data ---
    datapath3 = "C:\\Software\\Large_Dynamic_Model\\ODE_embedding\\Dataset\\Corey_matlab_data\\"
    x3, t3, num_sys3 = read_corey_matlab_data(
        filepath=datapath3,
        state_dim=3,
        max_seq_length=150,
        sample_resolution=0.1,
        use_seq=30,
        plot=True
    )
    x3 = torch.tensor(x3, dtype=torch.float32)   # [N, T, D2]
    t3 = torch.tensor(t3, dtype=torch.float32)   # [N, T]
    t3 = t3 * Tmax
    systems.append((x3, t3))

    true_dims = [x.size(2) for x, _ in systems]  # before padding

    # --- Find maximum dimension ---
    max_dim = max(x.size(2) for x, _ in systems)

    # --- Pad/repeat each system to max_dim ---
    padded_systems = []
    for x, t in systems:
        dim = x.size(2)
        if dim < max_dim:
            # repeat channels until reaching max_dim
            reps = (max_dim + dim - 1) // dim   # ceil(max_dim/dim)
            x_expanded = x.repeat(1, 1, reps)   # repeat along last dim
            x_expanded = x_expanded[:, :, :max_dim]  # trim to exact max_dim
            padded_systems.append((x_expanded, t))
        else:
            padded_systems.append((x, t))

    return padded_systems, max_dim, true_dims

def mixed_batches(dataloaders, verbose=False):
    # Collect all batches from all systems
    all_batches = []
    for sys_id, loader in enumerate(dataloaders):
        for batch in loader:
            all_batches.append((sys_id, batch))  # keep sys_id with batch

    # Shuffle batches randomly
    random.shuffle(all_batches)

    # Print system id order if verbose
    if verbose:
        sys_order = [sys_id for sys_id, _ in all_batches]
        print("🔀 Shuffled system order:", sys_order)

    # Yield batches
    for sys_id, batch in all_batches:
        yield sys_id, batch


if __name__ == "__main__":
    # ----------------------------
    # Hyperparameters
    # ----------------------------
    seed = 44
    torch.manual_seed(seed)
    print("Torch version:", torch.__version__)
    print("NumPy version:", np.__version__)
    print("CUDA available:", torch.cuda.is_available())
    print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = "cpu"


    # ODE data generation
    # t_cutoff = 20 # cut-off time, e.g., some ODEs may only have 15s rather than 20s
    # t_resolution = 1 # skip steps. 1: same, no mask applied. no skip for the data resolution
    # num_sequences = 20
    # total_time = 20
    # seq_length = 400
    # data_module = ODESequenceDataModule(
    #     num_sequences=num_sequences,
    #     seq_length=seq_length,
    #     state_dim=data_dim,
    #     total_time=total_time,
    #     batch_size=batch_size,
    #     device=device
    # )
    #
    # data_module.generate_data()
    # dataloader = data_module.get_dataloader(t_res=t_resolution, t_cutoff=t_cutoff)
    #
    # data_module.plot_sequences(num_to_plot=6)
    Tmax = 1
    systems, max_dim, true_dims  = load_systems(Tmax)
    print("Unified input dimension:", max_dim)
    for i, (x, t) in enumerate(systems):
        print(f"System {i}: x {x.shape}, t {t.shape}")

    # model architecture
    data_dim = max_dim # maximum dimension across different ODE systems
    latent_dim = 15
    embed_dim = 256
    num_layer_GRU = 2 # Encoding GRU cell layers
    rnn_hidden_dim = 256  # Larger dimension for the encoder's "working memory"
    num_tokens = 50  # number of tokens in each ODE process [0, total_time]
    model_path = "Lass-ODE-swing.pt" # large scale small ODE
    # attention parameters
    num_primitives = 30 # number of premitives in the ODE primitive library
    num_heads = 8 # number of heads
    depth = 6 # number of attention blocks
    num_block = 2 # number of inter-intra-GRU block
    dropout = 0 # 0.1
    mlp_ratio = 4 # MLP hidden = mlp_ratio × embed_dim
    hidden_unit = 128


    save_dir = "figures/eval"
    os.makedirs(save_dir, exist_ok=True)
    print(f"Created directory: {save_dir}")

    # training setting
    lr = 5e-4
    num_epochs = 1500
    num_scenarios = 15 # number of scenarios for different windows of prefix as input
    batch_size = 32

    dataloaders = []
    for (x, t) in systems:
        dataset = TensorDataset(x, t)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)  # no reshuffle
        dataloaders.append(loader)


    # Initialize encoder model
    # encoder = ODERNNEncoder(
    #     input_dim=data_dim,  # D
    #     latent_dim=latent_dim,  # z0 size
    #     rnn_hidden_dim=rnn_hidden_dim,  # hidden size for ODE-RNN
    #     hidden_units=hidden_unit,  # hidden size for ODE func
    #     method='rk4'  # or 'dopri5'
    # ).to(device)

    # GRU version 1, less powerful, more efficient
    # gru_update1 = GRUUpdate(input_dim=data_dim, rnn_hidden_dim=rnn_hidden_dim)
    # GRU version 2, more powerful
    # gru_update2 = GRUUpdateEnhanced(input_dim=data_dim, rnn_hidden_dim=rnn_hidden_dim, hidden_units=hidden_unit, use_layernorm=True)
    # encoder = GRUEncoder(
    #     latent_dim=latent_dim,  # z0 size
    #     rnn_hidden_dim=rnn_hidden_dim,  # hidden size for ODE-RNN
    #     gru_update_module=gru_update2
    # ).to(device)

    # GRU version 3, powerful, use GRU-Deltat to speed up
    encoder = GRUDeltaEncoderEnhanced(
        input_dim=data_dim,
        latent_dim=latent_dim,
        rnn_hidden_dim=rnn_hidden_dim,
        num_layers=num_layer_GRU,
        dropout=dropout,
    ).to(device)

    # embedder = Z0ToEmbedMLP(latent_dim=rnn_hidden_dim, embed_dim=embed_dim).to(device) # use hT but not z0
    embedder = TimeModulationEmbedder(rnn_hidden_dim=rnn_hidden_dim, embed_dim=embed_dim).to(device)
    pos_encoder = FourierFeaturePositionalEncoding(embed_dim=embed_dim).to(device)
    param_decoder = EmbedToParamMLP(latent_dim=latent_dim, embed_dim=embed_dim).to(device)
    # nonlinear_decoder = NonlinearDecoderMLP(latent_dim=latent_dim, embed_dim=embed_dim, hidden_units=hidden_unit).to(device)
    intra_attn_stack = TokenSelfAttentionStack(
        embed_dim=embed_dim,
        num_heads=num_heads,
        depth=depth,
        dropout=dropout,
        mlp_ratio=mlp_ratio
    ).to(device)

    inter_attn_stack = TokenPrimitiveCrossAttentionStack(
        embed_dim=embed_dim,
        num_heads=num_heads,
        depth=depth,
        dropout=dropout,
        mlp_ratio=mlp_ratio,
    ).to(device)

    reconstruct_mlp = ReconstructMLP(
        latent_dim=latent_dim,  # dimensionality of z
        output_dim=data_dim,  # dimensionality of x (e.g., state variables)
        hidden_units=hidden_unit  # optional, hidden layer size
    ).to(device)

    dynamic_hub_block = DynamicHubUpdateBlock(
        embed_dim=embed_dim,
        num_heads=num_heads,
        dropout=dropout
    ).to(device)

    # Construct the LassODE model
    lass_ode = LassODE(
        method='dopri5',
        embed_dim=embed_dim,
        num_blocks=num_block,
        num_primitives=num_primitives,
        encoder=encoder,
        embedder=embedder,
        pos_encoder=pos_encoder,
        param_decoder=param_decoder,
        reconstruct_mlp=reconstruct_mlp,
        intra_attn_stack=intra_attn_stack,
        inter_attn_stack=inter_attn_stack,
        dynamic_hub_block=dynamic_hub_block,
    ).to(device)

    num_params = sum(p.numel() for p in lass_ode.parameters())
    num_params_B = num_params / 1e9
    print(f"Model size: {num_params_B:.3f}B parameters")

    # Assign optimizer inside the model
    # lass_ode.optimizer = torch.optim.Adam(lass_ode.parameters(), lr=lr)

    # AdamW: Adam with decoupled weight decay

    lass_ode.optimizer = torch.optim.AdamW(
        lass_ode.parameters(),
        lr=lr,
        betas=(0.9, 0.95),
        weight_decay=0.05  # helps generalization
    )

    # training procedure
    # train_lr_schedule(
    #     model=lass_ode,
    #     dataloader=dataloaders,
    #     num_epochs=num_epochs,
    #     num_scenarios=num_scenarios,
    #     num_tokens=num_tokens,
    #     model_path=model_path,
    #     plot_path="figures/training_loss.png",
    # )

    model_kwargs = {
        'method': 'dopri5',
        'embed_dim': embed_dim,
        'num_blocks': num_block,
        'num_primitives': num_primitives,
        'encoder': encoder,
        'embedder': embedder,
        'pos_encoder': pos_encoder,
        'param_decoder': param_decoder,
        'reconstruct_mlp': reconstruct_mlp,
        'intra_attn_stack': intra_attn_stack,
        'inter_attn_stack': inter_attn_stack,
        'dynamic_hub_block': dynamic_hub_block,
    }

    results = evaluate_and_plot(
        model_class=LassODE,
        model_kwargs=model_kwargs,
        model_path=model_path,
        num_tokens=num_tokens,
        dataloaders=dataloaders,  # list of dataloaders, order = sys_id
        device=device,
        save_dir="figures/eval",
        plot_indices=[0, 5, 10, 20, 30]
    )

