import torch
import numpy as np 
import pickle
import random
import pandas as pd

from torch.utils.data import DataLoader, TensorDataset
from numpy import mean
from copy import deepcopy
from scipy.stats import ttest_ind

from data_utils import pre_process_dataset
from reward_trees.reward_trees import RewardTree, RewardNet


def main(): 
    
    df_list = []
    
    for dataset in ['heloc', 'adult', 'german_credit']:
        for model in ['mlp', 'tree']:
    
            model_custom_path = f"models/{dataset}.{model}.10.0.custom"
            model_standard_path = f"models/{dataset}.{model}.10.0.normal"

            with open(model_custom_path, "rb") as f:
                model_custom = pickle.load(f)

            with open(model_standard_path, "rb") as f:
                model_standard = pickle.load(f)
                
            _, _, means, std, _ = pre_process_dataset(0, dataset, 1000, training=False)

            if dataset == 'heloc':
                numerical_feature_idxs = [0,1,2,3]
                feature_names = ['MSinceMostRecentInqexcl7days', 'NumRevolvingTradesWBalance', 'NumTradesOpeninLast12M', 'NumInqLast6M']
            elif dataset == 'adult':
                numerical_feature_idxs = [1,4,5]
                feature_names = ['isMale', 'age', 'native-country-United-States',
            'marital-status-Married', 'education-num', 'hours-per-week', 'workclass-Private', 'isWhite']
            elif dataset == 'german_credit':
                numerical_feature_idxs = [0,1]
                categorical_slices = [[2,6], [6,11], [11,21]]
                std = [std[1], std[4]]
                feature_names = ['duration', 'amount', 'status', 'credit_history', 'purpose']
            else:
                raise TypeError('wrong dataset name')
                                
            custom_data = list()
            standard_data = list()

            # iterate each feature
            for feature_idx in range(len(feature_names)):

                custom_people = list()
                standard_people = list()

                # Iterate through each row
                for row in model_custom.states:

                    if feature_idx not in numerical_feature_idxs:
                        if dataset == 'german_credit':
                            x = row.clone()
                            x_prime = x.clone()

                            slice_values = x_prime[categorical_slices[feature_idx-len(numerical_feature_idxs)][0]:categorical_slices[feature_idx-len(numerical_feature_idxs)][1]]
                            x_prime[categorical_slices[feature_idx-len(numerical_feature_idxs)][0]:categorical_slices[feature_idx-len(numerical_feature_idxs)][1]] = move_one_to_new_position(slice_values)

                            delta_x = x_prime - x
                            custom_output   = model_custom(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))
                            standard_output = model_standard(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))
                        else:
                            x = row.clone()
                            x_prime = x.clone()
                            x_prime[feature_idx] -= abs( 1. ) 
                            delta_x = x_prime - x
                            custom_output   = model_custom(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))
                            standard_output = model_standard(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))

                    else:
                        x = row.clone()
                        x_prime = x.clone()
                        x_prime[feature_idx] += round(std[feature_idx])
                        delta_x = x_prime - x
                        custom_output   = model_custom(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))
                        standard_output = model_standard(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))

                    # Record the differences
                    custom_people.append(custom_output.item())
                    standard_people.append(standard_output.item())

                # Create a dictionary for the new row
                temp_row = {
                    'dataset': dataset,
                    'model': model,
                    'data_custom': custom_people,
                    'standard_mean': sum(standard_people)/len(standard_people),
                    'custom_mean': sum(custom_people)/len(custom_people),
                    'data_standard': standard_people,
                    'feature': feature_names[feature_idx],
                }

                # Append the new row to the list of DataFrames
                df_list.append(pd.DataFrame([temp_row]))
    
    # Concatenate all DataFrames in the list
    df = pd.concat(df_list, ignore_index=True)
    
    # Save the DataFrame to a CSV file
    df.to_csv('data/desiderata1.csv', index=False)

    
def get_train_test_accuracy(the_model):
    costs = the_model(the_model.states, the_model.actions, the_model.next_states)
    accs = list()
        
    for l, d in (("TRAIN", the_model.preferences), ("TEST", the_model.test_preferences)):
        correct = []
        for p in d:
            c_i, c_j = costs[p.i].item(), costs[p.j].item()
            if p.y > 0.5:
                correct.append(1 if c_i > c_j else 0)
            elif p.y < 0.5:
                correct.append(1 if c_i < c_j else 0)
            else:
                correct.append(1 if c_i == c_j else 0)
        print(f"{l} ACC =", mean(correct))
        accs.append( mean(correct) )
    return accs    


def move_one_to_new_position(tensor):
    # Find the current position of the 1
    current_position = torch.argmax(tensor).item()

    # Generate a new random position that is different from the current position
    new_position = current_position
    while new_position == current_position:
        new_position = random.randint(0, len(tensor) - 1)

    # Set the current position to 0 and the new position to 1
    tensor[current_position] = 0.
    tensor[new_position] = 1.

    return tensor


if __name__ == '__main__':
    main()

