from simulation.optim_utils import get_Py1_new, get_Px_new, get_kernel_from_policy
from simulation.data_fico import get_Py1x_given,  get_Ps_given
import numpy as np
import os




def run_simulation_diff_setups(setup_list, steps, num_cat, kernel, Pd1):
    Py1x = get_Py1x_given(num_cat)

    # alpha is P(X=x|S=s)
    Px_lists = []
    # beta is P(Y=1|S=s)
    Py1_lists = []
    EOP_lists = []

    # setup list are p(X=x|S)
    for t in range(len(setup_list)):

        # P(X=x|S=s)
        Px_evolution = []
        # P(Y=1|S=s)
        Py1_evolution = []
        # P(D=1|Y=1, S=s)
        EOP_evolution = []

        # initial distribution
        # P(X_0=x|S=s)
        _Px = setup_list[t]
        # P(Y=1|S=s) = \sum P(Y=1|X=x, S=s) P(X=x|S=s)
        _Py1 = get_Py1_new(_Px, Py1x, num_cat)

        Px_evolution.append(_Px)
        Py1_evolution.append(_Py1)

        for _ in range(steps):
            _Px = get_Px_new(kernel, _Px)
            _Py1 = get_Py1_new(_Px, Py1x, num_cat)

            Px_evolution.append(_Px)
            Py1_evolution.append(_Py1)

        _Px_list = permutate_list(Px_evolution)
        _Py1_list = permutate_list(Py1_evolution)

        Px_lists.append(_Px_list)
        Py1_lists.append(_Py1_list)

    return Px_lists, Py1_lists

def permutate_list(mylist):
    return np.moveaxis(np.array(mylist), 0, -1)

# read a np array from a csv file called file_name
def read_csv(file_name):
    return np.genfromtxt(file_name, delimiter=',')


def plot_equilibria_Py1_setups(beta_lists, exp_name, legend, setups_dict={}):
    setup_len = len(beta_lists)
    sns.set_palette("colorblind")
    plt.figure()
    plt.axis('square')
    font_size = 30
    label_size = 20
    colors = sns.color_palette("colorblind", n_colors=setup_len)

    if legend:
        assert setups_dict != {}, "setups_dict is empty"
    else:
        # create a setups_dict, where for each setup, the name is setup_i
        for i in range(len(beta_lists)):
            setups_dict[i] = f"setup_{i}"

    for i in range(setup_len):
        plt.plot(beta_lists[i][0][:-1], beta_lists[i][1][:-1], 'o-', color=colors[i], label=f"{setups_dict[i]}", linewidth=3, markersize=10)
        plt.plot(beta_lists[i][0][-1], beta_lists[i][1][-1], '-*', c='black', markersize=30)

    plt.xlabel(f'$P(Y=1|S=0)$', fontsize=font_size)
    plt.ylabel(f'$P(Y=1|S=1)$', fontsize=font_size)

    plt.plot([0, 1], [0, 1], '--', c='black')
    plt.tick_params(axis='both', which='major', labelsize=label_size)
    # set y ticks
    plt.yticks([0.6, 0.8, 1])
    plt.ylim(0.4, 1)
    plt.xlim(0.4, 1)
    if legend:
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=font_size)
    # plt.legend(bbox_to_anchor=(2, 1))
    plt.grid(color='grey', linestyle='-', linewidth=0.25, alpha=0.5)
    # plt.title('Qualification rates', fontsize=font_size)
    my_path = f'../figures/{exp_name}'
    os.makedirs(my_path, exist_ok=True)
    plt.savefig(f'../figures/{exp_name}/equilibria_Py1.png', bbox_inches='tight')
    print(f'$P(Y=1|S=0) = {beta_lists[i][0][-1].round(4)}$')
    print(f'$P(Y=1|S=1) = {beta_lists[i][1][-1].round(4)}$')
    plt.show()



