import numpy as np 
import pickle
import torch.nn.functional as F
import pandas as pd
import torch

from numpy import mean
from copy import deepcopy
from scipy.stats import ttest_ind

from reward_trees.reward_trees import RewardTree, RewardNet
from data_utils import pre_process_dataset


def main():
    
    df_list = []

    for dataset in ['heloc', 'adult', 'german_credit']:
        for model in ['mlp', 'tree']:
            
            model_custom_path = f"models/{dataset}.{model}.10.0.custom"
            model_standard_path = f"models/{dataset}.{model}.10.0.normal"

            with open(model_custom_path, "rb") as f:
                model_custom = pickle.load(f)

            with open(model_standard_path, "rb") as f:
                model_standard = pickle.load(f)
                
            _, _, _, std, _ = pre_process_dataset(0, dataset, 1000, training=False)

            # Iterate through each row
            for row in model_custom.states:

                # Feature change for first scenario
                x = row.clone()
                x, x_prime = mutate_before(dataset, x, std)

                delta_x = x_prime - x
                custom_output_1 = model_custom(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))
                standard_output_1 = model_standard(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))

                # Feature change for second scenario
                x = row.clone()
                x, x_prime = mutate_after(dataset, x, std)

                delta_x = x_prime - x
                custom_output_2 = model_custom(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))
                standard_output_2 = model_standard(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))

                # Record the differences
                custom_costs = (custom_output_1 - custom_output_2).item()
                standard_costs = (standard_output_1 - standard_output_2).item()
                
                # Create dictionaries for the new rows
                custom_row = {
                    'dataset': dataset,
                    'model': model,
                    'data': custom_costs,
                    'prompt': 'custom'
                }

                standard_row = {
                    'dataset': dataset,
                    'model': model,
                    'data': standard_costs,
                    'prompt': 'standard'
                }

                # Append the new rows to the list
                df_list.append(custom_row)
                df_list.append(standard_row)
    
    # Convert the list of dictionaries to a DataFrame
    df = pd.concat([pd.DataFrame([row]) for row in df_list], ignore_index=True)
    
    # Save the DataFrame to a CSV file
    df.to_csv('data/desiderata3.csv', index=False)
    
        
def get_train_test_accuracy(the_model):
    costs = the_model(the_model.states, the_model.actions, the_model.next_states)
    accs = list()

    for l, d in (("TRAIN", the_model.preferences), ("TEST", the_model.test_preferences)):
        correct = []
        for p in d:
            c_i, c_j = costs[p.i].item(), costs[p.j].item()
            if p.y > 0.5:
                correct.append(1 if c_i > c_j else 0)
            elif p.y < 0.5:
                correct.append(1 if c_i < c_j else 0)
            else:
                correct.append(1 if c_i == c_j else 0)
        print(f"{l} ACC =", mean(correct))
        accs.append( mean(correct) )
    return accs    


def mutate_before(dataset, x, std):

    if dataset=='heloc':
        x[3] = 5.  
        x_prime = x.clone()
        x_prime[2] += 1  # 

    elif dataset=='adult':
        x[6] = 1.  # set to private company
        x_prime = x.clone()
        x_prime[5] += 12  # increase hours works by 12

    elif dataset=='german_credit':  # for bad credit history
        x[2:6] = torch.tensor([0.,1.,0.,0.])
        x[6:11] = torch.tensor([0.,0.,1.,0.,0.])
        x_prime = x.clone()
        x_prime[1] += std[4] / 2   # increase eduction by 1
    else:
        raise TypeError('wrong dataset name')

    return x, x_prime


def mutate_after(dataset, x, std):

    if dataset=='heloc':
        x[3] = 0.  
        x_prime = x.clone()
        x_prime[2] += 1  # 

    elif dataset=='adult':
        x[6] = 0.  # set to self-employed
        x_prime = x.clone()
        x_prime[5] += 12  # increase hours works by 12

    elif dataset=='german_credit':
        x[2:6] = torch.tensor([0.,0.,1.,0.])
        x[6:11] = torch.tensor([1.,0.,0.,0.,0.])
        x_prime = x.clone()
        x_prime[1] += std[4] / 2 # increase eduction by 1
    else:
        raise TypeError('wrong dataset name')

    return x, x_prime


if __name__ == '__main__':
    main()