import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np 
import matplotlib.pyplot as plt
import pickle
import torch.nn.functional as F
import pandas as pd

from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from numpy import mean
from copy import deepcopy
from matplotlib.table import Table

from data_utils import pre_process_dataset
from reward_trees.reward_trees import RewardTree, RewardNet


NUM_LINKS = 10


def main():
    results = []

    for dataset in ['heloc', 'adult', 'german_credit']:
        for prompt in ['normal', 'custom']:
            for model_type in ['mlp', 'tree']:
                model_name = f"{dataset}.{model_type}.10.0.{prompt}"

                with open(f"models/{model_name}", "rb") as f:
                    cost_model = pickle.load(f)

                _, test_acc = get_train_test_accuracy(cost_model)
                results.append({
                    'test_acc': test_acc,
                    'dataset': dataset,
                    'prompt': prompt,
                    'model': model_name
                })

    df = pd.DataFrame(results)
    df.to_csv('data/model_test_accuracies_'+str(NUM_LINKS)+'.csv', index=False)



def get_train_test_accuracy(the_model):
    costs = the_model(the_model.states, the_model.actions, the_model.next_states)
    accs = list()

    for l, d in (("TRAIN", the_model.preferences), ("TEST", the_model.test_preferences)):
        correct = []
        for p in d:
            c_i, c_j = costs[p.i].item(), costs[p.j].item()
            if p.y > 0.5:
                correct.append(1 if c_i > c_j else 0)
            elif p.y < 0.5:
                correct.append(1 if c_i < c_j else 0)
            else:
                correct.append(1 if c_i == c_j else 0)
        print(f"{l} ACC =", mean(correct))
        accs.append( mean(correct) )
    return accs    


if __name__ == '__main__':
    main()