import os
import sys
import torch
import torch.nn as nn
from torch.linalg import matrix_exp
import math
import numpy as np
import warnings
from scipy.linalg import logm
from torchdiffeq import odeint
import contextlib
import data_utils_interpolation as data_utils
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import torch
import torch
###################



###################
@contextlib.contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as fnull:
        old_stdout = sys.stdout
        sys.stdout = fnull
        try:
            yield
        finally:
            sys.stdout = old_stdout


class GRUUpdate(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_units=64):
        super().__init__()
        concat_dim = input_dim + latent_dim
        self.update_gate = nn.Sequential(
            nn.Linear(concat_dim,  latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, latent_dim),
            nn.Sigmoid()
        )
        self.reset_gate = nn.Sequential(
            nn.Linear(concat_dim, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, latent_dim),
            nn.Sigmoid()
        )
        self.new_state = nn.Sequential(
            nn.Linear(concat_dim, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, latent_dim)
        )

    def forward(self, h, x, mask_1d):
        # h: [B, H], x: [B, D], mask_1d: [B]
        concat = torch.cat([h, x], dim=-1)  # [B, H + D]

        z = self.update_gate(concat)
        r = self.reset_gate(concat)

        r_h = r * h
        concat_r = torch.cat([r_h, x], dim=-1)
        h_tilde = self.new_state(concat_r)

        h_new = (1 - z) * h_tilde + z * h

        # mask_1d: [B] → reshape to [B, 1] to broadcast
        mask = mask_1d.unsqueeze(-1)
        h_out = mask * h_new + (1 - mask) * h
        return h_out



class LatentODEFunc(nn.Module):
    def __init__(self, latent_dim, hidden_units=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, latent_dim)
        )

    def forward(self, t, h):
        return self.net(h)


class DeterministicODERNNEncoder(nn.Module):
    def __init__(self, input_dim, latent_dim, ode_func, gru_update, min_step=2e-3):
        super().__init__()
        self.ode_func = ode_func
        self.gru_update = gru_update
        self.latent_dim = latent_dim
        self.min_step = min_step

    def forward(self, x, t, mask):
        """
        Args:
            x:    [B, T, D] - input data
            t:    [B, T]    - per-batch time
            mask: [B, T]    - binary mask
        Returns:
            h: [B, latent_dim] - final latent state
        """
        B, T, D = x.size()
        h = torch.zeros(B, self.latent_dim, device=x.device)

        for i in range(T):
            if i > 0:
                t_prev = t[i - 1].item()
                t_now = t[i].item()
                delta = abs(t_now - t_prev)

                if delta < self.min_step:
                    # Simple Euler step
                    dh = self.ode_func(torch.tensor(t_prev).to(h), h)
                    h = h + dh * (t_now - t_prev)
                else:
                    # Add intermediate steps
                    num_steps = max(2, int(delta / self.min_step))
                    time_points = torch.linspace(t_prev, t_now, num_steps).to(h)
                    h_traj = odeint(self.ode_func, h, time_points, method='rk4')  # [num_steps, B, H]
                    h = h_traj[-1]  # use the last point

            # GRU update
            h = self.gru_update(h, x[:, i, :], mask[:, i])

        return h


class ODERNN(nn.Module):
    def __init__(self, input_dim, latent_dim, output_dim, output_length,
                 time_feat_dim=0, gate_dim=3, top_k_gates=3, ode_func=None):
        super(ODERNN, self).__init__()

        self.latent_dim = latent_dim
        self.output_length = output_length
        self.time_feat_dim = 0
        self.gate_dim = gate_dim
        self.top_k_gates = top_k_gates
        self.current_epoch = 0

        if ode_func is None:
            ode_func = LatentODEFunc(latent_dim)

        gru_update = GRUUpdate(input_dim, latent_dim)

        self.encoder = DeterministicODERNNEncoder(
            input_dim=input_dim,
            latent_dim=latent_dim,
            ode_func=ode_func,
            gru_update=gru_update,
        )

        self.state_map = nn.Linear(latent_dim, latent_dim)

        # ===== Simple decoder for each step =====
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, output_dim),
        )

        # House‑keeping for shapes we must still return
        self.gate_dim = gate_dim
        self.latent_dim = latent_dim

    def set_epoch(self, epoch):
        self.current_epoch = epoch

    def forward(self, x: torch.Tensor, t: torch.Tensor, mask: torch.Tensor):
        # -------- ODE‑RNN PART  ----------------------- #
        # x : [B, T, D]   t : [T]   mask : [B, T]
        x_rev   = torch.flip(x,   dims=[1])
        m_rev   = torch.flip(mask, dims=[1])
        t_rev   = torch.flip(t,   dims=[0])

        z0 = self.encoder(x_rev, t_rev, m_rev)          # [B, latent_dim]
        # ----------------------------------------------------------------

        # 1map latent vector to a “proper” state
        z0 = self.state_map(z0)                         # [B, latent_dim]

        # 2 repeat that state for each prediction step (T steps)
        B, T = x.size(0), x.size(1)
        z_seq = z0.unsqueeze(1).repeat(1, T, 1)         # [B, T, latent_dim]

        # 3 decode
        y_hat = self.fc(z_seq)                          # [B, T, output_dim]

        # 4 dummy tensors for gates / α / v so the tuple signature stays
        gates_seq  = torch.zeros(B, T, self.gate_dim,  device=x.device)
        vt_seq     = torch.zeros(B, T, self.latent_dim, device=x.device)
        alpha_seq  = torch.zeros(B, T, self.latent_dim, device=x.device)

        return y_hat, z_seq, gates_seq, vt_seq, alpha_seq
    

