from model_extra import Latent_MoS_extra
import model_extra
import data_utils_extra
from data_utils_extra import gene_data, create_sequences_with_time, normalize_data, fourier_encode
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import mean_squared_error
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA




def compute_average_gate(gate_seq):
    avg_gate = gate_seq.mean(dim=(0, 1))  # Average over batch and time

    """
    Compute and print the average gating values across all batches and time steps,
    along with expert names.

    Args:
        gate_seq (torch.Tensor): shape [B, T, num_experts]
    """
    if avg_gate.shape[0] == 3:
        expert_names = [
            "Pi_rr",
            "Pi_sca",
            "Pi_tra",
        ]

    elif avg_gate.shape[0] == 9:
        expert_names = [
        "Pi_rr",
        "Pi_sca",
        "Pi_tra",
        "Pi_rr * Pi_sca",
        "Pi_rr * Pi_tra",
        "Pi_sca * Pi_rr",
        "Pi_sca * Pi_tra",
        "Pi_tra * Pi_rr",
        "Pi_tra * Pi_sca",
    ]


    print("Average gate weights across dataset and time:")
    for i, (name, weight) in enumerate(zip(expert_names, avg_gate)):
        print(f"  Expert {i} ({name}): {weight.item():.4f}")


def train_model_extra(model, bucketed_loaders_train, epochs=50, lr=0.001, save_path="best_model.pth"):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    best_loss = float('inf')
    model.train()

    for epoch in range(epochs):
        epoch_loss = 0
        total_batches = 0

        for start_idx, bucket in bucketed_loaders_train.items():
            train_loader = bucket['loader']
            t_inter = bucket['t_inter']
            t_extra = bucket['t_extra']  # optional

            for x_batch, y_inter_batch, y_extra_batch, mask_batch in train_loader:
                optimizer.zero_grad()

                predictions = model(x_batch, t_inter, t_extra, mask_batch)
                model.set_epoch(epoch)

                extrapolation_loss = criterion(predictions, y_extra_batch)
                loss = extrapolation_loss

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                total_batches += 1

        avg_epoch_loss = epoch_loss / total_batches
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_epoch_loss:.6f}")

        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            torch.save(model.state_dict(), save_path)
            print(f"New best model saved with loss: {best_loss:.6f}")


def test_and_visualize_model_win_extra(
        model,
        data_test,
        time_all,
        input_length=10,
        output_length=10,
        stride=4,
        mask_ratio=0.5,
        device='cpu',
        num_plot=3
):
    """
    Evaluate model via extrapolation and visualize results with PCA of latent flow.
    """
    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA
    import torch.nn.functional as F

    # --- Get extrapolation predictions and ground truth ---
    pred, gt, mask, avg_gate = data_utils_extra.reconstruct_full_test_trajectories_extra(
        model,
        data_test,
        time_all,
        input_length=input_length,
        output_length=output_length,
        stride=stride,
        mask_ratio=mask_ratio,
        device=device
    )

    assert pred.shape == gt.shape, f"Mismatch: pred {pred.shape}, gt {gt.shape}"

    # --- Compute MSE ---
    mse = F.mse_loss(pred, gt)
    print(f"\nExtrapolation MSE over future points: {mse.item():.6f}")

    mse_we_want = mse.item()

    # --- Gate weights ---
    if avg_gate is not None:
        expert_names = ["Pi_rr", "Pi_sca", "Pi_tra"] if avg_gate.shape[0] == 3 else [
            "Pi_rr", "Pi_sca", "Pi_tra",
            "Pi_rr * Pi_sca", "Pi_rr * Pi_tra",
            "Pi_sca * Pi_rr", "Pi_sca * Pi_tra",
            "Pi_tra * Pi_rr", "Pi_tra * Pi_sca"
        ]
        print("\nAverage gate weights across extrapolation:")
        for i, (name, weight) in enumerate(zip(expert_names, avg_gate)):
            print(f"  Expert {i} ({name}): {weight.item():.4f}")

    # --- Visualization ---
    t_np = np.array(time_all)[-pred.shape[1]:]  # align with extrapolated portion
    pred_np = pred.numpy()
    gt_np = gt.numpy()
    N, T, D = gt_np.shape

    for i in range(min(num_plot, N)):
        plt.figure(figsize=(12, 4))
        for d in range(D):
            plt.subplot(1, D, d + 1)

            # Plot true and predicted
            plt.plot(t_np, gt_np[i, :, d], label='Ground Truth', color='black')
            plt.plot(t_np, pred_np[i, :, d], '--', label='Prediction (Extrapolated)', color='blue')

            # Get unmasked indices from the mask (0 means masked, 1 means observed)
            unmasked_idx = np.where(mask[i, -T:] == 1)[0]
            if len(unmasked_idx) > 0:
                plt.scatter(t_np[unmasked_idx], gt_np[i, unmasked_idx, d],
                            color='green', label='Input (Unmasked)', zorder=10)

            # Mark extrapolation start
            plt.axvline(x=t_np[0], color='gray', linestyle=':', label='Extrapolation Start')

            plt.title(f"Trajectory {i}, Dim {d}")
            plt.xlabel("Time")
            plt.ylabel("Value")
            plt.legend()
            plt.grid(True)

        plt.tight_layout()
        plt.show()

    return mse_we_want



