import torch
import torch.nn.functional as F
import torchvision

import numpy as np
from scipy import special
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("whitegrid")

import csl
import csl.datasets

import os
import logging
import itertools
import sys
import pickle
import random

# Check if GPU is available
theDevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

exp_label = os.path.abspath('')[0:-3]
results_label = os.path.join('Results', exp_label)

####################################
# LOG                              #
####################################
# Set up logging to file...
logger = logging.getLogger(exp_label)

if not logger.hasHandlers():
    hfile = logging.FileHandler(exp_label + '.log', mode='w')
    formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
                                  datefmt='%m-%d %H:%M')
    hfile.setFormatter(formatter)
    logger.addHandler(hfile)

    # ... and console
    hconsole = logging.StreamHandler()
    formatter = logging.Formatter('%(levelname)-8s: %(message)s')
    hconsole.setFormatter(formatter)
    logger.addHandler(hconsole)

logger.setLevel(logging.INFO)

# seed everything
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

####################################
# FUNTIONS                         #
####################################
def accuracy(pred, y):
    correct = (pred == y).sum().item()
    return correct/pred.shape[0]


def fairness(x, model, cf_indices):
    x_cf = x.clone().detach()
    
    for idx1,idx2 in cf_indices:
        if idx1 == idx2:
            x_cf[:, idx1] = 1 - x[:, idx1]
        else:
            x_cf[:, [idx1, idx2]] = x[:, [idx2, idx1]]
    
    pred = model.predict(x)
    pred_cf = model.predict(x_cf.to(theDevice))
    
    fair = (pred == pred_cf).sum().item()
    
    return fair/pred.shape[0]

def fairness_single(x, model, cf_indices):
    x_cf = x.clone().detach()
    for idx1,idx2 in cf_indices:
        if idx1 == idx2:
            x_cf[idx1] = 1 - x[idx1]
        else:
            x_cf[[idx1, idx2]] = x[[idx2, idx1]]
    
    pred = model.predict(x)
    pred_cf = model.predict(x_cf.to(theDevice))
    
    fair = (pred == pred_cf).sum().item()
    
    return fair

solver_setting = csl.settings(iterations = 400,
                              batch_size = 256,
                              lr_p0 = 0.1,
                              lr_d0 = 2,
                              lr_d_decay = 0.5,
                              lr_d_period = 50,
                              lambdas0 = 1,
                              logger = logger)

genderxrace = False
#nn['unconstrained'] = {'problem': fairNN(trainset, n_features = 12),
#                        'solver': csl.TorchedPrimalDual(primal_solver = lambda params: torch.optim.Adam(params),
#                                                        dual_solver = None,  
#                                                      settings_ = solver_setting)}



# create list 23 to 4 descending by 1
nf_list = list(reversed(range(5,24)))
print(nf_list)