class ODERNN_extra(nn.Module):
    def __init__(self, input_dim, latent_dim, output_dim, output_length,
                 time_feat_dim=0, gate_dim=3, top_k_gates=3, ode_func=None):
        super(ODERNN_extra, self).__init__()

        self.latent_dim = latent_dim
        self.output_length = output_length
        self.time_feat_dim = 0
        self.gate_dim = gate_dim
        self.top_k_gates = top_k_gates
        self.current_epoch = 0

        if ode_func is None:
            ode_func = LatentODEFunc(latent_dim)

        gru_update = GRUUpdate(input_dim, latent_dim)

        self.encoder = DeterministicODERNNEncoder(
            input_dim=input_dim,
            latent_dim=latent_dim,
            ode_func=ode_func,
            gru_update=gru_update,
        )

        self.state_map = nn.Linear(latent_dim, latent_dim)

        # ===== Simple decoder for each step =====
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, output_dim),
        )

        # House‑keeping for shapes we must still return
        self.gate_dim = gate_dim
        self.latent_dim = latent_dim

    def set_epoch(self, epoch):
        self.current_epoch = epoch

    def forward(self, x: torch.Tensor, t_inter: torch.Tensor, t_extra: torch.Tensor, mask: torch.Tensor):
        # -------- ODE‑RNN PART  ----------------------- #
        # x : [B, T, D]   t : [T]   mask : [B, T]

        z0 = self.encoder(x, t_inter, mask)          # [B, latent_dim]
        # ----------------------------------------------------------------

        # 1map latent vector to a “proper” state
        z0 = self.state_map(z0)                         # [B, latent_dim]

        # 2 repeat that state for each prediction step (T steps)
        B, T = x.size(0), x.size(1)
        z_seq = z0.unsqueeze(1).repeat(1, t_extra.size(0), 1)         # [B, T, latent_dim]

        # 3 decode
        y_hat = self.fc(z_seq)                          # [B, T, output_dim]

        # 4 dummy tensors for gates / α / v so the tuple signature stays
        gates_seq  = torch.zeros(B, T, self.gate_dim,  device=x.device)
        vt_seq     = torch.zeros(B, T, self.latent_dim, device=x.device)
        alpha_seq  = torch.zeros(B, T, self.latent_dim, device=x.device)

        return y_hat, z_seq, gates_seq, vt_seq, alpha_seq

###################




def compute_average_gate(gate_seq):
    avg_gate = gate_seq.mean(dim=(0, 1))  # Average over batch and time

    """
    Compute and print the average gating values across all batches and time steps,
    along with expert names.

    Args:
        gate_seq (torch.Tensor): shape [B, T, num_experts]
    """
    if avg_gate.shape[0] == 3:
        expert_names = [
            "Pi_rr",
            "Pi_sca",
            "Pi_tra",
        ]

    elif avg_gate.shape[0] == 9:
        expert_names = [
        "Pi_rr",
        "Pi_sca",
        "Pi_tra",
        "Pi_rr * Pi_sca",
        "Pi_rr * Pi_tra",
        "Pi_sca * Pi_rr",
        "Pi_sca * Pi_tra",
        "Pi_tra * Pi_rr",
        "Pi_tra * Pi_sca",
    ]


    print("Average gate weights across dataset and time:")
    for i, (name, weight) in enumerate(zip(expert_names, avg_gate)):
        print(f"  Expert {i} ({name}): {weight.item():.4f}")


