import numpy as np 
import pickle
import torch.nn.functional as F
import pandas as pd
import torch

from numpy import mean
from copy import deepcopy
from scipy.stats import ttest_ind

from reward_trees.reward_trees import RewardTree, RewardNet
from data_utils import pre_process_dataset


def main():
    
    dataset = 'adult'
    model = 'mlp'
    
    for demographic in ['age', 'gender', 'race']:
        df_list = []

        model_custom_path = f"models/{dataset}.{model}.10.0.custom"
        model_standard_path = f"models/{dataset}.{model}.10.0.normal"

        with open(model_custom_path, "rb") as f:
            model_custom = pickle.load(f)

        with open(model_standard_path, "rb") as f:
            model_standard = pickle.load(f)

        _, _, _, std, _ = pre_process_dataset(0, dataset, 1000, training=False)

        # Iterate through each row
        for row in model_custom.states:

            # Feature change for first scenario
            x = row.clone()
            x, x_prime = mutate_before(dataset, x, std, demographic)

            delta_x = x_prime - x
            custom_output_1 = model_custom(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))
            standard_output_1 = model_standard(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))

            # Feature change for second scenario
            x = row.clone()
            x, x_prime = mutate_after(dataset, x, std, demographic)

            delta_x = x_prime - x
            custom_output_2 = model_custom(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))
            standard_output_2 = model_standard(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))

            # Record the differences
            custom_costs = (custom_output_1 - custom_output_2).item()
            standard_costs = (standard_output_1 - standard_output_2).item()

            # Create dictionaries for the new rows
            custom_row = {
                'dataset': dataset,
                'model': model,
                'data': custom_costs,
                'prompt': 'custom'
            }

            standard_row = {
                'dataset': dataset,
                'model': model,
                'data': standard_costs,
                'prompt': 'standard'
            }

            # Append the new rows to the list
            df_list.append(custom_row)
            df_list.append(standard_row)

        # Convert the list of dictionaries to a DataFrame
        df = pd.concat([pd.DataFrame([row]) for row in df_list], ignore_index=True)

        # Save the DataFrame to a CSV file
        df.to_csv('data/desiderata4_'+demographic+'.csv', index=False)
    
        
def get_train_test_accuracy(the_model):
    costs = the_model(the_model.states, the_model.actions, the_model.next_states)
    accs = list()

    for l, d in (("TRAIN", the_model.preferences), ("TEST", the_model.test_preferences)):
        correct = []
        for p in d:
            c_i, c_j = costs[p.i].item(), costs[p.j].item()
            if p.y > 0.5:
                correct.append(1 if c_i > c_j else 0)
            elif p.y < 0.5:
                correct.append(1 if c_i < c_j else 0)
            else:
                correct.append(1 if c_i == c_j else 0)
        print(f"{l} ACC =", mean(correct))
        accs.append( mean(correct) )
    return accs    


def mutate_before(dataset, x, std, demographic):

    if demographic == 'age':
        x[1] = 65.
    elif demographic == 'gender':
        x[0] = 0. 
    elif demographic == 'race':
        x[7] = 0. 
    else:
        raise TypeError('wrong dataset name')
        
    x_prime = x.clone()
    x_prime[4] += 1  # increase eduction by 1

    return x, x_prime


def mutate_after(dataset, x, std, demographic):

    if demographic == 'age':
        x[1] = 25.
    elif demographic == 'gender':
        x[0] = 1. 
    elif demographic == 'race':
        x[7] = 1. 
    else:
        raise TypeError('wrong dataset name')

    x_prime = x.clone()
    x_prime[4] += 1  # increase eduction by 1
        
    return x, x_prime


if __name__ == '__main__':
    main()