import torch
import torch.nn as nn
import numpy as np
import torchvision.datasets as datasets
from time import time
import os
from collections import OrderedDict
import matplotlib.pyplot as plt
import torchvision.transforms as tvt
import torchvision
import random
from sklearn.metrics import confusion_matrix, accuracy_score
import xgboost as xgb
from xgboost import XGBClassifier
from tqdm import tqdm
import random

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def seed_everything(seed):
    """
    Changes the seed for reproducibility. 
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    


u = [19, 21, 31, 32]  


x_train = np.load('celebA_representation/train.npy')
train_attribute = np.load('celebA_representation/train_attribute.npy')


x_valid = np.load('celebA_representation/valid.npy')
valid_attribute = np.load('celebA_representation/valid_attribute.npy')

x_train = np.concatenate((x_train, x_valid), axis = 0)
train_attribute = np.concatenate((train_attribute, valid_attribute), axis = 0)


x_test = np.load('celebA_representation/test.npy')
test_attribute = np.load('celebA_representation/test_attribute.npy')

bold_u_train = train_attribute[:, u]
bold_u_valid = valid_attribute[:, u]
bold_u_test = test_attribute[:, u]


def put_into_class(label):
    category = sum([2**(i)*label[:,i] for i in range(label.shape[1])])
    return category 


train_groups = put_into_class(bold_u_train)
valid_groups = put_into_class(bold_u_valid)
test_groups = put_into_class(bold_u_test)
np.unique(valid_groups) 


train_sensitive = train_attribute[:, 20][:, np.newaxis]
valid_sensitive = valid_attribute[:, 20][:, np.newaxis]
test_sensitive = test_attribute[:, 20][:, np.newaxis]

train_target = train_attribute[:, 9][:, np.newaxis]
valid_target = valid_attribute[:, 9][:, np.newaxis]
test_target = test_attribute[:, 9][:, np.newaxis]

# train_sensitive = train_attribute[:, 39][:, np.newaxis]
# valid_sensitive = valid_attribute[:, 39][:, np.newaxis]
# test_sensitive = test_attribute[:, 39][:, np.newaxis]

# train_target = train_attribute[:, 3][:, np.newaxis]
# valid_target = valid_attribute[:, 3][:, np.newaxis]
# test_target = test_attribute[:, 3][:, np.newaxis]

train_data_np = x_train
data_np = x_valid
sig = data_np.std(0)
plt.plot(sig)
plt.show()



def mmd_GPU(x, y, sigma):
    n, d = x.shape
    m, d2 = y.shape
    assert d == d2
    
    xy = torch.cat([x.detach(), y.detach()], dim=0).to(device)
    dists = torch.cdist(xy, xy, p=2.0)

    k = torch.exp((-1/(2*sigma**2)) * dists**2) + torch.eye(n+m, device=device)*1e-5
    k_x = k[:n, :n]
    k_y = k[n:, n:]
    k_xy = k[:n, n:]

    mmd = k_x.sum() / (n * (n - 1)) + k_y.sum() / (m * (m - 1)) - 2 * k_xy.sum() / (n * m)
    return mmd



#Step 2 compute MMD score Compute correlation with Sensitive labels
def mmd_u(group_id):
    score = []
    percentage_to_use = 66.67
    total_indices = np.arange(x_valid.shape[0])
    selected_indices = np.random.choice(total_indices, size=int(percentage_to_use / 100* len(total_indices)), replace=False)
    
    selected_data = x_valid[selected_indices]
    selected_labels = valid_sensitive[selected_indices]
    selected_groups = valid_groups[selected_indices]
    
    indices = np.where(selected_groups == group_id)[0]
    group_data = selected_data[indices]
    group_labels = selected_labels[indices]
    group_labels = np.squeeze(group_labels)  
    
    print(group_data.shape) 

    
    selected_data_0 = torch.tensor(group_data[group_labels[:] == 0]).to(device)
    selected_data_1 = torch.tensor(group_data[group_labels[:] == 1]).to(device)
    print(selected_data_0.shape, selected_data_1.shape)
    
    for i in tqdm(range(selected_data_0.shape[1])):
        sensitive0_subsample = selected_data_0[:,i]#/(torch.norm(selected_data_0[:,i].unsqueeze(0), p=2, dim=0, keepdim=True).squeeze(0)+1e-8)
        sensitive1_subsample = selected_data_1[:,i]#/(torch.norm(selected_data_1[:,i].unsqueeze(0), p=2, dim=0, keepdim=True).squeeze(0)+1e-8)
        score_for_dim = mmd_GPU(sensitive0_subsample[:, None],sensitive1_subsample[:, None],2)
        score.append(score_for_dim.cpu().numpy())
    return score



#Step 2 compute MMD score Compute correlation with Sensitive labels
def mmd_u_y(group_id):
    score = []
    percentage_to_use = 66.67
    total_indices = np.arange(x_valid.shape[0])
    selected_indices = np.random.choice(total_indices, size=int(percentage_to_use / 100* len(total_indices)), replace=False)
    
    selected_data = x_valid[selected_indices]
    selected_labels = valid_target[selected_indices]
    selected_groups = valid_groups[selected_indices]
    
    indices = np.where(selected_groups == group_id)[0]
    group_data = selected_data[indices]
    group_labels = selected_labels[indices]
    group_labels = np.squeeze(group_labels)  
    
    print(group_data.shape)  # Should show something like (n_samples, n_features)

    
    selected_data_0 = torch.tensor(group_data[group_labels[:] == 0]).to(device)
    selected_data_1 = torch.tensor(group_data[group_labels[:] == 1]).to(device)
    print('u',selected_data_0.shape, selected_data_1.shape)
    
    for i in tqdm(range(selected_data_0.shape[1])):
        target0_subsample = selected_data_0[:,i]#/(torch.norm(selected_data_0[:,i].unsqueeze(0), p=2, dim=0, keepdim=True).squeeze(0)+1e-8)
        target1_subsample = selected_data_1[:,i]#/(torch.norm(selected_data_1[:,i].unsqueeze(0), p=2, dim=0, keepdim=True).squeeze(0)+1e-8)
        score_for_dim = mmd_GPU(target0_subsample[:, None],target1_subsample[:, None],2)
        score.append(score_for_dim.cpu().numpy())
    return score


score = []
y_score = []
for i in np.unique(valid_groups):
    score.append(mmd_u(i))
    y_score.append(mmd_u_y(i))
    

sen_score = np.array(score)
print(sen_score)
tar_score = np.array(y_score)
print(tar_score)




lambda_d = 1.5
sen_score = np.array(score, dtype=np.float64)

sen_score_tensor = torch.tensor(sen_score)
target_score_tensor = torch.tensor(tar_score)

print(sen_score_tensor.shape)

mean_value = torch.mean(sen_score_tensor, dim=1, keepdim=True) 
mean_value_target = torch.mean(target_score_tensor, dim=1, keepdim=True)
count_larger_than_mean = torch.sum(sen_score_tensor > mean_value, dim=1)
print(count_larger_than_mean)
count_larger_than_mean_target = torch.sum(target_score_tensor > mean_value_target, dim=1)

number_of_dim_to_delete = torch.mean(count_larger_than_mean.float()) 
number_of_dim_to_keep = int(torch.mean(count_larger_than_mean_target.float()))
print('number_of_dim_to_keep', number_of_dim_to_keep)

number_of_dim_to_delete = int(lambda_d * number_of_dim_to_delete) 
print(number_of_dim_to_delete)

indices = []
indices_to_zero = []


for i in np.unique(valid_groups):
    sorted_values, sorted_indices = torch.sort(sen_score_tensor[i], dim=0, descending=True)
    remaining_indices = sorted_indices[number_of_dim_to_delete:]
    delete_indices = sorted_indices[:number_of_dim_to_delete]
    print(remaining_indices)
    indices.extend(remaining_indices.unsqueeze(0))
    indices_to_zero.extend(delete_indices.unsqueeze(0))

remaining_indices = torch.stack(indices)
remaining_indices, _ = remaining_indices.sort(dim=1)

del_indices = torch.stack(indices_to_zero)
print(del_indices.shape)




indices_to_keep = []
for i in np.unique(valid_groups):
    sorted_values, sorted_indices = torch.sort(target_score_tensor[i], dim=0, descending=True)
    keep_indices = sorted_indices[:number_of_dim_to_keep]
    indices_to_keep.extend(keep_indices.unsqueeze(0))


keep_indices = torch.stack(indices_to_keep)

print(keep_indices.shape)


def data_deduction(group_id, input, label, sensitive, train):
    if train == True:
        group = train_groups
    else:
        group = test_groups
    select_input = input[group == group_id]
    select_input[:, del_indices[group_id]] = 0
    select_input = select_input[:, keep_indices[group_id]]
    
    select_label = label[group == group_id]
    select_sensitive = sensitive[group == group_id]
    return select_input, select_label, select_sensitive
    

training_data = []
train_y = []
train_a = []

test_data = []
test_y = []
test_a = []

for i in np.unique(valid_groups):
    data, label, sensitive = data_deduction(i, x_train, train_target, train_sensitive, train =True)
    training_data.extend(data)
    train_y.extend(label)
    train_a.extend(sensitive)

    testdata, testlabel, testsensitive = data_deduction(i, x_test, test_target, test_sensitive, train =False)
    test_data.extend(testdata)
    test_y.extend(testlabel)
    test_a.extend(testsensitive)

training_data = torch.tensor(training_data)
print(training_data.shape)

test_data = torch.tensor(test_data)
print(test_data.shape)


num_samples = training_data.shape[0]
shuffled_indices = torch.randperm(num_samples)

shuffled_data = training_data[shuffled_indices]
shuffled_labels = train_target[shuffled_indices]


train_target = torch.tensor(train_y)
train_sensitive = torch.tensor(train_a)

test_target = torch.tensor(test_y)
test_sensitive = torch.tensor(test_a)


def compute_fairness(cf1, cf2):
    dp = []
    TPR = []
    FPR = []
    for cf in (cf1, cf2):
        TP = np.diag(cf)
        FN = cf.sum(axis =1)-np.diag(cf)
        FP = cf.sum(axis = 0) - np.diag(cf)
        TN = cf.sum()-(FN+FP+TP)

        dp_value = (TP+FP)/(TN+FP+FN+TP)
        TPR_value = TP/(TP+FN)
        FPR_value = FP/(FP+TN)
        dp.append(dp_value)
        TPR.append(TPR_value)
        FPR.append(FPR_value)
    DP = abs(dp[0]-dp[1])
    EoP = abs(TPR[0] - TPR[1])
    EoD = 0.5*(abs(FPR[0]-FPR[1])+abs(TPR[0]-TPR[1]))
    return DP, EoP, EoD


training_data
mean = training_data.mean(dim=0)
std = training_data.std(dim=0)
epsilon = 1e-8
standardized_data = (training_data - mean) / (std + epsilon)
standardized_test = (test_data - mean) / (std + epsilon)




from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

seed_everything(64)

device = torch.device("cuda:0") 

train_dataset = TensorDataset(training_data, train_target, train_sensitive) 
test_dataset = TensorDataset(test_data, test_target, test_sensitive)
train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4096, shuffle=False)

class LogisticRegressionModel(nn.Module):
    def __init__(self, num_features):
        super(LogisticRegressionModel, self).__init__()
        self.fc_seq = nn.Sequential(
             nn.Linear(num_features, 512),
             nn.BatchNorm1d(512),
             nn.ReLU(),
        )
        self.head = nn.Linear(512, 1)
        
    
    def forward(self, x):
        x  = self.fc_seq(x)
        y = self.head(x)
        return torch.sigmoid(y)


model = LogisticRegressionModel(num_features=2073).to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)
num_epochs = 100

for epoch in range(num_epochs):
    train_loop = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
    for inputs, targets, _ in train_loop:
        outputs = model(inputs.to(device))
        loss = criterion(outputs.squeeze(), targets.squeeze().float().to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loop.set_postfix(loss=loss.item())
    model.eval()
    all_preds = []
    all_targets = []
    all_sensitive = []
    with torch.no_grad():
        test_loop = tqdm(test_loader, desc='Evaluating')
        for inputs, targets, sensitive in test_loop:
            outputs = model(inputs.to(device)).squeeze()  
            all_preds.append(outputs.cpu())
            all_targets.append(targets.cpu())
            all_sensitive.append(sensitive.cpu())    
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)
    all_sensitive = torch.cat(all_sensitive)
    
    binary_preds = (all_preds > 0.5).int()
    test_target_female = test_target[all_sensitive ==0]
    test_target_male = test_target[all_sensitive ==1]
    
    if len(binary_preds.shape)==1:
        binary_preds = np.expand_dims(binary_preds, axis =1 )
    predict_female = binary_preds[all_sensitive == 0]
    predict_male = binary_preds[all_sensitive == 1]
    
    cm_0 = confusion_matrix(test_target_female, predict_female)    
    cm_1 = confusion_matrix(test_target_male, predict_male)
    
    DP, EoP, EoD = compute_fairness(cm_0 , cm_1)
    
    print('ACC', accuracy_score(test_target, binary_preds.squeeze()))
    print('DP:', max(DP))
    print('EOp' , max(EoP))
    print('EOd', max(EoD))