def train_model(model, bucketed_loaders_train, epochs=50, lr=0.001, save_path="best_model.pth"):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    best_loss = float('inf')
    model.train()

    for epoch in range(epochs):
        epoch_loss = 0
        total_batches = 0

        for start_idx, bucket in bucketed_loaders_train.items():
            
            train_loader = bucket['loader']
            t_inter = bucket['t_inter']
            t_extra = bucket['t_extra']  # optional

            for x_batch, y_inter_batch, y_extra_batch, mask_batch in train_loader:
                optimizer.zero_grad()

                predictions, _, _, _, _ = model(x_batch, t_inter, mask_batch)
                model.set_epoch(epoch)

                interpolation_loss = criterion(predictions, y_inter_batch)
                loss = interpolation_loss

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                total_batches += 1

        if total_batches != 0:
            avg_epoch_loss = epoch_loss / total_batches
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_epoch_loss:.6f}")

            if avg_epoch_loss < best_loss:
                best_loss = avg_epoch_loss
                torch.save(model.state_dict(), save_path)
                print(f"New best model saved with loss: {best_loss:.6f}")

def train_model_extra(model, bucketed_loaders_train, epochs=50, lr=0.001, save_path="best_model.pth"):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    best_loss = float('inf')
    model.train()

    for epoch in range(epochs):
        epoch_loss = 0
        total_batches = 0

        for start_idx, bucket in bucketed_loaders_train.items():
            train_loader = bucket['loader']
            t_inter = bucket['t_inter']
            t_extra = bucket['t_extra']  # optional

            # print(t_extra.shape)

            for x_batch, y_inter_batch, y_extra_batch, mask_batch in train_loader:
                optimizer.zero_grad()

                predictions, _, _, _, _ = model(x_batch, t_inter, t_extra, mask_batch)
                model.set_epoch(epoch)

                extrapolation_loss = criterion(predictions, y_extra_batch)
                loss = extrapolation_loss

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                total_batches += 1

        avg_epoch_loss = epoch_loss / total_batches
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_epoch_loss:.6f}")

        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            torch.save(model.state_dict(), save_path)
            print(f"New best model saved with loss: {best_loss:.6f}")


def test_and_visualize_model_win_extra(
    model,
    data_test,
    time_all,
    input_length=10,
    output_length=10,
    stride=4,
    mask_ratio=0.5,
    device='cpu',
    num_plot=3
):
    """
    Evaluate model via extrapolation and visualize results with PCA of latent flow.
    """
    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA
    import torch.nn.functional as F

    # --- Get extrapolation predictions and ground truth ---
    pred, gt, mask, avg_gate = data_utils.reconstruct_full_test_trajectories_extra(
        model,
        data_test,
        time_all,
        input_length=input_length,
        output_length=output_length,
        stride=stride,
        mask_ratio=mask_ratio,
        device=device
    )

    assert pred.shape == gt.shape, f"Mismatch: pred {pred.shape}, gt {gt.shape}"

    # --- Compute MSE ---
    mse = F.mse_loss(pred, gt)
    print(f"\nExtrapolation MSE over future points: {mse.item():.6f}")

    mse_we_want = mse.item()

    # --- Gate weights ---
    if avg_gate is not None:
        expert_names = ["Pi_rr", "Pi_sca", "Pi_tra"] if avg_gate.shape[0] == 3 else [
            "Pi_rr", "Pi_sca", "Pi_tra",
            "Pi_rr * Pi_sca", "Pi_rr * Pi_tra",
            "Pi_sca * Pi_rr", "Pi_sca * Pi_tra",
            "Pi_tra * Pi_rr", "Pi_tra * Pi_sca"
        ]
        print("\nAverage gate weights across extrapolation:")
        for i, (name, weight) in enumerate(zip(expert_names, avg_gate)):
            print(f"  Expert {i} ({name}): {weight.item():.4f}")

    # --- Visualization ---
    t_np = np.array(time_all)[-pred.shape[1]:]  # align with extrapolated portion
    pred_np = pred.numpy()
    gt_np = gt.numpy()
    N, T, D = gt_np.shape

    for i in range(min(num_plot, N)):
        plt.figure(figsize=(12, 4))
        for d in range(D):
            plt.subplot(1, D, d + 1)

            # Plot true and predicted
            plt.plot(t_np, gt_np[i, :, d], label='Ground Truth', color='black')
            plt.plot(t_np, pred_np[i, :, d], '--', label='Prediction (Extrapolated)', color='blue')

            # Get unmasked indices from the mask (0 means masked, 1 means observed)
            unmasked_idx = np.where(mask[i, -T:] == 1)[0]
            if len(unmasked_idx) > 0:
                plt.scatter(t_np[unmasked_idx], gt_np[i, unmasked_idx, d],
                            color='green', label='Input (Unmasked)', zorder=10)

            # Mark extrapolation start
            plt.axvline(x=t_np[0], color='gray', linestyle=':', label='Extrapolation Start')

            plt.title(f"Trajectory {i}, Dim {d}")
            plt.xlabel("Time")
            plt.ylabel("Value")
            plt.legend()
            plt.grid(True)

        plt.tight_layout()
        plt.show()
    
    return mse_we_want

    # # --- PCA of latent trajectory (optional) ---
    # z_continuous, t_continuous = extract_latent_flow_nonoverlap(
    #     model,
    #     data_test,
    #     time_all,
    #     input_length,
    #     visual_traj_indices=[0],
    #     mask_ratio=mask_ratio,
    #     device=device
    # )
    #
    # if z_continuous is not None:
    #     z_pca = PCA(n_components=2).fit_transform(z_continuous.cpu().numpy())
    #     plt.figure(figsize=(8, 6))
    #     plt.plot(z_pca[:, 0], z_pca[:, 1], marker='o', linewidth=1)
    #     plt.title("Latent Flow Trajectory (PCA)")
    #     plt.xlabel("PC 1")
    #     plt.ylabel("PC 2")
    #     plt.grid(True)
    #     plt.show()
    

