import data_utils_v12
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import torch.nn.functional as F
############################################################
from Informer2020.models.model import Informer
import torch, torch.nn as nn
from typing import Optional

class GRUDTLayer(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.cell = nn.GRUCell(input_size, hidden_size)

    def forward(
            self,
            x: torch.Tensor,  # B × L × D
            t_vec: torch.Tensor,  # L   or   B × L
            h0: Optional[torch.Tensor] = None,
    ):
        B, L, _ = x.shape
        h = torch.zeros(B, self.hidden_size, device=x.device) if h0 is None else h0

        hidden_seq = []
        for i in range(L):
            if i > 0:
                if t_vec.dim() == 2:  # B × L
                    dt = t_vec[:, i] - t_vec[:, i - 1]  # (B)
                else:  # L
                    dt = t_vec[i] - t_vec[i - 1]  # scalar → broadcast

                # Ensure dt is non-negative (can happen with unsorted or identical timestamps)
                if isinstance(dt, torch.Tensor):
                    dt = torch.clamp(dt, min=0.0)
                else:  # scalar
                    dt = max(0.0, dt)

                decay = torch.exp(-dt)  # Exponential decay, common in GRU-D variants
                decay = decay.unsqueeze(-1) if decay.dim() == 1 else decay  # (B, 1) or scalar

                # Decay the hidden state
                h_decayed = decay * h

                # Note: Original GRU-D has more complex decay for inputs and imputation.
                h = self.cell(x[:, i, :], h_decayed)
                h_from_cell = self.cell(x[:, i, :], h)  # cell uses the h from previous step
                h = decay * h + (1.0 - decay) * h_from_cell

            else:  # i == 0
                h = self.cell(x[:, i, :], h)  # h is h0 (usually zeros)

            hidden_seq.append(h.unsqueeze(1))

        hidden_seq = torch.cat(hidden_seq, dim=1)  # B × L × H
        return hidden_seq, h


# ────────────────────────────────────────────────────────────────
# 2.  Core GRU‑Δ‑T model that outputs a *sequence* prediction
# ────────────────────────────────────────────────────────────────
class GRUDTModel(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super().__init__()
        self.encoder = GRUDTLayer(input_size, hidden_size)
        self.readout = nn.Linear(hidden_size, output_size)

    def forward(self, x: torch.Tensor, t_vec: torch.Tensor):
        """
        x     : B × L × D_in
        t_vec : L or B x L
        returns : B × L × output_size (prediction at every step)
        """
        hidden_seq, _ = self.encoder(x, t_vec)  # B × L × H
        preds = self.readout(hidden_seq)  # B × L × D_out
        return preds


class GRUDTMoSInterface(nn.Module):
    """
    A thin wrapper that emulates InformerMoSInterface:

        y_hat, _, _, gate_seq, _ = model(x_window, t_inter, t_extra, mask)

    • x_window : B × L_in × D      (possibly masked inputs)
    • t_inter  : L_in   (or  B × L_in)   interpolation stamps
    • t_extra  : L_out              extrapolation stamps
    • mask     : B × L_in           1 = observed, 0 = masked

    Returns
    -------
    y_hat     : B × L_out × D       (prediction for future points)
    _ (None)  : placeholder 1
    _ (None)  : placeholder 2
    gate_seq  : B × L_in × gate_dim (dummy “expert‑weights” so the
                                     rest of your MoS code keeps working)
    _ (None)  : placeholder 3
    """
    def __init__(
        self,
        input_dim    : int,
        output_dim   : int,
        input_length : int,
        output_length: int,
        hidden_dim   : int = 128,
        gate_dim     : int = 3,   # keep =3 or 9 to match your cfg
    ):
        super().__init__()
        self.input_length   = input_length
        self.output_length  = output_length
        self.gate_dim       = gate_dim

        # 1) GRU‑Δ‑T encoder
        self.encoder = GRUDTLayer(input_dim, hidden_dim)

        # 2) same GRUCell for decoding one‐step‑ahead
        self.gru_cell = nn.GRUCell(input_dim, hidden_dim)

        # 3) linear read‑out
        self.readout  = nn.Linear(hidden_dim, output_dim)

        # 4) (optional) store current epoch for compatibility
        self.current_epoch = 0

    # ------------------------------------------------------------
    # helper: compute Δt between consecutive time points
    # ------------------------------------------------------------
    @staticmethod
    def _delta_t(t_now, t_prev):
        """works with scalars, (B,) or () tensors"""
        dt = t_now - t_prev
        if isinstance(dt, torch.Tensor):
            dt = torch.clamp(dt, min=0.0)
        else:
            dt = max(0.0, dt)
        return dt

    # ------------------------------------------------------------
    # forward pass
    # ------------------------------------------------------------
    def forward(
        self,
        x_window : torch.Tensor,   # B × L_in × D
        t_inter  : torch.Tensor,   # L_in  or  B × L_in
        t_extra  : torch.Tensor,   # L_out
        mask     : torch.Tensor    # B × L_in
    ):
        B, L_in, D = x_window.shape
        device = x_window.device

        # ---- 1) zero‑impute masked points (same as Informer wrapper) ----
        x_inp = x_window.clone()
        x_inp[mask == 0] = 0.0

        # ---- 2) encode observed part with GRU‑Δ‑T ----
        hidden_seq, h = self.encoder(x_inp, t_inter)   # last h is at t_inter[-1]

        # ---- 3) autoregressively extrapolate L_out steps ----
        preds     = []
        h_t       = h                                   # start from last hidden state
        x_prev    = x_inp[:, -1, :]                     # last *observed* value
        t_prev    = t_inter[:, -1] if t_inter.dim() == 2 else t_inter[-1]

        for i in range(self.output_length):
            t_now = t_extra[:, i] if t_extra.dim() == 2 else t_extra[i]
            dt    = self._delta_t(t_now, t_prev)

            # broadcast dt → (B, 1) if necessary
            if not torch.is_tensor(dt):
                dt = torch.tensor(dt, device=device)
            if dt.dim() == 1:
                dt = dt.unsqueeze(-1)

            decay = torch.exp(-dt)                      # simple exponential decay
            h_decayed = decay * h_t

            h_t = self.gru_cell(x_prev, h_decayed)      # GRU update
            y_t = self.readout(h_t)                     # map to output space

            preds.append(y_t.unsqueeze(1))              # store
            x_prev = y_t.detach()                       # teacher‑forcing with own pred
            t_prev = t_now

        y_hat = torch.cat(preds, dim=1)                 # B × L_out × D

        # ---- 4) dummy gate sequence (all weight on expert 0) ----
        gate_seq = torch.zeros(B, L_in, self.gate_dim, device=device)
        gate_seq[..., 0] = 1.0

        return y_hat, None, None, gate_seq, None

    def set_epoch(self, epoch:int):   # keeps your training loop unchanged
        self.current_epoch = epoch


def train_model(model, bucketed_loaders_train, epochs=50, lr=0.001, save_path="best_model.pth"):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    best_loss = float('inf')
    model.train()

    for epoch in range(epochs):
        epoch_loss = 0
        total_batches = 0

        for start_idx, bucket in bucketed_loaders_train.items():

            train_loader = bucket['loader']
            t_inter = bucket['t_inter']
            t_extra = bucket['t_extra']  # optional

            for x_batch, y_inter_batch, y_extra_batch, mask_batch in train_loader:
                optimizer.zero_grad()

                predictions, _, _, _, _ = model(x_batch, t_inter, t_extra, mask_batch)
                model.set_epoch(epoch)

                extrapolation_loss = criterion(predictions, y_extra_batch)
                loss = extrapolation_loss

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                total_batches += 1

        avg_epoch_loss = epoch_loss / total_batches
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_epoch_loss:.6f}")

        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            torch.save(model.state_dict(), save_path)
            print(f"New best model saved with loss: {best_loss:.6f}")


def test_and_visualize_model_win(
    model,
    data_test,
    time_all,
    input_length=10,
    stride=4,
    mask_ratio=0.5,
    device='cpu',
    num_plot=3
):

    pred, gt, mask, avg_gate = data_utils_v12.reconstruct_full_test_trajectories(
        model,
        data_test,
        time_all,
        input_length=input_length,
        stride=stride,
        mask_ratio=mask_ratio,
        device=device
    )

    # Evaluate only masked regions
    mask_bool = (mask == 0)
    mse = torch.nn.functional.mse_loss(pred[mask_bool], gt[mask_bool])
    print(f"\nInterpolation MSE over masked points: {mse.item():.6f}")

    # Gate weights
    if avg_gate is not None:
        expert_names = ["Pi_rr", "Pi_sca", "Pi_tra"] if avg_gate.shape[0] == 3 else [
            "Pi_rr", "Pi_sca", "Pi_tra",
            "Pi_rr * Pi_sca", "Pi_rr * Pi_tra",
            "Pi_sca * Pi_rr", "Pi_sca * Pi_tra",
            "Pi_tra * Pi_rr", "Pi_tra * Pi_sca"
        ]
        print("\nAverage gate weights across test reconstruction:")
        for i, (name, weight) in enumerate(zip(expert_names, avg_gate)):
            print(f"  Expert {i} ({name}): {weight.item():.4f}")

    # Plot predictions
    t_np = torch.tensor(time_all).cpu().numpy()
    pred_np = pred.numpy()
    gt_np = gt.numpy()
    mask_np = mask.numpy()
    N, T, D = gt_np.shape

    for i in range(min(num_plot, N)):
        plt.figure(figsize=(12, 4))
        for d in range(D):
            plt.subplot(1, D, d + 1)
            plt.plot(t_np, gt_np[i, :, d], label='Ground Truth', color='black')
            plt.plot(t_np, pred_np[i, :, d], '--', label='Prediction', color='blue')
            plt.scatter(t_np[mask_np[i] == 0], gt_np[i, mask_np[i] == 0, d], color='red',
                        label='Masked (Unobserved)', zorder=10)
            plt.scatter(t_np[mask_np[i] == 1], gt_np[i, mask_np[i] == 1, d], color='green',
                        label='Observed', zorder=10)
            plt.title(f"Trajectory {i}, Dim {d}")
            plt.xlabel("Time"); plt.ylabel("Value")
            plt.legend(); plt.grid(True)
        plt.tight_layout()
        plt.show()

    # PCA of latent z
    z_continuous, t_continuous = data_utils_v12.extract_latent_flow_nonoverlap(
        model,
        data_test,
        time_all,
        input_length,
        visual_traj_indices=[2],
        mask_ratio=mask_ratio,
        device=device
    )

    if z_continuous is not None:
        z_pca = PCA(n_components=2).fit_transform(z_continuous.numpy())

        plt.figure(figsize=(8, 6))
        plt.plot(z_pca[:, 0], z_pca[:, 1], marker='o', linewidth=1)
        plt.title("Latent Flow Trajectory (PCA)")
        plt.xlabel("PC 1")
        plt.ylabel("PC 2")
        plt.grid(True)
        plt.show()



def test_and_visualize_model_win_extra(
    model,
    data_test,
    time_all,
    input_length=10,
    output_length=10,
    stride=4,
    mask_ratio=0.5,
    device='cpu',
    num_plot=3
):
    """
    Evaluate model via extrapolation and visualize results with PCA of latent flow.
    """
    # --- Get extrapolation predictions and ground truth ---
    pred, gt, mask, avg_gate = data_utils_v12.reconstruct_full_test_trajectories_extra(
        model,
        data_test,
        time_all,
        input_length=input_length,
        output_length=output_length,
        stride=stride,
        mask_ratio=mask_ratio,
        device=device,
        spline=True
    )

    assert pred.shape == gt.shape, f"Mismatch: pred {pred.shape}, gt {gt.shape}"

    # --- Compute MSE ---
    mse = F.mse_loss(pred, gt)
    print(f"\nExtrapolation MSE over future points: {mse.item():.6f}")

    # --- Gate weights ---
    if avg_gate is not None:
        expert_names = ["Pi_rr", "Pi_sca", "Pi_tra"] if avg_gate.shape[0] == 3 else [
            "Pi_rr", "Pi_sca", "Pi_tra",
            "Pi_rr * Pi_sca", "Pi_rr * Pi_tra",
            "Pi_sca * Pi_rr", "Pi_sca * Pi_tra",
            "Pi_tra * Pi_rr", "Pi_tra * Pi_sca"
        ]
        print("\nAverage gate weights across extrapolation:")
        for i, (name, weight) in enumerate(zip(expert_names, avg_gate)):
            print(f"  Expert {i} ({name}): {weight.item():.4f}")

    # --- Visualization ---
    t_np = np.array(time_all)[-pred.shape[1]:]  # align with extrapolated portion
    pred_np = pred.numpy()
    gt_np = gt.numpy()
    N, T, D = gt_np.shape

    for i in range(min(num_plot, N)):
        plt.figure(figsize=(12, 4))
        for d in range(D):
            plt.subplot(1, D, d + 1)

            # Plot true and predicted
            plt.plot(t_np, gt_np[i, :, d], label='Ground Truth', color='black')
            plt.plot(t_np, pred_np[i, :, d], '--', label='Prediction (Extrapolated)', color='blue')

            # Get unmasked indices from the mask (0 means masked, 1 means observed)
            unmasked_idx = np.where(mask[i, -T:] == 1)[0]
            if len(unmasked_idx) > 0:
                plt.scatter(t_np[unmasked_idx], gt_np[i, unmasked_idx, d],
                            color='green', label='Input (Unmasked)', zorder=10)

            # Mark extrapolation start
            plt.axvline(x=t_np[0], color='gray', linestyle=':', label='Extrapolation Start')

            plt.title(f"Trajectory {i}, Dim {d}")
            plt.xlabel("Time")
            plt.ylabel("Value")
            plt.legend()
            plt.grid(True)
        plt.tight_layout()
        plt.savefig(f"./results/extrapolation_{data_name}_{model_name}_{mask_ratio}.png")
        plt.show()

if __name__ == "__main__":
    # parameters
    seed = 28
    data_utils_v12.set_seed(seed)
    latent_dim = 15 # symmetry dimensions for the latent space
    batch_size = 32
    num_epochs = 100
    learning_rate = 0.001 #0.0008
    top_k_gates = 1
    if_mut_sym = 1 #  consider the second order correlation of the symmetries or not
    input_length = 30
    output_length = 30
    stride = 4
    mask_ratio = 0.3 # 0.9
    train_ratio = 0.6

    if if_mut_sym == 0:
        gate_dim = 3
    else:
        gate_dim = 9

    data_name = "ECG"  # lotka PV power_event  air_quality air_quality ECG
    model_name = "RNN_delta_t"

    if data_name == "spiral":
        data, time_all = data_utils_v12.generate_spiral_dataset(n_trajectories=80, total_steps=60, visualize=False)  # (80, 60, 2)
        # data_np -= data_np.min(axis=(-1, 1), keepdims=True)
        # data_np /= data_np.max(axis=(-1, 1), keepdims=True)
    elif data_name == "glycolytic":
        data, time_all = data_utils_v12.generate_glycolytic_dataset()
    elif data_name == "lotka":
        data, time_all = data_utils_v12.generate_lotka_dataset()
    elif data_name == "load":
        data, time_all = data_utils_v12.generate_load_dataset()
    elif data_name == "PV":
        data, time_all = data_utils_v12.generate_PV_dataset()
    elif data_name == "power_event":
        data, time_all = data_utils_v12.generate_power_event_dataset()
    elif data_name == "air_quality":
        data, time_all = data_utils_v12.generate_AirQuality_dataset()
        input_length = 10
        output_length = 10
    elif data_name == "ECG":
        data, time_all = data_utils_v12.generate_ECG_dataset()

    print(f"{data_name} data shape:", data.shape)  # [5000, 200, 2]
    print("time_all", time_all.shape)

    # split train and test
    N = data.shape[0]
    train_N = int(train_ratio * N)

    data_train = data[:train_N]
    data_test = data[train_N:]

    # Step 2: Create windowed data from training trajectories
    windows_train = data_utils_v12.create_tagged_windows(
        data_train, time_all,
        input_length=input_length,
        output_length=output_length,
        mask_ratio=mask_ratio,
        stride=stride,
        seed=42
    )

    # Step 3: Bucket windows by start index
    buckets_train = data_utils_v12.bucket_windows_by_start(windows_train)

    # Step 4: Build DataLoaders
    bucketed_loaders_train = data_utils_v12.build_bucket_dataloaders(buckets_train, batch_size=32)

    output_dim = data.shape[2]
    input_dim = data.shape[2]

    model = GRUDTMoSInterface(
        input_dim=data.shape[2],
        output_dim=data.shape[2],
        input_length=input_length,
        output_length=output_length,
        hidden_dim=latent_dim,
        gate_dim=gate_dim,  # 3 or 9, same as before
    ).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))


    # Step 3: Train the Model
    best_model_state_path = f"./results/extra_{model_name}_{int(mask_ratio * 100)}mask_{data_name}.pth"
    # print(bucketed_loaders_train[0])
    train_model(model, bucketed_loaders_train, epochs=num_epochs, lr=learning_rate, save_path=best_model_state_path)

    # Step 4: Test and Visualize
    best_model = GRUDTMoSInterface(
        input_dim=data.shape[2],
        output_dim=data.shape[2],
        input_length=input_length,
        output_length=output_length,
        hidden_dim=latent_dim,
        gate_dim=gate_dim,  # 3 or 9, same as before
    ).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

    best_model.load_state_dict(torch.load(best_model_state_path))
    best_model.eval()

    # Test mode 2: test the moving windows and do average for overlapps
    test_and_visualize_model_win_extra(
        model=best_model,
        data_test=data_test,
        time_all=time_all,
        input_length=input_length,
        output_length=output_length,
        stride=stride,
        mask_ratio=mask_ratio,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        num_plot=1
    )

    print("mask_ratio", mask_ratio)