def plot_equilibria_Px_setups(alpha_lists, exp_name, num_cat, legend, setups_dict={}):
    # Px_lists, exp_name, num_cat, legend = False
    # assert if legend then setups_dict is not empty.
    if legend:
        assert setups_dict != {}, "setups_dict is empty"
    else:
        # create a setups_dict, where for each setup, the name is setup_i
        for i in range(len(alpha_lists)):
            setups_dict[i] = f"setup_{i}"



    setup_len = len(alpha_lists)
    sns.set_palette("colorblind")
    colors = sns.color_palette("colorblind", n_colors=setup_len)
    ncols = num_cat
    nrows = 1
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*5, nrows*5))
    fig.tight_layout(pad=1)
    font_size = 30
    label_font_size = 20

    x = 0

    # print(alpha_lists[0][0][0][:-1])
    # print(alpha_lists[0][1][0][:-1])
    for col in range(ncols):
        for i in range(setup_len):
            # setup, s, x, time

            axes[col].plot(alpha_lists[i][0][x][:-1], alpha_lists[i][1][x][:-1], 'o-', color=colors[i], linewidth=3, markersize=10, label=f"{setups_dict[i]}")
            axes[col].plot(alpha_lists[i][0][x][-1], alpha_lists[i][1][x][-1], '-*', c='black', linewidth=3, markersize=40)

        if col == 0:
            axes[col].set_ylabel(f'$P(X=x|S=1)$', fontsize=font_size)

        axes[col].set_xlabel(f'$P(X=x|S=0)$', fontsize=font_size)
        print(f'$P(X=x|S=0) = {alpha_lists[i][0][x][-1].round(4)}$', f'$P(X=x|S=1) = {alpha_lists[i][1][x][-1].round(4)}$')

        axes[col].set_xlim(0, 0.8)
        axes[col].set_ylim(0, 0.8)
        axes[col].set_title(f'$x ={x+1}$', fontsize=font_size)

        axes[col].tick_params(axis='both', which='major', labelsize=label_font_size)
        axes[col].set_xticks([0.2, 0.4, 0.6, 0.8])
        axes[col].set_yticks([0.2, 0.4, 0.6, 0.8])

        x += 1

        axes[col].plot([0, 1], [0, 1], '--', c='black')

    if legend:
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=font_size)
    # plt.legend(bbox_to_anchor=(2, 1))
    # plot grey grid
    for i in range(num_cat):
        axes[i].grid(color='grey', linestyle='-', linewidth=0.25, alpha=0.5)
    # plt.grid(color='grey', linestyle='-', linewidth=0.25, alpha=0.5)

    axes[1].set_yticklabels([])
    axes[2].set_yticklabels([])
    axes[3].set_yticklabels([])

    my_path = f'../figures/{exp_name}'
    os.makedirs(my_path, exist_ok=True)
    plt.savefig(f'../figures/{exp_name}/equilibria_Px.png', bbox_inches='tight')
    plt.show()

# iterate through eq_list_one_setup and stop, when the difference between two consecutive elements is smaller than epsilon
def shortest_longest_time_until_convergence(Py1_lists, nu = 10e-50):

    converge_list_shortest = []
    converge_list_longest = []
    for j in range(len(Py1_lists)): # for each run

        shortest_convergence_setup = len(Py1_lists[0][0])
        longest_convergence_setup = 0

        for s in [0,1]:

            eq_list_one_setup = Py1_lists[j][s]

            for t in range(len(eq_list_one_setup)-1):
                t=t+1 # 1....50
                if np.abs(eq_list_one_setup[t-1] - eq_list_one_setup[len(eq_list_one_setup)-1]) < nu:
                    # print("conv t", t)
                    if t > longest_convergence_setup:
                        longest_convergence_setup = t
                    if t < shortest_convergence_setup:
                        shortest_convergence_setup = t

                    break

            if s == 0:
                first_shortest = shortest_convergence_setup

        longest_convergence_setup = max(first_shortest, shortest_convergence_setup)
        converge_list_shortest.append(shortest_convergence_setup)
        converge_list_longest.append(longest_convergence_setup)

    return converge_list_shortest, converge_list_longest, nu


import seaborn as sns
import matplotlib.pyplot as plt