def test_and_visualize_model_win(
    model,
    data_test,
    time_all,
    input_length=10,
    stride=4,
    mask_ratio=0.5,
    device='cpu',
    num_plot=3
):

    pred, gt, mask, avg_gate = data_utils.reconstruct_full_test_trajectories(
        model,
        data_test,
        time_all,
        input_length=input_length,
        stride=stride,
        mask_ratio=mask_ratio,
        device=device
    )

    # Evaluate only masked regions
    mask_bool = (mask == 0)
    mse = torch.nn.functional.mse_loss(pred[mask_bool], gt[mask_bool])
    print(f"\nInterpolation MSE over masked points: {mse.item():.6f}")

    mse = torch.nn.functional.mse_loss(pred, gt)
    print(f"\nInterpolation MSE over all points: {mse.item():.6f}")

    # Gate weights
    if avg_gate is not None:
        expert_names = ["Pi_rr", "Pi_sca", "Pi_tra"] if avg_gate.shape[0] == 3 else [
            "Pi_rr", "Pi_sca", "Pi_tra",
            "Pi_rr * Pi_sca", "Pi_rr * Pi_tra",
            "Pi_sca * Pi_rr", "Pi_sca * Pi_tra",
            "Pi_tra * Pi_rr", "Pi_tra * Pi_sca"
        ]
        print("\nAverage gate weights across test reconstruction:")
        for i, (name, weight) in enumerate(zip(expert_names, avg_gate)):
            print(f"  Expert {i} ({name}): {weight.item():.4f}")

    # Plot predictions
    t_np = torch.tensor(time_all).cpu().numpy()
    pred_np = pred.numpy()
    gt_np = gt.numpy()
    mask_np = mask.numpy()
    N, T, D = gt_np.shape

    for i in range(min(num_plot, N)):
        plt.figure(figsize=(12, 4))
        for d in range(D):
            plt.subplot(1, D, d + 1)
            plt.plot(t_np, gt_np[i, :, d], label='Ground Truth', color='black')
            plt.plot(t_np, pred_np[i, :, d], '--', label='Prediction', color='blue')
            # plt.scatter(t_np[mask_np[i] == 0], gt_np[i, mask_np[i] == 0, d], color='red',
            #             label='Masked (Unobserved)', zorder=10)
            plt.scatter(t_np[mask_np[i] == 1], gt_np[i, mask_np[i] == 1, d], color='green',
                        label='Observed', zorder=10)
            plt.title(f"Trajectory {i}, Dim {d}")
            plt.xlabel("Time"); plt.ylabel("Value")
            plt.legend(); plt.grid(True)
        plt.tight_layout()
        plt.savefig("./results/interpolation_ODERNN_0.6.png")
        plt.show()

    # PCA of latent z
    z_continuous, t_continuous = data_utils.extract_latent_flow_nonoverlap(
        model,
        data_test,
        time_all,
        input_length,
        visual_traj_indices=[2],
        mask_ratio=mask_ratio,
        device=device
    )

    if z_continuous is not None:
        z_pca = PCA(n_components=2).fit_transform(z_continuous.numpy())

        plt.figure(figsize=(8, 6))
        plt.plot(z_pca[:, 0], z_pca[:, 1], marker='o', linewidth=1)
        plt.title("Latent Flow Trajectory (PCA)")
        plt.xlabel("PC 1")
        plt.ylabel("PC 2")
        plt.grid(True)
        plt.show()