for seed in [4545, 5656, 6767, 7878]:
    nn = {}
    results = {}
    solution = {}
    testsets = {}
    trainsets = {}
    seed_everything(seed)
    torch.cuda.empty_cache()
    random_feature_selec_og = np.arange(23).tolist()
    np.random.shuffle(random_feature_selec_og) 
    for n_features in nf_list:
            

        random_feature_selec = np.array_split(random_feature_selec_og, n_features)

        random_feature_selection = torch.zeros([n_features, 23], dtype = torch.float, requires_grad=False)
        for i in range(n_features):
            random_feature_selection[i,random_feature_selec[i]] = 1

        random_feature_selection.to(theDevice)

        preprocess = torchvision.transforms.Compose([
            csl.datasets.utils.Drop(['age_cat', 'is_recid', 'is_violent_recid', 'score_text', 'v_score_text', 'v_decile_score', 'decile_score']),
            csl.datasets.utils.Recode('race', {'Other': ['Other', 'Asian', 'Native American']}),
            csl.datasets.utils.QuantileBinning('age', 5),
            csl.datasets.utils.Binning('priors_count', bins = [0,0.99,1,2,3,4,1000]),
            csl.datasets.utils.Binning('juv_misd_count', bins = [0,0.99,1,1000]),
            csl.datasets.utils.Binning('juv_other_count', bins = [0,0.99,1,1000]),
            csl.datasets.utils.Dummify(csl.datasets.COMPAS.categorical + ['age', 'priors_count', 'juv_misd_count', 'juv_other_count']),
            ])

        trainsets[f'average_'+str(n_features)] = csl.datasets.COMPAS(root = 'data', split=0.8, train = True,
                                    target_name = 'two_year_recid', preprocess = preprocess,
                                    transform = torchvision.transforms.Compose([csl.datasets.utils.ToTensor(dtype = torch.float), csl.datasets.utils.RandomLinearMap(random_feature_selection)]),
                                    target_transform = csl.datasets.utils.ToTensor(dtype = torch.long))
        
        testsets[f'average_'+str(n_features)] = csl.datasets.COMPAS(root = 'data', split=0.2, train = False,
                                    target_name = 'two_year_recid', preprocess = preprocess,
                                    transform = torchvision.transforms.Compose([csl.datasets.utils.ToTensor(dtype = torch.float), csl.datasets.utils.RandomLinearMap(random_feature_selection)]),
                                    target_transform = csl.datasets.utils.ToTensor(dtype = torch.long))

        fullset = csl.datasets.COMPAS(root = 'data', split=1, train = True,
                                    target_name = 'two_year_recid',
                                    preprocess = preprocess)

        # Classes
        var_names = fullset[0][0].columns
        gender_idx = [idx for idx, name in enumerate(var_names) if name.startswith('sex')]
        race_idx = [idx for idx, name in enumerate(var_names) if name.startswith('race')]


        # Fair classification problem
        class counterfactualFairness(csl.ConstrainedLearningProblem):
            def __init__(self, model, data, similarity, rhs = None, pointwise_rhs = None, genderxrace = True):
                self.model = model
                self.data = data
                self.data_size = len(data)
                self.genderxrace = genderxrace

                if rhs is not None:
                    # Gender
                    self.constraints = [self.CounterfactualFairness(self.data, self.model, [(gender_idx,gender_idx)], similarity, aggregate = True)]
                    self.rhs = [rhs]
                    
                    # Race
                    self.constraints += [self.CounterfactualFairness(self.data, self.model, [idx], similarity, aggregate = True) \
                                        for idx in itertools.combinations(race_idx,2)]
                    self.rhs += [rhs]*6
                    
                    # Gender x Race
                    if self.genderxrace:
                        self.constraints += [self.CounterfactualFairness(self.data, self.model, [(gender_idx,gender_idx),idx], similarity, aggregate = True) \
                                for idx in itertools.combinations(race_idx,2)]
                        self.rhs += [rhs]*6
                    
                if pointwise_rhs is not None:
                    # Gender
                    self.pointwise = [self.CounterfactualFairness(self.data, self.model, [(gender_idx,gender_idx)], similarity)]
                    self.pointwise_rhs = [pointwise_rhs*torch.ones(self.data_size, requires_grad = False)]
                    
                    # Race
                    self.pointwise += [self.CounterfactualFairness(self.data, self.model, [idx], similarity) \
                                        for idx in itertools.combinations(race_idx,2)]
                    self.pointwise_rhs += [pointwise_rhs*torch.ones(self.data_size, requires_grad = False) for rep in range(6)]
                    
                    # Gender x Race
                    if self.genderxrace:
                        self.pointwise += [self.CounterfactualFairness(self.data, self.model, [(gender_idx,gender_idx),idx], similarity) \
                                for idx in itertools.combinations(race_idx,2)]
                        self.pointwise_rhs += [pointwise_rhs*torch.ones(self.data_size, requires_grad = False) for rep in range(6)]

            class CounterfactualFairness:
                def __init__(self, data, model, cf_indices, similarity, aggregate = False):
                    self.data = data
                    self.model = model
                    self.cf_indices = cf_indices
                    self.aggregate = aggregate
                    self.similarity = similarity
                    
                def __call__(self, batch_idx):
                    x, y, _ = self.data[batch_idx]
                    # select features of x

                    x_cf = x.clone().detach()
                    
                    if len(y.shape) == 0:
                        x = x.unsqueeze(0)
                        x_cf = x_cf.unsqueeze(0)
                    
                    for idx1,idx2 in self.cf_indices:
                        if idx1 == idx2:
                            x_cf[:, idx1] = 1 - x[:, idx1]
                        else:
                            x_cf[:, [idx1, idx2]] = x[:, [idx2, idx1]]
                        
                    yhat = self.model(x.to(theDevice))
                    yhat_cf = self.model(x_cf.to(theDevice))
                    
                    if self.aggregate:
                        return self.similarity(yhat, yhat_cf).mean()
                    else:
                        return self.similarity(yhat, yhat_cf)
        


        #%% ################################
        # NEURAL NETWORK                   #
        ####################################
        class fairNN(counterfactualFairness):
            def __init__(self, data, rhs = None, pointwise_rhs = None, genderxrace = True, n_features = 23):
                super(fairNN, self).__init__(self.NN(n_features).float().to(theDevice),
                                            data,
                                            lambda p,q: (F.softmax(p, dim=1) * F.log_softmax(p, dim=1)).sum(dim=1) \
                                                - (F.softmax(p, dim=1) * F.log_softmax(q, dim=1)).sum(dim=1),
                                            rhs, pointwise_rhs, genderxrace)

                self.parameters = list(self.model.parameters())
                self.obj_function = self.loss
                
            # Model
            class NN(torch.nn.Module):
                def __init__(self, n_features):
                    super(fairNN.NN, self).__init__()
                    self.fc1 = torch.nn.Linear(n_features, 64)
                    self.fc2 = torch.nn.Linear(64, 2)
            
                def forward(self, x):
                    x = torch.sigmoid(self.fc1(x))
                    x = self.fc2(x)
            
                    return x

                def predict(self, x):
                    if len(x.shape) == 1:
                        x = x.unsqueeze(0)
                    _, predicted = torch.max(self(x), 1)
                    return predicted
                
            def loss(self, batch_idx):
                # Evaluate objective
                x, y, _ = self.data[batch_idx]
                yhat = self.model(x.to(theDevice))
                
                return F.cross_entropy(yhat, y.to(theDevice))
            

        nn[f'average_'+str(n_features)] = {'problem': fairNN(trainsets[f'average_'+str(n_features)], rhs = 1e-3, genderxrace = genderxrace, n_features = n_features),
                        'solver': csl.TorchedPrimalDual(primal_solver = lambda params: torch.optim.Adam(params),
                                                        dual_solver = lambda params: torch.optim.Adam(params),
                                                        settings_ = solver_setting)}

    for key, value in nn.items():
        x_test, y_test, _ = testsets[key][:]
        print(f"Training model {key}")
        value['solver'].solve(value['problem'])
        #value['solver'].plots()
        solution[key] = {'model': value['problem'].model,
                        'lambdas': value['problem'].lambdas,
                        'mus': value['problem'].mus}
        
        # Test
        results[key] = {}
        print(f'Model: {key}')
        with torch.no_grad():
            yhat = solution[key]['model'].predict(x_test.to(theDevice))
            acc_test = accuracy(yhat, y_test.to(theDevice))
            print(f'Test accuracy: {100*acc_test:.2f}')
            results[key]['Test Accuracy'] = 100*acc_test
            
            fair_test = fairness(x_test.to(theDevice), solution[key]['model'], [(gender_idx,gender_idx)])
            lambda_value = solution[key]['lambdas']
            if lambda_value:
                print(f'Gender: {100*fair_test:.2f} / lambda = {lambda_value[0]}')
                results[key]['Gender'] =  100*fair_test
            else:
                print(f'Gender: {100*fair_test:.2f}')
                results[key]['Gender'] =  100*fair_test

            for ii, idx in enumerate(itertools.combinations(race_idx, 2)):
                fair_test = fairness(x_test.to(theDevice), solution[key]['model'], [idx])
                if lambda_value:
                    print(f'{var_names[idx[0]]} <-> {var_names[idx[1]]}: {100*fair_test:.2f} / lambda = {lambda_value[ii+1]}')
                    results[key][f'{var_names[idx[0]]} <-> {var_names[idx[1]]}'] = 100*fair_test
                else:
                    print(f'{var_names[idx[0]]} <-> {var_names[idx[1]]}: {100*fair_test:.2f}')
                    results[key][f'{var_names[idx[0]]} <-> {var_names[idx[1]]}'] = 100*fair_test
            
            if genderxrace:
                for ii, idx in enumerate(itertools.combinations(race_idx, 2)):
                    fair_test = fairness(x_test.to(theDevice), solution[key]['model'], [(gender_idx, gender_idx),idx])
                    if lambda_value:
                        print(f'Gender x {var_names[idx[0]]} <-> {var_names[idx[1]]}: {100*fair_test:.2f} / lambda = {lambda_value[ii+7]}')
                        results[key][f'Gender x {var_names[idx[0]]} <-> {var_names[idx[1]]}' ] = 100*fair_test
                    else:
                        print(f'Gender x {var_names[idx[0]]} <-> {var_names[idx[1]]}: {100*fair_test:.2f}')
                        results[key][f'Gender x {var_names[idx[0]]} <-> {var_names[idx[1]]}' ] = 100*fair_test
            
        results[key]['slacks_evol'] = np.array(value['solver'].slack_evolution)

    # save results
    torch.save(results, 'results'+str(seed)+'.pt')