def plot_utility_fairness_policies(Px_lists, Pd1_list, convergence_time, exp_name, c, legend, setups_dict={}):
    setup_len = len(Px_lists)
    sns.set_palette("colorblind")
    colors = sns.color_palette("colorblind", n_colors=setup_len)
    font_size = 30
    label_size = 20
    num_cat = len(Px_lists[0][0])
    Py1x = get_Py1x_given(num_cat=num_cat)
    convergence_time = 50

    if legend:
        assert setups_dict != {}, "setups_dict is empty"
    else:
        # create a setups_dict, where for each setup, the name is setup_i
        for i in range(len(Px_lists)):
            setups_dict[f"setup_{i}"] = i

    for i in range(len(Px_lists)):  # different initial distributions
        Px_one_setup = Px_lists[i]

        utility_list = get_utility_list(Pd1_list[i], Py1x, Px_one_setup, c)
        # multiply every element in utility_list by 100 and save in utility_list
        utility_list = [x * 100 for x in utility_list]
        plt.plot(utility_list, label=f"{setups_dict[i]}", color=colors[i], linewidth=2.5)

        fairness_list = get_fairness_EOP_list(Pd1_list[i], Py1x, Px_one_setup)
        fairness_list = [x * 100 for x in fairness_list]
        plt.plot(fairness_list, linestyle='dashed', color=colors[i], linewidth=2.5)

    plt.xlabel("Time", size=font_size)
    plt.xticks(np.arange(0, convergence_time, 10))
    plt.tick_params(axis='both', which='major', labelsize=label_size)
    plt.title(f"Utility (—) / Unfairness (---)", size=font_size)
    plt.ylim(0, 0.13*100)
    plt.xlim(0,convergence_time)
    # make y ticks integers between 0 and 15 in steps of 5
    plt.yticks(np.arange(0, 13, 2))

    # make a grey line as part of the grid at each integer 1,2,3,4,5
    plt.grid(color='grey', linestyle='-', linewidth=0.5, alpha=0.5)



    print("fairness", fairness_list[-1].round(4))
    print("utility", utility_list[-1].round(4))

    if legend:
        # anchor legend to the right close to the plot.
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=label_size)
        # plt.legend(bbox_to_anchor=(0, -0.12, 1, 0), loc='upper center', ncol=setup_len, fontsize=label_size)
    my_path = f'../figures/{exp_name}'
    os.makedirs(my_path, exist_ok=True)
    plt.savefig(f'../figures/{exp_name}/utility-eop.png', bbox_inches='tight')
    plt.show()