if __name__ == "__main__":
    '''
    interpolation.
    '''
    # parameters
    seed = 200
    data_utils.set_seed(seed)
    latent_dim = 15 #8 15# symmetry dimensions for the latent space
    batch_size = 32
    num_epochs = 150 # 150
    learning_rate = 0.0008
    top_k_gates = 1
    if_mut_sym = 1 #  consider the second order correlation of the symmetries or not
    data_name = "spiral"
    model_name = "ODERNN"
    input_length = 30
    output_length = 1
    stride = 4         ##  1
    mask_ratio = 0.9   ###
    train_ratio = 0.6  ### 0.8

    if if_mut_sym == 0:
        gate_dim = 3
    else:
        gate_dim = 9


    if data_name == "spiral":
        data, time_all = data_utils.generate_spiral_dataset(n_trajectories=80, total_steps=60, noise_std=0.01, seed=42, visualize=False)
        print("Sprial data shape:", data.shape)

    elif data_name == "glycolytic":
        data, time_all = data_utils.generate_glycolytic_dataset(n_trajectories=100, total_steps=200)
        print("Glycolytic data shape:", data.shape)

    elif data_name == "lotka":
        data, time_all = data_utils.generate_lotka_dataset(n_trajectories=100, total_steps=200)
        print("Lotka data shape:", data.shape)
    
    elif data_name == "load":
        data, time_all = data_utils.generate_load_dataset(n_trajectories=100, total_steps=144)
        print("Load data shape:", data.shape)

    elif data_name == "PV":
        data, time_all = data_utils.generate_PV_dataset(n_trajectories=100)
        print("PV data shape:", data.shape)

    elif data_name == "power-event":
        data, time_all = data_utils.generate_power_event_dataset(n_trajectories=100)
        print("Power event data shape:", data.shape)

    elif data_name == "air-quality":
        data, time_all = data_utils.generate_AirQuality_dataset(n_trajectories=100)
        print("AirQuality data shape:", data.shape)

    elif data_name == "ECG":
        data, time_all = data_utils.generate_ECG_dataset()
        print("ECG event data shape:", data.shape)

    # split train and test
    N = data.shape[0]
    train_N = int(train_ratio * N)

    data_train = data[:train_N]
    data_test = data[train_N:]

    # Step 2: Create windowed data from training trajectories
    windows_train = data_utils.create_tagged_windows(
        data_train, time_all,
        input_length=input_length,
        output_length=output_length,
        mask_ratio=mask_ratio,
        stride=stride,
        seed=42
    )

    # Step 3: Bucket windows by start index
    buckets_train = data_utils.bucket_windows_by_start(windows_train)

    # Step 4: Build DataLoaders
    bucketed_loaders_train = data_utils.build_bucket_dataloaders(buckets_train, batch_size=32)

    # Step 2: Initialize Model
    # input_length = x_train.shape[1]
    # output_length = y_extrap_train.shape[1]
    output_dim = data.shape[2]
    input_dim = data.shape[2]



    model = ODERNN_extra(
        input_dim=input_dim,
        latent_dim=latent_dim,
        output_dim=output_dim,
        output_length=output_length,
        time_feat_dim=1,
        gate_dim=gate_dim,
        top_k_gates=top_k_gates
    )

    # Step 3: Train the Model
    train_model_extra(model, bucketed_loaders_train, epochs=num_epochs, lr=learning_rate, save_path=f"{model_name}_{int(mask_ratio*100)}mask_{data_name}.pth")

    # Step 4: Test and Visualize
    best_model = ODERNN_extra(
        input_dim=input_dim,
        latent_dim=latent_dim,
        output_dim=output_dim,
        output_length=output_length,
        time_feat_dim=1,
        gate_dim=gate_dim,
        top_k_gates=top_k_gates
    )
    best_model.load_state_dict(torch.load(f"{model_name}_{int(mask_ratio*100)}mask_{data_name}.pth"))
    best_model.eval()


    # Test mode 2: test the moving windows and do average for overlapps
    test_and_visualize_model_win_extra(
        model=best_model,
        data_test=data_test,
        time_all=time_all,
        input_length=input_length,
        stride=stride,
        mask_ratio=mask_ratio,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        num_plot=1
    )
