from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import random
import csv
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
from helpers import *

import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

from metrics import *

class ratio_debiasing(object):
    '''
    INPUTS
        Learning rate
        batch size
        lamb = hyperparameter of the fairness term in the loss
        num_epochs
        NN_y = NN to predict the label y
        NN_s = NN to predict the sensitive attribute s
        GPU
    '''
 
    def __init__(self, learning_rate, batch_size, lamb_fair, lamb_ratio, lamb_sparse,
                 lamb_diversity, num_epochs, num_concepts, NN_r, NN_s, NN_c, GPU):
        
        self.lambda_fair = lamb_fair
        self.lambda_ratio = lamb_ratio
        self.lambda_sparse = lamb_sparse
        self.lambda_diversity = lamb_diversity
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.device = torch.device(GPU if torch.cuda.is_available() else "cpu")
        self.num_concepts = num_concepts
        
        self.m_NN_r = NN_r().to(self.device)
        self.m_NN_s = NN_s().to(self.device)
        self.m_NN_c = NN_c().to(self.device)
 
    def train(self, X_train, y_train, S_train, y_hat, X_test=None, y_test=None, 
            S_test=None, y_hat_test=None, plot_losses=False, dem_parity=True, write=False, ablation=False, seed=None):
        '''
        INPUTS
            X_train
            y_train 
            S_train = value of the sensitive attribute in the training dataset
            y_hat (probability)
        '''
    
        batch_no = int(len(X_train) // self.batch_size)+1
       
        self.optimizer_r = torch.optim.Adam(self.m_NN_r.parameters(), lr=self.learning_rate)
        self.optimizer_s = torch.optim.Adam(self.m_NN_s.parameters(), lr=self.learning_rate)
        self.optimizer_c = torch.optim.Adam(self.m_NN_c.parameters(), lr=self.learning_rate)
 
        criterion_BCE = torch.nn.BCEWithLogitsLoss(reduction='mean')
        criterion_MSE = torch.nn.MSELoss(reduction='mean')
        criterion_L1 = torch.nn.L1Loss(reduction='mean')
       
        loss_acc_all, loss_fair_all, loss_ratio_all, loss_all, ratios_avgs, ratios_std = [], [], [], [], [], []
        
        if dem_parity: ################# DEMOGRAPHIC PARITY #################
                
            for epoch in range(self.num_epochs):

                x_train, ytrain, strain, yhat = shuffle(X_train, np.expand_dims(y_train,axis = 1), 
                                                         np.expand_dims(S_train,axis = 1), np.expand_dims(y_hat,axis = 1))

                for i in range(batch_no):
                    start = i * self.batch_size
                    end = start + self.batch_size

                    x_var = Variable(torch.from_numpy(x_train[start:end])).float().to(self.device)
                    y_var = Variable(torch.from_numpy(ytrain[start:end])).float().to(self.device)
                    s_var = Variable(torch.from_numpy(strain[start:end])).float().to(self.device)
                    y_var_hat = Variable(torch.from_numpy(yhat[start:end])).float().to(self.device)

                    logit_y_var_hat = torch.log(y_var_hat + 1e-6) - torch.log(1 - y_var_hat + 1e-6)
                    ratios = self.m_NN_r(self.m_NN_c(x_var))
                    Ypred_var0 = torch.sigmoid(ratios*logit_y_var_hat).detach() 
        
                    for l in range(50):
                        self.optimizer_s.zero_grad()
                        Spred_var = self.m_NN_s(Ypred_var0)
                        lossS = criterion_BCE(Spred_var, s_var)
                        lossS.backward()
                        self.optimizer_s.step()
                    
                    self.optimizer_r.zero_grad()
                    self.optimizer_c.zero_grad()
                    ratios = self.m_NN_r(self.m_NN_c(x_var))
                    Ypred_var = torch.sigmoid(ratios*logit_y_var_hat) 
                    
                    #import pdb; pdb.set_trace()
                    Spred_var = self.m_NN_s(Ypred_var) #concatenete with y for eq odds
                    
                    #################################
                    
                    lossY = criterion_BCE(Ypred_var, y_var) #loss for the accuracy
                    lossS = criterion_BCE(Spred_var, s_var) #loss for the adversary (fairness)
                    ratios = self.m_NN_r(self.m_NN_c(x_var))
                    loss_ratio = criterion_MSE(ratios, torch.ones(ratios.shape)) #MSE to force zero avg of ratio
                    
                    loss_sparse = torch.ones(self.num_concepts)
                    for i in range(self.num_concepts):
                        loss_sparse[i] = self.m_NN_c.fc1.weight[i].abs().sum()
                        
                    # Initialize a variable to store the total sum of cosine similarities
                    loss_diversity = 0.0

                    # Calculate cosine similarity between all pairs of neurons without duplicates
                    for i in range(self.num_concepts):
                        for j in range(i + 1, self.num_concepts):
                            # Get the weights for the i-th and j-th neurons
                            weights_i = self.m_NN_c.fc1.weight[i]
                            weights_j = self.m_NN_c.fc1.weight[j]

                            # Calculate cosine similarity
                            cosine_similarity = F.cosine_similarity(weights_i.unsqueeze(0), weights_j.unsqueeze(0), dim=1)

                            # Add the cosine similarity to the total sum
                            loss_diversity += cosine_similarity.abs()
                            
                    #import pdb; pdb.set_trace()
                    
                    loss = lossY
                    
                    if epoch >= 50:
                        loss = -self.lambda_fair*lossS + lossY 
                    if epoch >= 80:
                        loss = -self.lambda_fair*lossS + lossY + self.lambda_ratio*loss_ratio
                    if epoch >= 100:
                        loss = (-self.lambda_fair*lossS + lossY + self.lambda_ratio*loss_ratio + self.lambda_sparse*torch.sum(loss_sparse) +
                                self.lambda_diversity*loss_diversity)
                                        
                    loss.backward()
                    self.optimizer_r.step()
                    self.optimizer_c.step()

                    loss_fair_all.append(lossS.item())
                    loss_acc_all.append(lossY.item())
                    loss_ratio_all.append(loss_ratio.item())

                    loss_all.append(loss.item())

                if epoch % 200== 0:
                    ratios_train = self.m_NN_r(self.m_NN_c(torch.FloatTensor(X_train).to(self.device)))
                    ratios_train = ratios_train.flatten()
                    logit_y_hat = torch.log(torch.FloatTensor(y_hat)) - torch.log(1 - torch.FloatTensor(y_hat))
                    final_predictions_train = ((torch.sigmoid(ratios_train*logit_y_hat)).detach().cpu().data.numpy()>0.5).astype(int)
                    
                    if X_test is not None:
                        ratios_test = self.m_NN_r(self.m_NN_c(torch.FloatTensor(X_test).to(self.device)))
                        ratios_test = ratios_test.flatten()
                        logit_y_hat_test = torch.log(torch.FloatTensor(y_hat_test)) - torch.log(1 - torch.FloatTensor(y_hat_test))
                        final_predictions_test = ((torch.sigmoid(ratios_test*logit_y_hat_test).detach().cpu().data.numpy())>0.5).astype(int)
                                                
                        print('epoch', epoch, 'loss', loss.cpu().data.numpy(), 'loss_acc', lossY.cpu().data.numpy(), 
                              'loss_ratio', loss_ratio.cpu().data.numpy(), 'loss_fair', lossS.cpu().data.numpy(), 'loss_sparse', loss_sparse.cpu().data.numpy(),
                              'loss_diversity', loss_diversity.cpu().data.numpy(),'P-ruletest', p_rule(final_predictions_test,S_test),
                              'ACC_test',accuracy_score(y_test, final_predictions_test), 
                              'proportion of changes', np.mean(final_predictions_test!=(y_hat_test>0.5)) )
                        
                    else:
                        print('epoch', epoch, 'loss', loss.cpu().data.numpy(), 'loss_acc', lossY.cpu().data.numpy(),
                              'loss_ratio', loss_ratio.cpu().data.numpy(), 'loss_fair', lossS.cpu().data.numpy(),
                              'P-rule', p_rule(final_predictions_train, S_train), 'ACC_train',accuracy_score(y_train, final_predictions_train), 
                              'proportion of changes', np.mean(final_predictions!=(y_hat>0.5)))
            
            ratios_test = self.m_NN_r(self.m_NN_c(torch.FloatTensor(X_test).to(self.device)))
            ratios_test = ratios_test.flatten()
            logit_y_hat_test = torch.log(torch.FloatTensor(y_hat_test)) - torch.log(1 - torch.FloatTensor(y_hat_test))
            final_predictions_test = ((torch.sigmoid(ratios_test*logit_y_hat_test).detach().cpu().data.numpy())>0.5).astype(int)
            
            p_changes_test = np.mean(final_predictions_test!=(y_hat_test>0.5))
            indices_test = np.where(final_predictions_test!=(y_hat_test>0.5))[0].tolist()
            
            # Identify the changes
            changes = (final_predictions_test != (y_hat_test>0.5).astype(int))
            
            # Identify useful changes
            useful_changes = ((S_test == 0) & (final_predictions_test == 1) & ((y_hat_test>0.5).astype(int) == 0)) | \
                             ((S_test == 1) & (final_predictions_test == 0) & ((y_hat_test>0.5).astype(int) == 1))

            # Compute the number of changes and useful changes
            num_changes = np.sum(changes)
            num_useful_changes = np.sum(useful_changes)

            # Compute the proportion of useful changes
            proportion_useful_changes = num_useful_changes / num_changes if num_changes != 0 else 0

            print('P ratio is:', p_rule(final_predictions_test, S_test))
            print('Accuracy is:', accuracy_score(y_test, final_predictions_test)*100)
            print('Proportion of changes is:', p_changes_test)
            
            if plot_losses:
                plt.figure()
                df_losses = pd.DataFrame({'loss':loss_all, 'loss_acc':loss_acc_all, 'loss_fair':loss_fair_all, 'loss_post':loss_ratio_all})
                sns.lineplot(data=df_losses)
                plt.show()
                
            if write:
                
                # Calculate cosine similarity between all pairs of neurons without duplicates
                for i in range(self.num_concepts):
                    for j in range(i + 1, self.num_concepts):
                        # Get the weights for the i-th and j-th neurons
                        weights_i = self.m_NN_c.fc1.weight[i]
                        weights_j = self.m_NN_c.fc1.weight[j]

                        # Calculate cosine similarity
                        cosine_similarity = F.cosine_similarity(weights_i.unsqueeze(0), weights_j.unsqueeze(0), dim=1)
                
                #import pdb; pdb.set_trace()
                
                number_of_features = [((self.m_NN_c.fc1.weight[0].abs())>0.0).numpy().sum(), ((self.m_NN_c.fc1.weight[1].abs())>0.0).numpy().sum()]
                indices_features = [list(np.where((self.m_NN_c.fc1.weight[0].abs())>0.0)[0]), list(np.where((self.m_NN_c.fc1.weight[1].abs())>0.0)[0])]
                
                weights = [self.m_NN_c.fc1.weight[0][indices_features[0]].detach().numpy().tolist(), self.m_NN_c.fc1.weight[1][indices_features[1]].detach().numpy().tolist()]
                
                if ablation:      
                    write_to_csv_ratio_DP_autoencoder_nchanges('results/n_changes.csv', self.lambda_fair, self.lambda_ratio, p_rule(final_predictions_test, S_test), 
                                     accuracy_score(y_test, final_predictions_test), p_changes_test, proportion_useful_changes)
                else:
                    write_to_csv_ratio_DP_autoencoder('results/ratio_autoencoder_old.csv', self.lambda_fair, self.lambda_ratio, epoch, p_rule(final_predictions_test, S_test), 
                                     p_rule_diff(final_predictions_test,S_test), accuracy_score(y_test, final_predictions_test), p_changes_test, indices_test, 
                                                      cosine_similarity.detach().numpy()[0], number_of_features, indices_features, weights)
                
    def predict(self, X, prob_BB, threshold=0.5):
        ratios_test = self.m_NN_r(self.m_NN_c(torch.FloatTensor(X).to(self.device)))
        ratios_test = ratios_test.flatten()
        logit_y_hat_test = torch.log(torch.FloatTensor(prob_BB)) - torch.log(1 - torch.FloatTensor(prob_BB))
        final_predictions_test = ((torch.sigmoid(ratios_test*logit_y_hat_test).detach().cpu().data.numpy())>0.5).astype(int)
        return final_predictions_test
   
    def predict_proba(self, X, prob_BB):
        logit = torch.log(torch.FloatTensor(prob_BB)) - torch.log(1-torch.FloatTensor(prob_BB))
        ratios = self.m_NN_r(self.m_NN_c(torch.FloatTensor(X).to(self.device)))
        pred = torch.sigmoid(ratios.flatten()*logit).detach().cpu().data.numpy()
        return pred
    
    def predict_ratios(self, X):
        ratios = self.m_NN_r(self.m_NN_c(torch.FloatTensor(X).to(self.device)))
        return ratios.flatten().cpu().data.numpy()
    
    def predict_concepts(self, X):
        concepts = self.m_NN_c(torch.FloatTensor(X).to(self.device))
        return concepts.cpu().data.numpy()