import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from aif360.datasets import CompasDataset, GermanDataset
from fairlearn.metrics import demographic_parity_ratio
from sklearn.cluster import KMeans
import time
import matplotlib.pyplot as plt
import json
import argparse
import sys
import os
from datetime import datetime

seed = 1235

torch.manual_seed(seed)

np.random.seed(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {device}")

if torch.cuda.is_available():

    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True

    torch.backends.cudnn.benchmark = False

class Logger:

    def __init__(self, log_file):

        self.terminal = sys.stdout

        self.log = open(log_file, 'w')

    def write(self, message):

        self.terminal.write(message)

        self.log.write(message)

        self.flush()

    def flush(self):

        self.terminal.flush()

        self.log.flush()

    def close(self):

        self.log.close()

def normalize_features(X):

    mean = X.mean(dim=0, keepdim=True)

    std = X.std(dim=0, keepdim=True)


    std[std == 0] = 1

    return (X - mean) / std

def load_data(dataset_name='COMPAS'):

    if dataset_name == 'COMPAS':

        dataset = CompasDataset()

    else:

        dataset = GermanDataset()

    
    X = dataset.features

    y = dataset.labels.ravel()

    s = dataset.protected_attributes

    
    if dataset_name == 'Credit':
        
        y = (y - 1.0).astype(int)

        print(f"Remapped German Credit labels from [1.0, 2.0] to [0, 1]")

        print(f"Unique labels after remapping: {np.unique(y)}")

    X = torch.tensor(X, dtype=torch.float32).to(device)

    y = torch.tensor(y, dtype=torch.long).to(device)

    s = torch.tensor(s, dtype=torch.float32).to(device)
    
    X = normalize_features(X)
    
    return X, y, s

def compute_mutual_info_loss(estimator, z, s):

    joint_score = estimator(z, s)

    joint_term = torch.mean(joint_score)
    
    batch_size = z.size(0)

    idx = torch.randperm(batch_size, device=z.device)

    s_shuffled = s[idx]

    marginal_score = estimator(z, s_shuffled)

    marginal_term = torch.log(torch.mean(torch.exp(marginal_score)))
    
    mi_loss = marginal_term - joint_term

    return mi_loss

def intersectional_fairness_loss(latent, sensitive):

    unique_groups = torch.unique(sensitive, dim=0)

    overall_mean = torch.mean(latent, dim=0)

    loss = 0.0

    for group in unique_groups:

        mask = (sensitive == group).all(dim=1)

        if mask.sum() > 0:

            group_mean = torch.mean(latent[mask], dim=0)

            loss += F.mse_loss(group_mean, overall_mean)

    loss = loss / len(unique_groups)

    return loss

def intersectional_fairness_loss_class(latent, sensitive, labels, lambda_class=1.0):

    unique_groups = torch.unique(sensitive, dim=0)

    overall_mean = torch.mean(latent, dim=0)
    
    global_loss = 0.0

    for group in unique_groups:

        mask = (sensitive == group).all(dim=1)

        if mask.sum() > 0:

            group_mean = torch.mean(latent[mask], dim=0)

            global_loss += F.mse_loss(group_mean, overall_mean)

    global_loss = global_loss / len(unique_groups)
    
    class_loss = 0.0

    unique_classes = torch.unique(labels)

    for c in unique_classes:

        class_mask = (labels == c)

        if class_mask.sum() > 0:

            class_mean = torch.mean(latent[class_mask], dim=0)

            for group in unique_groups:

                group_class_mask = (sensitive == group).all(dim=1) & class_mask

                if group_class_mask.sum() > 0:

                    group_class_mean = torch.mean(latent[group_class_mask], dim=0)

                    class_loss += F.mse_loss(group_class_mean, class_mean)
    
    class_loss = class_loss / (len(unique_groups) * len(unique_classes))

    final_loss = global_loss + lambda_class * class_loss
    
    return final_loss

def fpr_regularizer(probs, labels, sensitive, eps=1e-6):

    sensitive_np = sensitive.cpu().numpy()

    labels_np = labels.cpu().numpy()

    probs_np = probs.cpu().detach().numpy()
    
    groups = np.unique(sensitive_np, axis=0)

    group_fprs = []
    
    for group in groups:

        mask = np.all(sensitive_np == group, axis=1)

        if mask.sum() > 0:

            neg_mask = (labels_np[mask] == 0)

            soft_fp = probs_np[mask, 1][neg_mask].sum()

            neg_count = neg_mask.sum()

            if neg_count > 0:

                group_fpr = soft_fp / (neg_count + eps)

                group_fprs.append(group_fpr)

    if len(group_fprs) > 1:

        fpr_reg = np.var(group_fprs)

        return torch.tensor(fpr_reg, dtype=torch.float32, device=probs.device)
    
    else:

        return torch.tensor(0.0, dtype=torch.float32, device=probs.device)
    
def evaluate_representations(encoder, classifier, data_loader, dataset_name='COMPAS'):

    group_mappings = {

        'COMPAS': {
            (0, 0): "Male Not Caucasian",
            (0, 1): "Male Caucasian",
            (1, 0): "Female Not Caucasian",
            (1, 1): "Female Caucasian"
        },
        'Credit': {
            (0, 0): "Female Young",
            (0, 1): "Female Old",
            (1, 0): "Male Young",
            (1, 1): "Male Old"
        }
    }
    
    attribute_names = {
        'COMPAS': ["Gender", "Race"],
        'Credit': ["Gender", "Age"]
    }
    
    current_mapping = group_mappings.get(dataset_name, {})

    current_attr_names = attribute_names.get(dataset_name, ["Attribute 1", "Attribute 2"])
    
    encoder.eval()

    classifier.eval()
    
    all_preds = []

    all_labels = []

    all_sensitive = []

    all_latent = []

    with torch.no_grad():

        for batch in data_loader:

            x, y, s = batch

            z = encoder.encode(x)

            logits = classifier(z)

            preds = torch.argmax(logits, dim=1)
            
            all_preds.append(preds)

            all_labels.append(y)

            all_sensitive.append(s)

            all_latent.append(z)
    
    all_preds = torch.cat(all_preds)

    all_labels = torch.cat(all_labels)

    all_sensitive = torch.cat(all_sensitive)

    all_latent = torch.cat(all_latent)
    
    all_preds_cpu = all_preds.cpu()

    all_labels_cpu = all_labels.cpu()
    
    accuracy = (all_preds_cpu == all_labels_cpu).float().mean().item()

    print(f"Test Accuracy: {accuracy:.4f}")
    
    latent_np = all_latent.cpu().numpy()

    sensitive_np = all_sensitive.cpu().numpy()

    preds_np = all_preds.cpu().numpy()

    labels_np = all_labels.cpu().numpy()
    
    groups = np.unique(sensitive_np, axis=0)

    group_fprs = []

    group_names = []

    
    print("\nFalse Positive Rates by Group:")
    print("------------------------------")
    
    for group in groups:

        mask = np.all(sensitive_np == group, axis=1)

        if mask.sum() > 0:

            neg_mask = (labels_np[mask] == 0)

            neg_count = neg_mask.sum()
            
            if neg_count > 0:

                fp = np.sum((preds_np[mask] == 1) & (labels_np[mask] == 0))

                group_fpr = fp / neg_count

                group_fprs.append(group_fpr)
                
                group_tuple = tuple(group.tolist())

                group_name = current_mapping.get(group_tuple, f"Group {group}")

                group_names.append(group_name)
                
                print(f"{group_name}: FPR = {group_fpr:.4f} (n={neg_count})")
    
    if len(group_fprs) > 1:

        fpr_diff = np.max(group_fprs) - np.min(group_fprs)

        print(f"\nFPR Difference (max-min): {fpr_diff:.4f}")
        
        max_idx = np.argmax(group_fprs)

        min_idx = np.argmin(group_fprs)

        print(f"Largest gap: {group_names[max_idx]} ({group_fprs[max_idx]:.4f}) vs {group_names[min_idx]} ({group_fprs[min_idx]:.4f})")

    kmeans = KMeans(n_clusters=2, random_state=42)

    latent_clusters = kmeans.fit_predict(latent_np)
    
    dp_ratios = []
    
    print("\nDemographic Parity Ratios by Attribute:")
    print("--------------------------------------")
    
    for i in range(sensitive_np.shape[1]):

        dp_ratio = demographic_parity_ratio(
            y_true=labels_np,
            y_pred=latent_clusters,
            sensitive_features=sensitive_np[:, i]
        )

        dp_ratios.append(dp_ratio)
        
        attr_name = current_attr_names[i] if i < len(current_attr_names) else f"Attribute {i}"

        print(f"{attr_name}: {dp_ratio:.4f}")
    
    return accuracy, group_fprs, dp_ratios