from model_inter_slow_no_rot import Latent_MoS
import model_inter_slow_no_rot
import data_utils_inter_slow_no_rot
from data_utils_inter_slow_no_rot import gene_data, create_sequences_with_time, normalize_data, fourier_encode
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import mean_squared_error
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA




def compute_average_gate(gate_seq):
    avg_gate = gate_seq.mean(dim=(0, 1))  # Average over batch and time

    """
    Compute and print the average gating values across all batches and time steps,
    along with expert names.

    Args:
        gate_seq (torch.Tensor): shape [B, T, num_experts]
    """
    if avg_gate.shape[0] == 3:
        expert_names = [
            "Pi_rr",
            "Pi_sca",
            "Pi_tra",
        ]

    elif avg_gate.shape[0] == 9:
        expert_names = [
        "Pi_rr",
        "Pi_sca",
        "Pi_tra",
        "Pi_rr * Pi_sca",
        "Pi_rr * Pi_tra",
        "Pi_sca * Pi_rr",
        "Pi_sca * Pi_tra",
        "Pi_tra * Pi_rr",
        "Pi_tra * Pi_sca",
    ]


    print("Average gate weights across dataset and time:")
    for i, (name, weight) in enumerate(zip(expert_names, avg_gate)):
        print(f"  Expert {i} ({name}): {weight.item():.4f}")


def train_model(model, bucketed_loaders_train, epochs=50, lr=0.001, save_path="LatentMoS.pth"):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    best_loss = float('inf')
    model.train()

    for epoch in range(epochs):
        epoch_loss = 0
        total_batches = 0

        for start_idx, bucket in bucketed_loaders_train.items():
            train_loader = bucket['loader']
            t_inter = bucket['t_inter']
            t_extra = bucket['t_extra']  # optional

            for x_batch, y_inter_batch, y_extra_batch, mask_batch in train_loader:
                optimizer.zero_grad()

                predictions = model(x_batch, t_inter, mask_batch)
                model.set_epoch(epoch)

                interpolation_loss = criterion(predictions, y_inter_batch)
                loss = interpolation_loss

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                total_batches += 1

        avg_epoch_loss = epoch_loss / total_batches
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_epoch_loss:.6f}")

        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            torch.save(model.state_dict(), save_path)
            print(f"New best model saved with loss: {best_loss:.6f}")


def test_and_visualize_model_win(
    model,
    data_test,
    time_all,
    input_length=10,
    stride=4,
    mask_ratio=0.5,
    device='cpu',
    num_plot=3
):
    from sklearn.decomposition import PCA
    import matplotlib.pyplot as plt
    import torch

    pred, gt, mask, avg_gate = data_utils_inter_slow_no_rot.reconstruct_full_test_trajectories(
        model,
        data_test,
        time_all,
        input_length=input_length,
        stride=stride,
        mask_ratio=mask_ratio,
        device=device
    )

    # Evaluate only masked regions
    mask_bool = (mask == 0)
    mse = torch.nn.functional.mse_loss(pred[mask_bool], gt[mask_bool])
    print(f"\nInterpolation MSE over masked points: {mse.item():.6f}")

    mse2 = torch.nn.functional.mse_loss(pred, gt)
    print(f"\nInterpolation MSE: {mse2.item():.6f}")


    # Gate weights
    if avg_gate is not None:
        expert_names = ["Pi_rr", "Pi_sca", "Pi_tra"] if avg_gate.shape[0] == 3 else [
            "Pi_rr", "Pi_sca", "Pi_tra",
            "Pi_rr * Pi_sca", "Pi_rr * Pi_tra",
            "Pi_sca * Pi_rr", "Pi_sca * Pi_tra",
            "Pi_tra * Pi_rr", "Pi_tra * Pi_sca"
        ]
        print("\nAverage gate weights across test reconstruction:")
        for i, (name, weight) in enumerate(zip(expert_names, avg_gate)):
            print(f"  Expert {i} ({name}): {weight.item():.4f}")

    # Plot predictions
    t_np = torch.tensor(time_all).cpu().numpy()
    pred_np = pred.numpy()
    gt_np = gt.numpy()
    mask_np = mask.numpy()
    N, T, D = gt_np.shape

    for i in range(min(num_plot, N)):
        plt.figure(figsize=(12, 4))
        for d in range(D):
            plt.subplot(1, D, d + 1)
            plt.plot(t_np, gt_np[i, :, d], label='Ground Truth', color='black')
            plt.plot(t_np, pred_np[i, :, d], '--', label='Prediction', color='blue')
            # plt.scatter(t_np[mask_np[i] == 0], gt_np[i, mask_np[i] == 0, d], color='red',
            #             label='Masked (Unobserved)', zorder=10)
            plt.scatter(t_np[mask_np[i] == 1], gt_np[i, mask_np[i] == 1, d], color='green',
                        label='Observed', zorder=10)
            plt.title(f"Trajectory {i}, Dim {d}")
            plt.xlabel("Time"); plt.ylabel("Value")
            plt.legend(); plt.grid(True)
        plt.tight_layout()
        plt.show()

    # PCA of latent z
    z_continuous, t_continuous = data_utils_inter_slow_no_rot.extract_latent_flow_nonoverlap(
        model,
        data_test,
        time_all,
        input_length,
        visual_traj_indices=[2],
        mask_ratio=mask_ratio,
        device=device
    )

    if z_continuous is not None:
        z_pca = PCA(n_components=2).fit_transform(z_continuous.numpy())

        plt.figure(figsize=(8, 6))
        plt.plot(z_pca[:, 0], z_pca[:, 1], marker='o', linewidth=1)
        plt.title("Latent Flow Trajectory (PCA)")
        plt.xlabel("PC 1")
        plt.ylabel("PC 2")
        plt.grid(True)
        plt.show()



