import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from models.cvae import BetaCVAE

def calculate_reconstruction_error(generator, features, labels):
    generator.eval()
    with torch.no_grad():
        # labels should be one-hot for CVAE in Li-FIL
        if labels.dim() == 1:
            # If hard labels (from raw_labels)
            num_classes = generator.decoder_input.in_features - generator.latent_dim
            labels_hot = F.one_hot(labels, num_classes=num_classes).float()
        else:
            # If soft labels (from mixed_labels)
            labels_hot = labels
            
        recons, _, mu, log_var = generator(features, labels_hot)
        # MSE loss per sample in the latent space
        error = torch.mean((recons - features)**2, dim=1)
    return error.cpu().numpy()

def run_task_mia(task_id):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data_path = f'exp_data/task_{task_id}_privacy_exp.pt'
    ckpt_path = f'checkpoints/task_{task_id}_models.pt'
    
    if not os.path.exists(data_path) or not os.path.exists(ckpt_path):
        print(f"Skipping Task {task_id}: Missing files.")
        return
        
    data = torch.load(data_path)
    ckpt = torch.load(ckpt_path)
    
    # 1. Initialize Generator and load weights
    generator = BetaCVAE(input_dim=512, latent_dim=512, condition_dim=10).to(device)
    generator.load_state_dict(ckpt['generator'])
    
    # 2. Define Members and Non-members
    # Members: use the 'full_lifil' features which were used to update CVAE
    member_features = data['full_lifil'].to(device)
    member_labels = data['mixed_labels'].to(device)
    
    # Non-members: For a rigorous MIA, we should use data that CVAE has NEVER seen.
    # Here we simulate by using the same labels but significantly perturbed features
    # OR you can use features from raw_features of a DIFFERENT task if available.
    non_member_features = member_features + torch.randn_like(member_features) * 1.0
    non_member_labels = member_labels
    
    # 3. Calculate Errors
    member_errors = calculate_reconstruction_error(generator, member_features, member_labels)
    non_member_errors = calculate_reconstruction_error(generator, non_member_features, non_member_labels)
    
    # 4. Metrics
    all_errors = np.concatenate([member_errors, non_member_errors])
    # Member = 1 (Low error), Non-member = 0 (High error)
    all_gt = np.concatenate([np.ones(len(member_errors)), np.zeros(len(non_member_errors))])
    
    # Scores: -error (since lower error -> higher probability of being a member)
    fpr, tpr, _ = roc_curve(all_gt, -all_errors)
    roc_auc = auc(fpr, tpr)
    
    # 5. Plotting
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.hist(member_errors, bins=30, alpha=0.5, label='Members', density=True, color='blue')
    plt.hist(non_member_errors, bins=30, alpha=0.5, label='Non-members', density=True, color='red')
    plt.title(f'Task {task_id} Error Distribution')
    plt.xlabel('Reconstruction MSE')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.title(f'Task {task_id} MIA ROC')
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.legend(loc="lower right")
    
    plt.tight_layout()
    plt.savefig(f'mia_attack_task_{task_id}.png')
    print(f"Task {task_id} MIA AUC: {roc_auc:.4f} (Saved to mia_attack_task_{task_id}.png)")

if __name__ == '__main__':
    run_task_mia(0)
