"""Reimplement TimeGAN-pytorch Codebase.

Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar,
"Time-series Generative Adversarial Networks,"
Neural Information Processing Systems (NeurIPS), 2019.

Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks

Last updated Date: October 18th 2021
Code author: Zhiwei Zhang (bitzzw@gmail.com)

-----------------------------

Note: Use post-hoc RNN to classify original data and synthetic data

Output: discriminative score (np.abs(classification accuracy - 0.5))

The code has been reimplemented in pure PyTorch to avoid the original TensorFlow dependency.

"""

import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm

# --- Your helper functions (train_test_divide, batch_generator) remain the same ---
def train_test_divide(data_x, data_x_hat, train_rate=0.8):
    # (Your implementation is here)
    no = len(data_x)
    idx = np.random.permutation(no)
    train_idx = idx[:int(no * train_rate)]
    test_idx = idx[int(no * train_rate):]
    train_x = [data_x[i] for i in train_idx]
    test_x = [data_x[i] for i in test_idx]
    no = len(data_x_hat)
    idx = np.random.permutation(no)
    train_idx = idx[:int(no * train_rate)]
    test_idx = idx[int(no * train_rate):]
    train_x_hat = [data_x_hat[i] for i in train_idx]
    test_x_hat = [data_x_hat[i] for i in test_idx]
    return train_x, train_x_hat, test_x, test_x_hat

def batch_generator(data, batch_size):
    # (Your implementation is here)
    no = len(data)
    idx = np.random.permutation(no)
    train_idx = idx[:batch_size]
    X_mb = list(data[i] for i in train_idx)
    return X_mb


def discriminative_score_metrics(ori_data, generated_data, model_training_iterations : int = None, device='cuda'):
    # Basic Parameters
    ori_data, generated_data = torch.Tensor(ori_data), torch.Tensor(generated_data)
    no, seq_len, dim = ori_data.shape
    hidden_dim = int(dim / 2)
    iterations = model_training_iterations if model_training_iterations is not None else 2000
    batch_size = 32
    device = torch.device(device)

    # CORRECTED Discriminator Model
    class Discriminator(nn.Module):
        def __init__(self, inp_dim, hidden_dim):
            super(Discriminator, self).__init__()
            self.rnn = nn.GRU(input_size=inp_dim, hidden_size=hidden_dim, num_layers=1, batch_first=True)
            self.linear = nn.Linear(hidden_dim, 1)

        def forward(self, x):
            # x shape: (batch, seq_len, dim)
            _, last_hidden_state = self.rnn(x)
            # last_hidden_state shape: (1, batch, hidden_dim)
            
            # KEY FIX: Squeeze the first dimension to get (batch, hidden_dim)
            last_hidden_squeezed = last_hidden_state.squeeze(0)
            
            y_hat_logit = self.linear(last_hidden_squeezed) # Output shape: (batch, 1)
            y_hat = torch.sigmoid(y_hat_logit)
            return y_hat_logit, y_hat

    model = Discriminator(dim, hidden_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = nn.BCEWithLogitsLoss()

    train_x, train_x_hat, test_x, test_x_hat = train_test_divide(ori_data, generated_data)

    # --- Training Loop (mostly unchanged) ---
    model.train()
    for _ in tqdm(range(iterations), desc='Training Discriminator'):
        # Batch setting (handle potential empty list from generator)
        X_mb_list = batch_generator(train_x, batch_size)
        X_hat_mb_list = batch_generator(train_x_hat, batch_size)
        
        if not X_mb_list or not X_hat_mb_list: continue # Skip batch if generator returns empty

        X_mb = torch.stack(X_mb_list).to(device).float()
        X_hat_mb = torch.stack(X_hat_mb_list).to(device).float()

        y_logit_real, _ = model(X_mb)
        y_logit_fake, _ = model(X_hat_mb)

        d_loss_real = loss_fn(y_logit_real, torch.ones_like(y_logit_real))
        d_loss_fake = loss_fn(y_logit_fake, torch.zeros_like(y_logit_fake))
        d_loss = d_loss_real + d_loss_fake

        optimizer.zero_grad()
        d_loss.backward()
        optimizer.step()

    # --- CORRECTED and ROBUST Evaluation ---
    model.eval()
    with torch.no_grad():
        # Handle cases where test sets might be empty
        if len(test_x) == 0 or len(test_x_hat) == 0:
            print("Warning: One or both test sets are empty. Returning NaN for scores.")
            return np.nan, np.nan, np.nan
            
        test_x_tensor = torch.stack(test_x).to(device).float()
        test_x_hat_tensor = torch.stack(test_x_hat).to(device).float()
        
        # Get predictions
        _, y_pred_real = model(test_x_tensor)
        _, y_pred_fake = model(test_x_hat_tensor)

        # Convert to numpy (predictions are now cleanly shaped)
        y_pred_real = y_pred_real.cpu().numpy() # Shape: (num_real_test, 1)
        y_pred_fake = y_pred_fake.cpu().numpy() # Shape: (num_fake_test, 1)

        # Create labels
        y_true_real = np.ones_like(y_pred_real)
        y_true_fake = np.zeros_like(y_pred_fake)

        # Concatenate for final score
        y_pred_final = np.concatenate((y_pred_real, y_pred_fake), axis=0)
        y_true_final = np.concatenate((y_true_real, y_true_fake), axis=0)

        # Calculate scores cleanly
        acc = accuracy_score(y_true_final, y_pred_final > 0.5)
        discriminative_score = abs(0.5 - acc)
        
        real_acc = accuracy_score(y_true_real, y_pred_real > 0.5)
        fake_acc = accuracy_score(y_true_fake, y_pred_fake > 0.5)

    return discriminative_score, fake_acc, real_acc