def plot_utility_fairness_inequity_policies(Px_lists, Pd1_list, convergence_time, exp_name, c, legend, setups_dict={}):
    setup_len = len(Px_lists)
    sns.set_palette("colorblind")
    colors = sns.color_palette("colorblind", n_colors=setup_len)
    font_size = 30
    label_size = 20
    num_cat = len(Px_lists[0][0])
    Py1x = get_Py1x_given(num_cat=num_cat)
    num_plots = 3
    fig, axs = plt.subplots(1, num_plots, figsize=(15, 5))
    convergence_time = 50

    if legend:
        assert setups_dict != {}, "setups_dict is empty"
    else:
        # create a setups_dict, where for each setup, the name is setup_i
        for i in range(len(Px_lists)):
            setups_dict[f"setup_{i}"] = i

    for i in range(len(Pd1_list)):  # different initial distributions
        Px_one_setup = Px_lists[i]

        utility_list = get_utility_list(Pd1_list[i], Py1x, Px_one_setup, c)
        # multiply every element in utility_list by 100 and save in utility_list
        utility_list = [x * 100 for x in utility_list]
        axs[0].plot(utility_list[:convergence_time], label=f"{setups_dict[i]}", color=colors[i], linewidth=2.5)
        axs[1].set_title("Utiliy", fontsize=font_size + 20)

        inequity_list = get_fairness_Py1_list(Pd1_list[i])
        # multiply every element in utility_list by 100 and save in utility_list
        inequity_list = [x * 100 for x in inequity_list]
        axs[1].plot(inequity_list[:convergence_time], label=f"{setups_dict[i]}", color=colors[i], linewidth=2.5)
        axs[1].set_title("Inequity", fontsize=font_size + 20)

        fairness_list = get_fairness_EOP_list(Pd1_list[i], Py1x, Px_one_setup)
        fairness_list = [x * 100 for x in fairness_list]
        axs[2].plot(fairness_list[:convergence_time], linestyle='dashed', color=colors[i], linewidth=2.5)
        axs[1].set_title("Unfairness", fontsize=font_size + 20)

    # plt.xlabel("Time", size=font_size)
    # plt.xticks(np.arange(0, convergence_time, 10))
    # plt.tick_params(axis='both', which='major', labelsize=label_size)
    # plt.title(f"Utility (—) / Unfairness (---)", size=font_size)
    # plt.ylim(0, 0.13*100)
    # # make y ticks integers between 0 and 15 in steps of 5
    # plt.yticks(np.arange(0, 13, 2))
    #
    # # make a grey line as part of the grid at each integer 1,2,3,4,5
    # plt.grid(axis='y', color='grey', linestyle='-', linewidth=0.5)
    #
    # print("fairness", fairness_list[-1].round(4))
    # print("utility", utility_list[-1].round(4))
    #
    # if legend:
    #     # anchor legend to the right close to the plot.
    #     plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=label_size)
    #     # plt.legend(bbox_to_anchor=(0, -0.12, 1, 0), loc='upper center', ncol=setup_len, fontsize=label_size)
        # set fontsize of x and y ticks
        for j in range(num_plots):
            axs[j].tick_params(axis='both', which='major', labelsize=font_size)
            axs[j].set_xticks([10, 20, 30, 40])
            axs[j].grid(axis='y', alpha=.3)
            axs[j].grid(axis='x', alpha=.3)
            axs[j].set_box_aspect(1)

        axs[num_plots - 1].legend(bbox_to_anchor=(1.04, 0.5), loc="center left", borderaxespad=0,
                                   fontsize=font_size + 10)
    plt.grid(color='grey', linestyle='-', linewidth=0.5, alpha=0.5)
    my_path = f'../figures/{exp_name}'
    os.makedirs(my_path, exist_ok=True)
    plt.savefig(f'figures/{exp_name}/utility-eop.png', bbox_inches='tight')
    plt.show()


def plot_utility_fairness(Px_lists, Pd1, convergence_time, exp_name, c):
    setup_len = len(Px_lists)
    sns.set_palette("colorblind")
    colors = sns.color_palette("colorblind", n_colors=setup_len)
    font_size = 30
    label_size = 20
    num_cat = len(Px_lists[0][0])
    Py1x = get_Py1x_given(num_cat=num_cat)
    convergence_time = 50

    # figsize = (6, height)

    for i in range(len(Px_lists)): # different initial distributions
        Px_one_setup = Px_lists[i]

        utility_list = get_utility_list(Pd1, Py1x, Px_one_setup, c)
        utility_list = [x * 100 for x in utility_list]
        plt.plot(utility_list, label=f"setup={i}", color=colors[i], linewidth=2.5)

        #  make lines thicker


        fairness_list = get_fairness_EOP_list(Pd1, Py1x, Px_one_setup)
        fairness_list = [x * 100 for x in fairness_list]
        plt.plot(fairness_list, linestyle='dashed', color=colors[i], linewidth=2.5)

    plt.xlabel("Time", size=font_size)
    plt.xticks(np.arange(0, convergence_time, 10))
    plt.tick_params(axis='both', which='major', labelsize=label_size)
    plt.title(f"Utility (—) / Unfairness (---)", size=font_size)
    plt.xlim(0, convergence_time)
    # fix the height of the plot

    print("fairness", fairness_list[-1].round(4))
    print("utility", utility_list[-1].round(4))
    plt.grid(color='grey', linestyle='-', linewidth=0.5, alpha=0.5)

    my_path = f'../figures/{exp_name}'
    os.makedirs(my_path, exist_ok=True)
    my_path = f'../figures/{exp_name}'
    os.makedirs(my_path, exist_ok=True)
    plt.savefig(f'../figures/{exp_name}/utility-eop.png', bbox_inches='tight')
    plt.show()