if __name__ == "__main__":

    # parameters
    seed = 24
    data_utils_extra.set_seed(seed)
    latent_dim = 15 # symmetry dimensions for the latent space
    batch_size = 32
    num_epochs = 150
    learning_rate = 0.002 #0.001 for dense data 0.3/0.6, 0.002 for sparse data
    top_k_gates = 4
    if_mut_sym = 1 #  consider the second order correlation of the symmetries or not

    input_length = 10  # 10 for air quality, 30 for others
    num_subintervals = 2 # number of subintervals that maintain the same equivariance, 5 for dense data 0.3/0.6, 2 for sparse data
    output_length = 10
    stride = 4
    mask_ratio = 0.6
    train_ratio = 0.6



    if if_mut_sym == 0:
        gate_dim = 3
    else:
        gate_dim = 9

    # Step 1: Data Preparation

    # spiral data
    # data, time_all = data_utils_extra.generate_spiral_dataset(n_trajectories=80, total_steps=60, visualize=True)
    # name = "spiral_extra"

    # load data
    # data, time_all = data_utils_extra.generate_load_dataset(n_trajectories=100, total_steps=144)
    # name = "load_extra"

    # glycolytic
    # data, time_all = data_utils_extra.generate_glycolytic_dataset(n_trajectories=100, total_steps=200)
    # name = "glycolytic_extra"

    # lotka
    # data, time_all = data_utils_extra.generate_lotka_dataset(n_trajectories=100, total_steps=200, noise=0.0)
    # name = "lotka_extra"

    # data, time_all = data_utils_extra.generate_power_event_dataset(n_trajectories=100, total_steps=100)
    # name = "power_event_extra"

    # solar power
    # data, time_all = data_utils_extra.generate_PV_dataset(n_trajectories=124, total_steps=100)
    # name = "PV_extra"
    #
    # data, time_all = data_utils_extra.generate_AirQuality_dataset(n_trajectories=73, total_steps=24)
    # name = "AirQuality_extra"

    data, time_all = data_utils_extra.generate_ECG_dataset()
    name = "ECG_extra"

    # split train and test
    N = data.shape[0]
    train_N = int(train_ratio * N)

    data_train = data[:train_N]
    data_test = data[train_N:]

    # Step 2: Create windowed data from training trajectories
    windows_train = data_utils_extra.create_tagged_windows(
        data_train, time_all,
        input_length=input_length,
        output_length=output_length,
        mask_ratio=mask_ratio,
        stride=stride,
        seed=42
    )

    # Step 3: Bucket windows by start index
    buckets_train = data_utils_extra.bucket_windows_by_start(windows_train)

    # Step 4: Build DataLoaders
    bucketed_loaders_train = data_utils_extra.build_bucket_dataloaders(buckets_train, batch_size=32)

    # Step 2: Initialize Model
    # input_length = x_train.shape[1]
    # output_length = y_extrap_train.shape[1]
    output_dim = data.shape[2]
    input_dim = data.shape[2]



    model = Latent_MoS_extra(
        input_dim=input_dim,
        latent_dim=latent_dim,
        output_dim=output_dim,
        output_length=output_length,
        time_feat_dim=1,
        gate_dim=gate_dim,
        top_k_gates=top_k_gates,
        num_subintervals = num_subintervals
    )

    # Step 3: Train the Model
    model_name = f"LatentMoS_{int(mask_ratio * 100)}mask_" + name + ".pth"

    train_model_extra(model, bucketed_loaders_train, epochs=num_epochs, lr=learning_rate, save_path=model_name)

    # Step 4: Test and Visualize
    best_model = Latent_MoS_extra(
        input_dim=input_dim,
        latent_dim=latent_dim,
        output_dim=output_dim,
        output_length=output_length,
        time_feat_dim=1,
        gate_dim=gate_dim,
        top_k_gates=top_k_gates,
        num_subintervals=num_subintervals
    )

    best_model.load_state_dict(torch.load(model_name))
    best_model.eval()


    # # Test mode 1: test the trajectory at once
    # # Step 1: Prepare and save
    # x_test, y_test, mask_test, t_test = data_utils.prepare_masked_full_test_set(
    #     data_test=data_test,
    #     time_all=time_all,
    #     mask_ratio=0.8,
    #     seed=999
    # )
    #
    # # -- Run evaluation + visualization --
    # test_and_visualize_model(
    #     best_model,
    #     x_test,
    #     y_test,
    #     mask_test,
    #     t_test,
    #     device='cuda' if torch.cuda.is_available() else 'cpu',
    #     num_plot=1
    # )

    # Test mode 2: test the moving windows and do average for overlapps
    test_and_visualize_model_win_extra(
        model=best_model,
        data_test=data_test,
        time_all=time_all,
        input_length=input_length,
        output_length=10,
        stride=stride,
        mask_ratio=mask_ratio,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        num_plot=1
    )
