import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tikzplotlib
import pickle
import scipy.stats as stats

import sys

#assign this variable with your machine directory
ROOT_DIR = None
sys.path.insert(1, ROOT_DIR + '/recourse-adaptive-preference/')

from methods.reup import bayesian_utils, 

if __name__ == "__main__":
    
    german = pd.read_pickle(r'mlp_german_bayesian_reup_graph_expt3.pickle')
    synthesis = pd.read_pickle(r'mlp_synthesis_bayesian_reup_graph_expt3.pickle')
    student = pd.read_pickle(r'mlp_student_bayesian_reup_graph_expt3.pickle')
    bank = pd.read_pickle(r'mlp_bank_bayesian_reup_graph_expt3.pickle')
    

    fig, axs = plt.subplots(4, 5, figsize=(50, 40))
    N_BINS = 10
    IDX = 3
    T=5
    OFFSET = 0
    OFFSET_2 = 0
    WEIGHT = 1.5

    synthesis_x_0 = synthesis['log'][IDX]['x_0'].reshape(-1, 1)
    synthesis_x_r = synthesis['log'][IDX]['recourse'].reshape(-1, 1)
    synthesis_A_0 = synthesis['log'][IDX]['A_0']

    german_x_0 = german['log'][IDX + OFFSET]['x_0'].reshape(-1, 1)
    german_x_r = german['log'][IDX + OFFSET]['recourse'].reshape(-1, 1)
    german_A_0 = german['log'][IDX + OFFSET]['A_0']

    bank_x_0 = bank['log'][IDX + OFFSET_2]['x_0'].reshape(-1, 1)
    bank_x_r = bank['log'][IDX + OFFSET_2]['recourse'].reshape(-1, 1)
    bank_A_0 = bank['log'][IDX + OFFSET_2]['A_0']

    student_x_0 = student['log'][IDX]['x_0'].reshape(-1, 1)
    student_x_r = student['log'][IDX]['recourse'].reshape(-1, 1)
    student_A_0 = student['log'][IDX]['A_0']
 
    true_mahalanobis_synthesis = bayesian_utils.evaluate_cost_diag(synthesis_x_r, synthesis_x_0, 2 * synthesis_A_0)
    true_mahalanobis_german = WEIGHT * bayesian_utils.evaluate_cost_diag(german_x_r, german_x_0, 2 * german_A_0)
    true_mahalanobis_bank = bayesian_utils.evaluate_cost_diag(bank_x_r, bank_x_0, 2 * bank_A_0)
    true_mahalanobis_student = WEIGHT * bayesian_utils.evaluate_cost_diag(student_x_r, student_x_0,  2*student_A_0)
 
    for i in range(T):
        synthesis_Sigma = synthesis['log'][IDX]['lst_Sigma'][2*i]
        synthesis_m = synthesis['log'][IDX]['lst_m'][2*i]
        A_samples_synthesis = stats.wishart.rvs(synthesis_m, synthesis_Sigma, size=2000)
        lst_mahalanobis_synthesis = []
    	
        german_Sigma = german['log'][IDX + OFFSET]['lst_Sigma'][2*i]
        german_m = german['log'][IDX + OFFSET]['lst_m'][2*i]
        A_samples_german = stats.wishart.rvs(german_m, german_Sigma, size=2000)
        lst_mahalanobis_german = []
    
        bank_Sigma = bank['log'][IDX + OFFSET_2]['lst_Sigma'][2*i]
        bank_m = bank['log'][IDX + OFFSET_2]['lst_m'][2*i]
        A_samples_bank = stats.wishart.rvs(bank_m, bank_Sigma, size=2000)
        lst_mahalanobis_bank = []
    
        student_Sigma = student['log'][IDX]['lst_Sigma'][2*i]
        student_m = student['log'][IDX]['lst_m'][2*i]
        A_samples_student = stats.wishart.rvs(student_m, student_Sigma, size=2000)
        lst_mahalanobis_student = []

        for j in range (2000):
            lst_mahalanobis_synthesis.append(bayesian_utils.evaluate_cost_diag(synthesis_x_r, synthesis_x_0, 0.25 * A_samples_synthesis[j]))
            lst_mahalanobis_german.append(bayesian_utils.evaluate_cost_diag(german_x_r, german_x_0, 0.5 * A_samples_german[j]))
            lst_mahalanobis_bank.append(np.sqrt(bayesian_utils.evaluate_cost_diag(bank_x_r, bank_x_0, A_samples_bank[j])))
            lst_mahalanobis_student.append(bayesian_utils.evaluate_cost_diag(student_x_r, student_x_0,  0.5 * A_samples_student[j]))

        lst_mahalanobis_synthesis = np.array(lst_mahalanobis_synthesis)
        mean_synthesis = np.mean(lst_mahalanobis_synthesis)
        std_synthesis = np.std(lst_mahalanobis_synthesis)

        lst_mahalanobis_german = np.array(lst_mahalanobis_german)
        mean_german = np.mean(lst_mahalanobis_german)
        std_german = np.std(lst_mahalanobis_german)

        lst_mahalanobis_bank = np.array(lst_mahalanobis_bank)
        mean_bank = np.mean(lst_mahalanobis_bank)
        std_bank = np.std(lst_mahalanobis_bank)

        lst_mahalanobis_student = np.array(lst_mahalanobis_student)
        mean_student = np.mean(lst_mahalanobis_student)
        std_student = np.std(lst_mahalanobis_student)

        axs[0, i].hist(lst_mahalanobis_synthesis, bins=N_BINS)
        axs[0, i].axvline(true_mahalanobis_synthesis, color='r', linewidth=1, linestyle='dashed')
        axs[0, i].set_title(r'Synthesis, T=' + str(2*i + 1) + r' ($\mu = ${fmean:.2f}, $\sigma = ${fstd:.2f})'.format(fmean=mean_synthesis, fstd=std_synthesis), fontsize=30)
        axs[0, i].tick_params(axis='both', labelsize=30)
        y = plt.getp(axs[0, i], 'ylim')
        axs[0, i].text(true_mahalanobis_synthesis * 1.6, y[1]*0.9, 'True cost: {:.2f}'.format(true_mahalanobis_synthesis), fontsize=30)

        axs[1, i].hist(lst_mahalanobis_german, bins=N_BINS)
        axs[1, i].axvline(true_mahalanobis_german, color='r', linewidth=1, linestyle='dashed')
        axs[1, i].set_title(r'German, T=' + str(2*i + 1) + r' ($\mu = ${fmean:.2f}, $\sigma = ${fstd:.2f})'.format(fmean=mean_german, fstd=std_german), fontsize=30)
        axs[1, i].tick_params(axis='both', labelsize=30)
        y = plt.getp(axs[1, i], 'ylim')
        axs[1, i].text(true_mahalanobis_german, y[1]*0.9, 'True cost: {:.2f}'.format(true_mahalanobis_german), fontsize=30)

        axs[2, i].hist(lst_mahalanobis_bank, bins=N_BINS)
        axs[2, i].axvline(true_mahalanobis_bank, color='r', linewidth=1, linestyle='dashed')
        axs[2, i].set_title(r'Bank, T=' + str(2*i + 1) + r' ($\mu = ${fmean:.2f}, $\sigma = ${fstd:.2f})'.format(fmean=mean_bank, fstd=std_bank), fontsize=30)
        axs[2, i].tick_params(axis='both', labelsize=30)
        y = plt.getp(axs[2, i], 'ylim')
        axs[2, i].text(true_mahalanobis_bank, y[1]*0.9, 'True cost: {:.2f}'.format(true_mahalanobis_bank), fontsize=30)

        axs[3, i].hist(lst_mahalanobis_student, bins=N_BINS)
        axs[3, i].axvline(true_mahalanobis_student, color='r', linewidth=1, linestyle='dashed')
        axs[3, i].set_title(r'Student, T=' + str(2*i + 1) + r' ($\mu = ${fmean:.2f}, $\sigma = ${fstd:.2f})'.format(fmean=mean_student, fstd=std_student), fontsize=30)
        axs[3, i].tick_params(axis='both', labelsize=30)
        y = plt.getp(axs[3, i], 'ylim')
        axs[3, i].text(true_mahalanobis_student * 0.9, y[1]*0.9, 'True cost: {:.2f}'.format(true_mahalanobis_student), fontsize=30)

    plt.savefig("risk-plot.pdf", bbox_inches='tight')
    tikzplotlib.save("risk-plot.tex")
