import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np 
import matplotlib.pyplot as plt
import pickle
import torch.nn.functional as F
import pandas as pd

from scipy.stats import ttest_ind
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from numpy import mean
from copy import deepcopy
from matplotlib.table import Table
from scipy.stats import ttest_ind

from data_utils import pre_process_dataset
from reward_trees.reward_trees import RewardTree, RewardNet


def main():

    df = pd.DataFrame(columns=['dataset', 'model', 'data_custom', 'data_standard', 'feature'])

    for dataset in ['heloc', 'adult', 'german_credit']:
        for model in ['mlp', 'tree']:

            model_custom_path = f"models/{dataset}.{model}.10.0.custom"
            model_standard_path = f"models/{dataset}.{model}.10.0.normal"

            with open(model_custom_path, "rb") as f:
                model_custom = pickle.load(f)

            with open(model_standard_path, "rb") as f:
                model_standard = pickle.load(f)
                
            _, _, _, std, _ = pre_process_dataset(0, dataset, 1000, training=False)

            if dataset == 'heloc':
                numerical_feature_idxs = [0,1,2,3]
                feature_names = ['MSinceMostRecentInqexcl7days', 'NumRevolvingTradesWBalance', 'NumTradesOpeninLast12M', 'NumInqLast6M']
            elif dataset == 'adult':
                numerical_feature_idxs = [1,4,5]
                feature_names = ['isMale', 'age', 'native-country-United-States',
           'marital-status-Married', 'education-num', 'hours-per-week', 'workclass-Private', 'isWhite']
            elif dataset == 'german_credit':
                numerical_feature_idxs = [0,1]
                std = [std[1], std[4]]
                feature_names = ['duration', 'amount']
            else:
                raise TypeError('wrong dataset name')

            custom_data = list()
            standard_data = list()

            # iterate each numerical feature
            for numerical_feature_idx in numerical_feature_idxs:

                custom_people = list()
                standard_people = list()

                # Iterate through each row
                for row in model_custom.states:

                    custom_feature = list()
                    standard_feature = list()

                    min_value = model_custom.states.T[numerical_feature_idx].min()
                    max_value = model_custom.states.T[numerical_feature_idx].max()

                    for i in range(min_value.int().item(), max_value.int().item(), max_value.int().item()//10):

                        # Feature change for first scenario
                        x = row.clone()
                        x[numerical_feature_idx] = round(i)

                        x_prime = x.clone()
                        x_prime[numerical_feature_idx] += round(std[numerical_feature_idx]/2)  # TODO: change for german credit

                        delta_x = x_prime - x
                        custom_output   = model_custom(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))
                        standard_output = model_standard(x.unsqueeze(0), delta_x.unsqueeze(0), x_prime.unsqueeze(0))

                        # Record the differences
                        custom_feature.append((custom_output).item())
                        standard_feature.append((standard_output).item())

                    custom_people.append(custom_feature)
                    standard_people.append(standard_feature)

                # Create a dictionary for the new row
                temp_row = {
                    'dataset': dataset,
                    'model': model,
                    'data_custom': custom_people,
                    'data_standard': standard_people,
                    'feature': feature_names[numerical_feature_idx],
                }

                # Append the new row to the DataFrame
                df = df.append(temp_row, ignore_index=True)
    
    # Save the DataFrame to a CSV file
    df.to_csv('data/desiderata2.csv', index=False)
                      

def get_train_test_accuracy(the_model):
    costs = the_model(the_model.states, the_model.actions, the_model.next_states)
    accs = list()
        
    for l, d in (("TRAIN", the_model.preferences), ("TEST", the_model.test_preferences)):
        correct = []
        for p in d:
            c_i, c_j = costs[p.i].item(), costs[p.j].item()
            if p.y > 0.5:
                correct.append(1 if c_i > c_j else 0)
            elif p.y < 0.5:
                correct.append(1 if c_i < c_j else 0)
            else:
                correct.append(1 if c_i == c_j else 0)
        print(f"{l} ACC =", mean(correct))
        accs.append( mean(correct) )
    return accs    


if __name__ == '__main__':
    main()