if __name__ == "__main__":

    # parameters
    seed = 24
    data_utils_inter_slow_no_rot.set_seed(seed)
    latent_dim = 15 # symmetry dimensions for the latent space
    batch_size = 32
    num_epochs = 150
    learning_rate = 0.002 #0.001 for dense data 0.3/0.6, 0.002 for sparse data
    top_k_gates = 2
    if_mut_sym = 1 #  consider the second order correlation of the symmetries or not


    input_length = 30
    num_subintervals = 2 # number of subintervals that maintain the same equivariance, 5 for dense data 0.3/0.6, 2 for sparse data
    output_length = 1
    stride = 4
    mask_ratio = 0.9
    train_ratio = 0.6



    if if_mut_sym == 0:
        gate_dim = 3
    else:
        gate_dim = 9

    gate_dim = 4
    # Step 1: Data Preparation

    # spiral data
    # data, time_all = data_utils.generate_spiral_dataset(n_trajectories=80, total_steps=60, visualize=True)
    # print("Spiral data shape:", data.shape)  # [5000, 200, 2]
    # name = "spiral"

    # load data
    # data, time_all = data_utils.generate_load_dataset(n_trajectories=100, total_steps=144)
    # name = "load"

    # glycolytic
    data, time_all = data_utils_inter_slow_no_rot.generate_glycolytic_dataset(n_trajectories=100, total_steps=200)
    name = "glycolytic"

    # lotka
    # data, time_all = data_utils.generate_lotka_dataset(n_trajectories=100, total_steps=200, noise=0.0)
    # name = "lotka"

    # data, time_all = data_utils.generate_power_event_dataset(n_trajectories=100, total_steps=100)
    # name = "power_event"

    # solar power
    # data, time_all = data_utils.generate_PV_dataset(n_trajectories=124, total_steps=100)
    # name = "PV"

    # split train and test
    N = data.shape[0]
    train_N = int(train_ratio * N)

    data_train = data[:train_N]
    data_test = data[train_N:]

    # Step 2: Create windowed data from training trajectories
    windows_train = data_utils_inter_slow_no_rot.create_tagged_windows(
        data_train, time_all,
        input_length=input_length,
        output_length=output_length,
        mask_ratio=mask_ratio,
        stride=stride,
        seed=42
    )

    # Step 3: Bucket windows by start index
    buckets_train = data_utils_inter_slow_no_rot.bucket_windows_by_start(windows_train)

    # Step 4: Build DataLoaders
    bucketed_loaders_train = data_utils_inter_slow_no_rot.build_bucket_dataloaders(buckets_train, batch_size=32)

    # Step 2: Initialize Model
    # input_length = x_train.shape[1]
    # output_length = y_extrap_train.shape[1]
    output_dim = data.shape[2]
    input_dim = data.shape[2]



    model = Latent_MoS(
        input_dim=input_dim,
        latent_dim=latent_dim,
        output_dim=output_dim,
        output_length=output_length,
        time_feat_dim=1,
        gate_dim=gate_dim,
        top_k_gates=top_k_gates,
        num_subintervals = num_subintervals
    )

    # Step 3: Train the Model
    model_name = f"LatentMoS_{int(mask_ratio * 100)}mask_" + name + ".pth"

    train_model(model, bucketed_loaders_train, epochs=num_epochs, lr=learning_rate, save_path=model_name)

    # Step 4: Test and Visualize
    best_model = Latent_MoS(
        input_dim=input_dim,
        latent_dim=latent_dim,
        output_dim=output_dim,
        output_length=output_length,
        time_feat_dim=1,
        gate_dim=gate_dim,
        top_k_gates=top_k_gates,
        num_subintervals=num_subintervals
    )

    best_model.load_state_dict(torch.load(model_name))
    best_model.eval()


    # # Test mode 1: test the trajectory at once
    # # Step 1: Prepare and save
    # x_test, y_test, mask_test, t_test = data_utils.prepare_masked_full_test_set(
    #     data_test=data_test,
    #     time_all=time_all,
    #     mask_ratio=0.8,
    #     seed=999
    # )
    #
    # # -- Run evaluation + visualization --
    # test_and_visualize_model(
    #     best_model,
    #     x_test,
    #     y_test,
    #     mask_test,
    #     t_test,
    #     device='cuda' if torch.cuda.is_available() else 'cpu',
    #     num_plot=1
    # )

    # Test mode 2: test the moving windows and do average for overlapps
    test_and_visualize_model_win(
        model=best_model,
        data_test=data_test,
        time_all=time_all,
        input_length=input_length,
        stride=stride,
        mask_ratio=mask_ratio,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        num_plot=1
    )
