import pandas as pd
import numpy as np

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tikzplotlib
import pickle
from matplotlib.ticker import FormatStrFormatter

import sys
sys.path.insert(1, '/scratch/work/sinagam1/recourse-adaptive-preference-rebuttal/recourse-adaptive-preference/')

from methods.reup import q_determine

if __name__ == "__main__":

    german_bayesian = pd.read_pickle(r'mlp_german_bayesian_reup_graph_expt3.pickle')
    synthesis_bayesian = pd.read_pickle(r'mlp_synthesis_bayesian_reup_graph_expt3.pickle')
    student_bayesian = pd.read_pickle(r'mlp_student_bayesian_reup_graph_expt3.pickle')
    bank_bayesian = pd.read_pickle(r'mlp_bank_bayesian_reup_graph_expt3.pickle')

    synthesis_data_bayesian = synthesis_bayesian['data']
    synthesis_label_bayesian = synthesis_bayesian['label']
    positive_synthesis_data_bayesian = synthesis_data_bayesian[synthesis_label_bayesian == 1]

    german_data_bayesian = german_bayesian['data']
    german_label_bayesian = german_bayesian['label']
    positive_german_data_bayesian = german_data_bayesian[german_label_bayesian == 1]

    bank_data_bayesian = bank_bayesian['data']
    bank_label_bayesian = bank_bayesian['label']
    positive_bank_data_bayesian = bank_data_bayesian[bank_label_bayesian == 1]

    student_data_bayesian = student_bayesian['data']
    student_label_bayesian = student_bayesian['label']
    positive_student_data_bayesian = student_data_bayesian[student_label_bayesian == 1]

    T = 10
    TOP_K = 5
    ITERATIONS = 50
    x_axis = np.arange(T) + 1
 
    plt.rcParams.update({'font.size': 20})

    fig, axs = plt.subplots(1, 4, figsize=(35, 10))

    lst_mean_synthesis_bayesian = []
    lst_mean_german_bayesian = []
    lst_mean_bank_bayesian = []
    lst_mean_student_bayesian = []


    for t in range(T):
        tmp_synthesis = []
        tmp_german = []
        tmp_bank = []
        tmp_student = []

        for i in range(26):
            synthesis_x_0_bayesian = synthesis_bayesian['log'][i]['x_0'].reshape(-1, 1)
            
            tmp_synthesis_A_0_bayesian = synthesis_bayesian['log'][i]['A_0']
            synthesis_A_0_bayesian = np.eye(tmp_synthesis_A_0_bayesian.shape[0])
            np.fill_diagonal(synthesis_A_0_bayesian, tmp_synthesis_A_0_bayesian.diagonal())
            tmp_synthesis_sigma_bayesian = synthesis_bayesian['log'][i]['lst_Sigma'][t]
            synthesis_sigma_bayesian = np.eye(tmp_synthesis_sigma_bayesian.shape[0])
            np.fill_diagonal(synthesis_sigma_bayesian, tmp_synthesis_sigma_bayesian.diagonal())
            synthesis_m_bayesian = synthesis_bayesian['log'][i]['lst_m'][t]        
            

            german_x_0_bayesian = german_bayesian['log'][i]['x_0'].reshape(-1, 1)
            tmp_german_A_0_bayesian = german_bayesian['log'][i]['A_0']
            german_A_0_bayesian = np.eye(tmp_german_A_0_bayesian.shape[0])
            np.fill_diagonal(german_A_0_bayesian, tmp_german_A_0_bayesian.diagonal())
            tmp_german_sigma_bayesian = german_bayesian['log'][i]['lst_Sigma'][t]
            german_sigma_bayesian = np.eye(tmp_german_sigma_bayesian.shape[0])
            np.fill_diagonal(german_sigma_bayesian, tmp_german_sigma_bayesian.diagonal())
            german_m_bayesian = german_bayesian['log'][i]['lst_m'][t]
            

            bank_x_0_bayesian = bank_bayesian['log'][i]['x_0'].reshape(-1, 1)
            tmp_bank_A_0_bayesian = bank_bayesian['log'][i]['A_0']
            bank_A_0_bayesian = np.eye(tmp_bank_A_0_bayesian.shape[0])
            np.fill_diagonal(bank_A_0_bayesian, tmp_bank_A_0_bayesian.diagonal())
            tmp_bank_sigma_bayesian = bank_bayesian['log'][i]['lst_Sigma'][t]
            bank_sigma_bayesian = np.eye(tmp_bank_sigma_bayesian.shape[0])
            np.fill_diagonal(bank_sigma_bayesian, tmp_bank_sigma_bayesian.diagonal())
            bank_m_bayesian = bank_bayesian['log'][i]['lst_m'][t]

            student_x_0_bayesian = student_bayesian['log'][i]['x_0'].reshape(-1, 1)
            tmp_student_A_0_bayesian = student_bayesian['log'][i]['A_0']
            student_A_0_bayesian = np.eye(tmp_student_A_0_bayesian.shape[0])
            np.fill_diagonal(student_A_0_bayesian, tmp_student_A_0_bayesian.diagonal())
            tmp_student_sigma_bayesian = student_bayesian['log'][i]['lst_Sigma'][t]
            student_sigma_bayesian = np.eye(tmp_student_sigma_bayesian.shape[0])
            np.fill_diagonal(student_sigma_bayesian, tmp_student_sigma_bayesian.diagonal())
            student_m_bayesian = student_bayesian['log'][i]['lst_m'][t]
            
            mean_synthesis = q_determine.compute_mean_rank(positive_synthesis_data_bayesian, synthesis_x_0_bayesian, synthesis_A_0_bayesian, synthesis_sigma_bayesian, TOP_K)
            mean_german = q_determine.compute_mean_rank(positive_german_data_bayesian, german_x_0_bayesian, german_A_0_bayesian, german_sigma_bayesian, TOP_K)
            mean_student = q_determine.compute_mean_rank(positive_student_data_bayesian, student_x_0_bayesian, student_A_0_bayesian, student_sigma_bayesian, 1)
            mean_bank = q_determine.compute_mean_rank(positive_bank_data_bayesian, bank_x_0_bayesian, bank_A_0_bayesian, bank_sigma_bayesian, TOP_K)

            tmp_synthesis.append(mean_synthesis)
            tmp_german.append(mean_german)
            tmp_bank.append(mean_bank)
            tmp_student.append(mean_student)


        tmp_synthesis = np.array(tmp_synthesis)
        tmp_german = np.array(tmp_german)
        tmp_student = np.array(tmp_student)
        tmp_bank = np.array(tmp_bank)

        lst_mean_synthesis_bayesian.append(tmp_synthesis.mean())
        lst_mean_german_bayesian.append(tmp_german.mean())
        lst_mean_bank_bayesian.append(tmp_bank.mean())
        lst_mean_student_bayesian.append(tmp_student.mean())

    axs[0].set_title(r'Synthetic', fontsize=35)
    axs[1].set_title(r'German', fontsize=35)
    axs[2].set_title(r'Bank', fontsize=35)
    axs[3].set_title(r'Student', fontsize=35)

    axs[0].plot(x_axis, lst_mean_synthesis_bayesian, marker='o')
    axs[0].tick_params(axis='both', labelsize=27)
    axs[0].yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    axs[0].set_xlabel(r'$T$', fontsize=35)
    axs[0].set_ylabel('Mean Rank', fontsize=35)

    axs[1].plot(x_axis, lst_mean_german_bayesian, marker='o')
    axs[1].tick_params(axis='both', labelsize=27)
    axs[1].yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    axs[1].set_xlabel(r'$T$', fontsize=35)

    axs[2].plot(x_axis, lst_mean_bank_bayesian, marker='o')
    axs[2].tick_params(axis='both', labelsize=27)
    axs[2].yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    axs[2].set_xlabel(r'$T$', fontsize=35)
    
    axs[3].plot(x_axis, lst_mean_student_bayesian, marker='o')
    axs[3].tick_params(axis='both', labelsize=27)
    axs[3].yaxis.tick_right()
    axs[3].set_xlabel(r'$T$', fontsize=35)

    plt.savefig("Mean-rank.pdf", bbox_inches='tight')
    tikzplotlib.save("Mean-rank.tex")