import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
import random
from typing import Optional
import data_utils_interpolation as data_utils


# ----------------------------------------------------------------
# 0. Utility functions
# ----------------------------------------------------------------
def set_seed(seed_value: int):
    """Set seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# ----------------------------------------------------------------
# 1. GRU-Δ-T Layer
# ----------------------------------------------------------------

class GRUDTLayer(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.cell = nn.GRUCell(input_size, hidden_size)

    def forward(
            self,
            x: torch.Tensor,  # B × L × D
            t_vec: torch.Tensor,  # L   or   B × L
            h0: Optional[torch.Tensor] = None,
    ):
        B, L, _ = x.shape
        h = torch.zeros(B, self.hidden_size, device=x.device) if h0 is None else h0

        hidden_seq = []
        for i in range(L):
            if i > 0:
                if t_vec.dim() == 2:  # B × L
                    dt = t_vec[:, i] - t_vec[:, i - 1]  # (B)
                else:  # L
                    dt = t_vec[i] - t_vec[i - 1]  # scalar → broadcast

                # Ensure dt is non-negative (can happen with unsorted or identical timestamps)
                if isinstance(dt, torch.Tensor):
                    dt = torch.clamp(dt, min=0.0)
                else:  # scalar
                    dt = max(0.0, dt)

                decay = torch.exp(-dt)  # Exponential decay, common in GRU-D variants
                decay = decay.unsqueeze(-1) if decay.dim() == 1 else decay  # (B, 1) or scalar

                # Decay the hidden state
                h_decayed = decay * h

                # Note: Original GRU-D has more complex decay for inputs and imputation.
                h = self.cell(x[:, i, :], h_decayed)
                h_from_cell = self.cell(x[:, i, :], h)  # cell uses the h from previous step
                h = decay * h + (1.0 - decay) * h_from_cell

            else:  # i == 0
                h = self.cell(x[:, i, :], h)  # h is h0 (usually zeros)

            hidden_seq.append(h.unsqueeze(1))

        hidden_seq = torch.cat(hidden_seq, dim=1)  # B × L × H
        return hidden_seq, h


# ────────────────────────────────────────────────────────────────
# 2.  Core GRU‑Δ‑T model that outputs a *sequence* prediction
# ────────────────────────────────────────────────────────────────
class GRUDTModel(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super().__init__()
        self.encoder = GRUDTLayer(input_size, hidden_size)
        self.readout = nn.Linear(hidden_size, output_size)

    def forward(self, x: torch.Tensor, t_vec: torch.Tensor):
        """
        x     : B × L × D_in
        t_vec : L or B x L
        returns : B × L × output_size (prediction at every step)
        """
        hidden_seq, _ = self.encoder(x, t_vec)  # B × L × H
        preds = self.readout(hidden_seq)  # B × L × D_out
        return preds


if __name__ == "__main__":
    # ----------------------------------------------------------------
    # 0. Parameters
    # ----------------------------------------------------------------
    seed = 200
    latent_dim = 15  # hidden_size for GRU-DT
    batch_size = 32  # Number of trajectories per batch
    num_epochs = 150
    learning_rate = 0.0008
    mask_ratio = 0.3 # 90% of data is unobservable (masked)
    train_ratio = 0.6  # Proportion of data for training

    # Data properties
    n_total_trajectories = 80
    sequence_length = 60  # L, total_steps in spiral generation
    model_name = "RNN_delta"
    data_name = "ECG"

    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create results directory
    os.makedirs("./results", exist_ok=True)

    # ----------------------------------------------------------------
    # 1. Data Loading and Preparation
    # ----------------------------------------------------------------
    if data_name == "spiral":
        data_np, time_all = data_utils.generate_spiral_dataset(n_trajectories=80, total_steps=60, visualize=False)  # (80, 60, 2)
    elif data_name == "glycolytic":
        data_np, time_all = data_utils.generate_glycolytic_dataset()
    elif data_name == "lotka":
        data_np, time_all = data_utils.generate_lotka_dataset()
    elif data_name == "load":
        data_np, time_all = data_utils.generate_load_dataset()
    elif data_name == "PV":
        data_np, time_all = data_utils.generate_PV_dataset()
    elif data_name == "power_event":
        data_np, time_all = data_utils.generate_power_event_dataset()
    elif data_name == "air_quality":
        data_np, time_all = data_utils.generate_AirQuality_dataset()
    elif data_name == "ECG":
        data_np, time_all = data_utils.generate_ECG_dataset()

    input_dim = data_np.shape[-1]  # D, feature dimension of spiral data (x, y)
    output_dim = data_np.shape[-1]  # D_out, model predicts the same dimensions

    full_data, time_points = data_np, time_all
    print(f"Spiral data shape: {full_data.shape}")
    print(f"Time points shape: {time_points.shape}")
    full_data = torch.from_numpy(full_data).float().to(device)
    time_points = torch.from_numpy(time_points).float().to(device)

    # Move to device
    full_data = full_data.to(device)
    time_points = time_points.to(device)  # Shared time vector for all trajectories

    # Split data
    N = full_data.shape[0]
    train_N = int(train_ratio * N)

    train_data = full_data[:train_N]
    test_data = full_data[train_N:]

    print(f"Training data shape: {train_data.shape}")
    print(f"Test data shape: {test_data.shape}")

    # Create DataLoader for training
    train_dataset = TensorDataset(train_data)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # ----------------------------------------------------------------
    # 2. Model Initialization
    # ----------------------------------------------------------------
    model = GRUDTModel(input_size=input_dim, hidden_size=latent_dim, output_size=output_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    # Use reduction='none' to compute loss per element, then average over observed
    criterion = nn.MSELoss(reduction='none')

    # ----------------------------------------------------------------
    # 3. Training Loop
    # ----------------------------------------------------------------
    best_train_loss = float('inf')
    # best_model_state_path = "./results/best_grudt_model.pth"
    best_model_state_path = f"./results/{model_name}_{int(mask_ratio * 100)}mask_{data_name}.pth"

    print("\nStarting training...")
    for epoch in range(num_epochs):
        model.train()
        epoch_total_loss = 0
        epoch_observed_points = 0

        for batch_x_list in train_loader:
            x_batch = batch_x_list[0].to(device)  # (B, L, D)
            B, L, D_in = x_batch.shape

            # Create observation mask for the current batch
            # True = observed, False = unobserved/masked
            # Ensure at least one point is observed if L > 0 to avoid division by zero in loss
            if L > 0:
                prob_observed = 1.0 - mask_ratio
                observed_mask = torch.rand(B, L, device=device) < prob_observed
                # Ensure at least one observation per trajectory if possible, or globally
                if not observed_mask.any():  # if all are False
                    if B > 0 and L > 0:
                        observed_mask[0, 0] = True  # Ensure at least one point in batch is observed
                    else:  # Batch or sequence is empty, skip
                        continue
            else:  # sequence length is 0
                continue

            # Create model input: zero out unobserved points
            # (Other imputation like forward fill could be used too)
            input_x_batch = x_batch.clone()
            input_x_batch[~observed_mask] = 0.0

            optimizer.zero_grad()

            # Forward pass: model uses shared time_points for all batch elements
            predictions = model(input_x_batch, time_points)  # (B, L, D_out)

            # Calculate loss ONLY on observed points
            loss_per_element = criterion(predictions, x_batch)  # (B, L, D_out)

            # Mask the loss
            masked_loss = loss_per_element[observed_mask]  # Selects losses for observed points & flattens

            if masked_loss.numel() > 0:
                loss = masked_loss.mean()
                loss.backward()
                optimizer.step()
                epoch_total_loss += loss.item() * masked_loss.numel()  # Weighted by num elements
                epoch_observed_points += masked_loss.numel()
            else:  # Should not happen if we ensure observed_mask.any()
                loss = torch.tensor(0.0, device=device)

        avg_epoch_loss = epoch_total_loss / epoch_observed_points if epoch_observed_points > 0 else float('nan')
        print(f"Epoch [{epoch + 1}/{num_epochs}], Average Training Loss (on observed): {avg_epoch_loss:.6f}")

        if avg_epoch_loss < best_train_loss:
            best_train_loss = avg_epoch_loss
            torch.save(model.state_dict(), best_model_state_path)
            print(f"Saved new best model to {best_model_state_path}")

    print("Training finished.")

    # ----------------------------------------------------------------
    # 4. Testing and Evaluation
    # ----------------------------------------------------------------
    print("\nStarting testing...")
    model.load_state_dict(torch.load(best_model_state_path, map_location=device))
    model.eval()

    test_trajs = test_data.to(device)  # (N_test, L, D)
    N_test, L_test, D_test = test_trajs.shape

    test_prob_observed = 1.0 - mask_ratio
    test_observed_mask = torch.rand(N_test, L_test, device=device) < test_prob_observed

    if not test_observed_mask.any() and N_test > 0 and L_test > 0:
        test_observed_mask[0, 0] = True

    test_input_x = test_trajs.clone()
    test_input_x[~test_observed_mask] = 0.0  # Zero out unobserved points

    with torch.no_grad():
        pred_full = model(test_input_x, time_points)  # (N_test, L, D_out)

    unobserved_points_mask_for_mse = ~test_observed_mask  # True for points that were NOT in input

    if unobserved_points_mask_for_mse.any():
        mse_miss = ((pred_full[unobserved_points_mask_for_mse] - test_trajs[
            unobserved_points_mask_for_mse]) ** 2).mean().item()
    else:
        mse_miss = float('nan')  # Or 0.0 if no missing points, meaning all were "observed" in test_input_x

    # MSE on all points
    mse_all = ((pred_full - test_trajs) ** 2).mean().item()

    # Assuming 'itr' refers to the final state after training:
    print(f'[Test] Interpolation MSE ({mask_ratio} Missing): {mse_miss:7.6f}')
    print(f'[Test] Interpolation MSE (All Points):     {mse_all:7.6f}')

    # ----------------------------------------------------------------
    # 5. Visualization of a sample result

    print("\nVisualizing a sample trajectory...")
    if N_test > 0:
        sample_idx_to_plot = 0  # Plot the first trajectory in the test set

        gt_np = test_trajs[sample_idx_to_plot].cpu().numpy()  # (L, D)
        pred_np = pred_full[sample_idx_to_plot].cpu().numpy()  # (L, D)
        t_np = time_points.cpu().numpy()  # (L,)

        observed_points_for_viz = test_observed_mask[sample_idx_to_plot].cpu().numpy()  # (L,)

        plt.figure(figsize=(12, 4))  
        for d_idx in range(D_test):
            plt.subplot(1, D_test, d_idx + 1)

            plt.plot(t_np, gt_np[:, d_idx], label='Ground Truth', color='black', linewidth=2)
            plt.plot(t_np, pred_np[:, d_idx], '--', label='Prediction (GRUDT)', color='blue', linewidth=2)

            plt.scatter(t_np[observed_points_for_viz],
                        gt_np[observed_points_for_viz, d_idx],  
                        color='green', label='Observed Input', zorder=10, s=50)

            plt.title(f"Trajectory {sample_idx_to_plot}, Dimension {d_idx + 1}")
            plt.xlabel("Time")
            plt.ylabel("Value")
            if d_idx == 0:
                plt.legend()
            plt.grid(True)

        plt.tight_layout()
        save_path = f"./results/interpolation_{model_name}_{mask_ratio}.png"
        plt.savefig(save_path)
        print(f"Saved visualization to {save_path}")
        plt.show()
    else:
        print("No test data to visualize.")