# predictive_metrics.py
# ------------------------------------------------------------
# Reimplement TimeGAN-pytorch Codebase.
#
# Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar,
# "Time-series Generative Adversarial Networks," NeurIPS, 2019.
# Paper: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks
#
# This file: optimized for speed without changing behavior.
# Last updated: 2025-09-24
# ------------------------------------------------------------

from __future__ import annotations

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence, pad_packed_sequence, PackedSequence
from sklearn.metrics import mean_absolute_error
from tqdm.auto import tqdm

from .metric_utils import extract_time


class Predictor(nn.Module):
    """A simple GRU-based predictor network (one-step-ahead for last feature).
    Behavior preserved: GRU over (dim-1) features, Linear->Sigmoid, per-timestep outputs.
    """

    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        self.gru = nn.GRU(
            input_size=dim - 1,
            hidden_size=hidden_dim,
            num_layers=1,
            batch_first=True,
        )
        self.linear_layer = nn.Linear(hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            x: padded batch of inputs, shape [B, T_max, dim-1]
            lengths: original sequence lengths (LongTensor, CPU), shape [B]

        Returns:
            y_hat: padded predictions, shape [B, T_max, 1]
        """
        # Pack (avoid GRU work on pads)
        x_packed: PackedSequence = pack_padded_sequence(
            x, lengths, batch_first=True, enforce_sorted=False
        )

        # GRU on packed
        y_packed, _ = self.gru(x_packed)  # y_packed.data: [N_valid, H]

        # Apply Linear + Sigmoid on valid steps only (packed.data)
        logits_valid = self.linear_layer(y_packed.data)     # [N_valid, 1]
        preds_valid = self.sigmoid(logits_valid)            # [N_valid, 1]

        # Rebuild a PackedSequence with transformed data
        preds_packed = PackedSequence(
            data=preds_valid,
            batch_sizes=y_packed.batch_sizes,
            sorted_indices=y_packed.sorted_indices,
            unsorted_indices=y_packed.unsorted_indices,
        )

        # Unpack to padded to MATCH original behavior/shape
        # total_length = x.size(1) ensures fixed T_max (no behavior change)
        y_hat, _ = pad_packed_sequence(preds_packed, batch_first=True, total_length=x.size(1))
        return y_hat  # [B, T_max, 1]


def _make_mask(lengths: torch.Tensor, T_max: int, device: torch.device) -> torch.Tensor:
    """Vectorized mask of shape [B, T_max, 1] with ones on valid steps."""
    # lengths is CPU Long; put a device copy for comparison
    B = lengths.size(0)
    arange_t = torch.arange(T_max, device=device).unsqueeze(0).expand(B, T_max)
    mask_bt = (arange_t < lengths.to(device).unsqueeze(1)).float()  # [B, T_max]
    return mask_bt.unsqueeze(-1)  # [B, T_max, 1]


def predictive_score_metrics(
    ori_data,
    generated_data,
    model_training_iterations: int | None = None,
    device: str = "cuda",
):
    """Report performance of the post-hoc RNN one-step-ahead predictor.

    Args:
        ori_data: list/array of sequences, each [T_i, dim]
        generated_data: list/array of sequences, each [T_i, dim]
        model_training_iterations: optional number of training iterations (default 5000)
        device: "cuda" or "cpu" (final selection still respects torch.cuda.is_available())

    Returns:
        predictive_score: float MAE on original data
    """
    # Shapes
    no, seq_len, dim = np.asarray(ori_data).shape

    # Extract times (lengths)
    ori_time, _ = extract_time(ori_data)
    gen_time, _ = extract_time(generated_data)

    # Predictor setup
    hidden_dim = int(dim / 2)
    iterations = model_training_iterations if model_training_iterations is not None else 5000
    batch_size = 128

    # Device resolution (preserve original behavior: prefer CUDA if available)
    torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    predictor_model = Predictor(dim, hidden_dim).to(torch_device)
    p_solver = torch.optim.Adam(predictor_model.parameters())
    # Keep L1Loss default reduction='mean' to match original behavior with masked tensors
    p_loss = nn.L1Loss()

    # -------------------------
    # Training on synthetic data
    # -------------------------
    predictor_model.train()
    G = len(generated_data)

    for itt in tqdm(range(iterations), desc="Training Predictor on synthetic data", unit="iter"):
        # Mini-batch indices
        idx = np.random.permutation(G)
        train_idx = idx[:batch_size]

        # Prepare batch lists: inputs (all feats except last), targets (last feat, shifted by 1)
        X_mb_list = [
            torch.tensor(generated_data[i][:-1, :(dim - 1)], dtype=torch.float32)
            for i in train_idx
        ]  # shapes vary: [T_i-1, dim-1]
        Y_mb_list = [
            torch.tensor(generated_data[i][1:, (dim - 1)].reshape(-1, 1), dtype=torch.float32)
            for i in train_idx
        ]  # [T_i-1, 1]

        # Lengths (T-1 per sample)
        T_mb_cpu = torch.tensor([gen_time[i] - 1 for i in train_idx], dtype=torch.long)  # CPU

        # Pad to tensors on device
        X_mb = pad_sequence(X_mb_list, batch_first=True).to(torch_device)  # [B, T_max, D-1]
        Y_mb = pad_sequence(Y_mb_list, batch_first=True).to(torch_device)  # [B, T_max, 1]
        T_max = X_mb.size(1)

        # Forward
        p_solver.zero_grad(set_to_none=True)
        y_pred = predictor_model(X_mb, T_mb_cpu)  # [B, T_max, 1]

        # Vectorized mask (no Python loops); same masking behavior as original
        mask = _make_mask(T_mb_cpu, T_max, X_mb.device)  # [B, T_max, 1]

        # Same behavior: masked predictions/targets, mean over ALL elements (pads stay zero)
        loss = p_loss(y_pred * mask, Y_mb * mask)

        loss.backward()
        p_solver.step()

    # -------------------------
    # Evaluation on original data
    # -------------------------
    predictor_model.eval()
    with torch.no_grad():
        # Prepare full original set
        X_mb_list = [
            torch.tensor(ori_data[i][:-1, :(dim - 1)], dtype=torch.float32)
            for i in range(no)
        ]
        Y_mb_list = [
            torch.tensor(ori_data[i][1:, (dim - 1)].reshape(-1, 1), dtype=torch.float32)
            for i in range(no)
        ]
        T_mb_cpu = torch.tensor([ori_time[i] - 1 for i in range(no)], dtype=torch.long)

        X_mb = pad_sequence(X_mb_list, batch_first=True).to(torch_device)
        T_max = X_mb.size(1)

        # Predict (padded)
        pred_Y_padded = predictor_model(X_mb, T_mb_cpu).cpu().numpy()

    # Compute MAE over valid steps only (unchanged behavior vs. your loop)
    mae_sum = 0.0
    for i in range(no):
        true_len = int(T_mb_cpu[i].item())
        # Y_mb_list[i]: torch.Tensor on CPU, shape [T_i-1, 1]
        # pred_Y_padded[i]: np.ndarray, shape [T_max, 1]
        mae_sum += mean_absolute_error(
            Y_mb_list[i][:true_len].numpy(),
            pred_Y_padded[i, :true_len, :]
        )

    predictive_score = mae_sum / no
    return predictive_score