def get_utility_list(Pd1, Py1x, Px_one_setup, c):
    num_cat = len(Py1x[0])
    res_list = []
    Ps = get_Ps_given()
    # if Pd1 is len 2*num_cat, then reshape (2,num_cat)
    if len(Pd1) == 2*num_cat:
        Pd1 = np.reshape(Pd1, (2, num_cat))
    for t in range(len(Px_one_setup[0][0])):
        res = 0
        for s in [0,1]:
            for x in range(num_cat):
                res += Pd1[s][x] * (Py1x[s][x] - c) * Px_one_setup[s][x][t]*Ps[s]
        res_list.append(res)

    return res_list

def get_fairness_Py1_list(Py1_list):
    res = []
    for i in range(len(Py1_list[0])):
        res.append(np.abs(Py1_list[0][i] - Py1_list[1][i]))
    return res

def get_fairness_EOP_list(Pd1, Py1x, Px_one_setup):
    num_cat = len(Py1x[0])
    res_list = []
    if len(Pd1) == 2*num_cat:
        Pd1 = np.reshape(Pd1, (2, num_cat))
    for t in range(len(Px_one_setup[0][0])):
        EOP = [0, 0]
        for s in [0,1]:
            sum0 = 0
            sum1 = 0
            for x in range(num_cat):
                sum0 += (Pd1[s][x] * Py1x[s][x] * Px_one_setup[s][x][t])
                sum1 += (Py1x[s][x] * Px_one_setup[s][x][t])
            EOP[s] = sum0 / sum1
        res_list.append(np.abs(EOP[0] - EOP[1]))
    return res_list


def loans_granted_per_group(Pd1, Px_one_setup):
    num_cat = len(Px_one_setup[0])

    if len(Pd1) == 2*num_cat:
        Pd1 = Pd1.reshape(2, num_cat)


    res_list = []
    for s in [0,1]:
        # for each time step
        s_list = []
        for i in range(len(Px_one_setup[0][0])):
            res = 0
            # for each x
            for x in range(num_cat):
                res += Pd1[s][x] * Px_one_setup[s][x][i]
            s_list.append(res)
        res_list.append(s_list)

    return res_list

def plot_loans_payback(Px_lists, Py1_lists, Pd1, convergence_time, exp_name, s=0):
    setup_len = len(Py1_lists)
    sns.set_palette("colorblind")
    colors = sns.color_palette("colorblind", n_colors=setup_len)
    font_size = 25

    for i in range(len(Px_lists)):
        one_setup = Px_lists[i]
        loans_per_group = loans_granted_per_group(Pd1, one_setup)
        plt.plot(loans_per_group[s][:convergence_time], label=f"S={s}", color=colors[i])
        plt.xticks(np.arange(0, convergence_time, 10))
        plt.tick_params(axis='both', which='major', labelsize=20)
        plt.yticks(np.arange(0.4, 1, 0.1))
        plt.ylim(0.3, 1)

        one_setup = Py1_lists[i]
        plt.plot(one_setup[0][:convergence_time], linestyle='dashed', color=colors[i])

    plt.xlabel("Time", size=font_size)
    plt.title(f"Loans (—) / Payback (---), S={s}", size=font_size)

    my_path = f'../figures/{exp_name}'
    os.makedirs(my_path, exist_ok=True)
    plt.savefig(f'../figures/{exp_name}/loans-payback-S{s}.png', bbox_inches='tight')
    plt.show()


def process_policies(policies_list, policies_dict, setup_list, steps, num_cat, estimation='_true'):
    list_Px_lists = []
    list_Py1_lists = []

    k = 0
    for pi in policies_list:
        my_dynamics = policies_dict[k].split('_')[-1]
        kernel = get_kernel_from_policy(pi, num_cat, seed=k, dynamics=my_dynamics, estimation=estimation)
        pi = pi.reshape(2, num_cat)
        Px_lists, Py1_lists = run_simulation_diff_setups(setup_list, steps, num_cat, kernel, pi)
        list_Px_lists.append(Px_lists)
        list_Py1_lists.append(Py1_lists)
        k += 1

    list_Px_lists = [inner_list[0] for inner_list in list_Px_lists]
    list_Py1_lists = [inner_list[0] for inner_list in list_Py1_lists]

    return list_Px_lists, list_Py1_lists