
import numpy as np
from simulation.data_fico import get_Py1x_given, get_Px_given, get_Ps_given, get_Py1x_estimated
from simulation.dynamics import get_Pxx_given_onemed, get_Pxx_given_oneslow, get_Pxx_given_onefast, get_Pxx_given_default, \
    get_Pxx_random, get_Pxx_given_twomed_discouraged, get_Pxx_given_twomed_recourse, get_Pxx_estimated
from simulation.constants import Cte
def get_kernel_from_policy(Pd1, num_cat, seed, dynamics, estimation):
    # print(estimation)
    if estimation == "_true":
        Py1x = get_Py1x_given(num_cat)
    elif estimation in ["_est-random", "_est-threshold", '_est-biased']:
        Py1x = get_Py1x_estimated(estimation)
    else:
        raise NotImplementedError


    Pyx = np.array([1 - Py1x, Py1x])
    Pd1 = Pd1.reshape(2, num_cat)
    Pd = np.array([1 - Pd1, Pd1])


    if dynamics == Cte.ONEMED:
        Pxx = get_Pxx_given_onemed()
    elif dynamics == Cte.ONESLOW:
        Pxx = get_Pxx_given_oneslow()
    elif dynamics == Cte.ONEFAST:
        Pxx = get_Pxx_given_onefast()
    elif dynamics == Cte.DEFAULT:
        Pxx = get_Pxx_given_default()
    elif dynamics == Cte.RANDOM:
        Pxx = get_Pxx_random(num_cat, seed)
    elif dynamics == Cte.RECOURSE:
        Pxx = get_Pxx_given_twomed_recourse()
    elif dynamics == Cte.DISCOURAGED:
        Pxx = get_Pxx_given_twomed_discouraged()
    else:
        print(f"dynamics {dynamics} not found")
        raise NotImplementedError


    kernel = np.zeros((2, num_cat, num_cat)).astype(float)


    for s in [0, 1]:
        for to_x in range(num_cat):
            for from_x in range(num_cat):
                _kernel_entry = 0
                for d in [0, 1]:
                    for y in [0, 1]:
                        _kernel_entry += Pd[d][s][from_x] * Pyx[y][s][from_x] * Pxx[s][d][y][from_x][to_x]
                kernel[s][from_x][to_x] = _kernel_entry

    return kernel


def get_stationary_dist_from_kernel(kernel):
    # https://ninavergara2.medium.com/calculating-stationary-distribution-in-python-3001d789cd4b
    '''
    Since the sum of each row is 1, our matrix is row stochastic.
    We'll transpose the matrix to calculate eigenvectors of the stochastic rows.
    '''
    # print(kernel)
    transition_matrix_transp = kernel.T
    eigenvals, eigenvects = np.linalg.eig(transition_matrix_transp)
    # '''
    #     Find the indexes of the eigenvalues that are close to one.
    #     Use them to select the target eigen vectors. Flatten the result.
#     '''
    close_to_1_idx = np.isclose(eigenvals,1)
    target_eigenvect = eigenvects[:,close_to_1_idx]
    target_eigenvect = target_eigenvect[:,0]# Turn the eigenvector elements into probabilites
    stationary_distrib = target_eigenvect / sum(target_eigenvect)
    stationary_distrib = np.real_if_close(stationary_distrib)
    return stationary_distrib

#Pd1[s][x] - P(D=1|X=x, S=s)
def get_Pd1_given(num_cat=4):
    # random policy, all 0.5
    a = np.ones(num_cat).astype(float)*0.5
    # get an array of size num_cat filled with numbers between 01. and 0.7 in increasing order
    # a = np.linspace(0.1, 0.7, num_cat)
    # for both s
    return np.array([a, a])

def get_Py1_new(Px, Py1x, num_cat):
    # P(Y=1|S) = \sum_x P(Y=1|X=x, S=s)P(X=x|S=s)
    alpha = []
    for s in [0, 1]:
        Py1 = 0
        for x in range(num_cat):
            Py1 += Py1x[s][x] * Px[s][x]
        alpha.append(Py1)
    return alpha

def get_EOP_new(Px, Py1x, num_cat, Pd1):
    # P(Y=1|S) = \sum_x P(Y=1|X=x, S=s)P(X=x|S=s)

    EOP = []
    for s in [0, 1]:
        sum0 = 0
        sum1 = 0
        for x in range(num_cat):  # i is the x in Px
            sum0 += (Pd1[s][x] * Py1x[s][x] * Px[s][x])
            sum1 += (Py1x[s][x] * Px[s][x])
        EOP.append(sum0 / sum1)

    return EOP


def get_Px_new(kernel, Px):

    Px_new = np.array([np.dot(kernel[0].T, Px[0]).round(4), np.dot(kernel[1].T, Px[1]).round(4)])

    return Px_new





#
# import numpy as np
# from simulation.data_fico import get_Py1x_given, get_Px_given, get_Ps_given
# from simulation.dynamics import get_Pxx_given_onemed, get_Pxx_given_oneslow, get_Pxx_given_onefast, get_Pxx_given_default, get_Pxx_random, get_Pxx_given_twomed_discouraged, get_Pxx_given_twomed_recourse
# from simulation.constants import Cte
#
# def get_kernel_from_policy(Pd1, num_cat, seed, dynamics):
#
#     Py1x = get_Py1x_given(num_cat)
#     Pyx = np.array([1 - Py1x, Py1x])
#     Pd1 = Pd1.reshape(2, num_cat)
#     Pd = np.array([1 - Pd1, Pd1])
#
#     if dynamics == Cte.ONEMED:
#         Pxx = get_Pxx_given_onemed()
#     elif dynamics == Cte.ONESLOW:
#         Pxx = get_Pxx_given_oneslow()
#     elif dynamics == Cte.ONEFAST:
#         Pxx = get_Pxx_given_onefast()
#     elif dynamics == Cte.DEFAULT:
#         Pxx = get_Pxx_given_default()
#     elif dynamics == Cte.RANDOM:
#         Pxx = get_Pxx_random(num_cat, seed)
#     elif dynamics == Cte.RECOURSE:
#         Pxx = get_Pxx_given_twomed_recourse()
#     elif dynamics == Cte.DISCOURAGED:
#         Pxx = get_Pxx_given_twomed_discouraged()
#     else:
#         # Pxx = get_Pxx_given(num_cat, seed)
#         print(f"dynamics {dynamics} not found")
#         raise NotImplementedError
#     # print(Pxx)
#     # create a 4 x 4 matrix filled with zeros
#
#     kernel = np.zeros((2, num_cat, num_cat)).astype(float)
#     for s in [0, 1]:
#         for to_x in range(num_cat):
#             for from_x in range(num_cat):
#                 _kernel_entry = 0
#                 for d in [0, 1]:
#                     for y in [0, 1]:
#                         _kernel_entry += Pd[d][s][from_x] * Pyx[y][s][from_x] * Pxx[s][d][y][from_x][to_x]
#                 kernel[s][from_x][to_x] = _kernel_entry
#
#     return kernel
#
#
# def get_stationary_dist_from_kernel(kernel):
#     # https://ninavergara2.medium.com/calculating-stationary-distribution-in-python-3001d789cd4b
#     '''
#     Since the sum of each row is 1, our matrix is row stochastic.
#     We'll transpose the matrix to calculate eigenvectors of the stochastic rows.
#     '''
#     # print(kernel)
#     transition_matrix_transp = kernel.T
#     eigenvals, eigenvects = np.linalg.eig(transition_matrix_transp)
#     # '''
#     #     Find the indexes of the eigenvalues that are close to one.
#     #     Use them to select the target eigen vectors. Flatten the result.
# #     '''
#     close_to_1_idx = np.isclose(eigenvals,1)
#     target_eigenvect = eigenvects[:,close_to_1_idx]
#     target_eigenvect = target_eigenvect[:,0]# Turn the eigenvector elements into probabilites
#     stationary_distrib = target_eigenvect / sum(target_eigenvect)
#     stationary_distrib = np.real_if_close(stationary_distrib)
#     return stationary_distrib
#
# #Pd1[s][x] - P(D=1|X=x, S=s)
# def get_Pd1_given(num_cat=4):
#     # random policy, all 0.5
#     a = np.ones(num_cat).astype(float)*0.5
#     # get an array of size num_cat filled with numbers between 01. and 0.7 in increasing order
#     # a = np.linspace(0.1, 0.7, num_cat)
#     # for both s
#     return np.array([a, a])
#
# def get_Py1_new(Px, Py1x, num_cat):
#     # P(Y=1|S) = \sum_x P(Y=1|X=x, S=s)P(X=x|S=s)
#     alpha = []
#     for s in [0, 1]:
#         Py1 = 0
#         for x in range(num_cat):
#             Py1 += Py1x[s][x] * Px[s][x]
#         alpha.append(Py1)
#     return alpha
#
# def get_EOP_new(Px, Py1x, num_cat, Pd1):
#     # P(Y=1|S) = \sum_x P(Y=1|X=x, S=s)P(X=x|S=s)
#
#     EOP = []
#     for s in [0, 1]:
#         sum0 = 0
#         sum1 = 0
#         for x in range(num_cat):  # i is the x in Px
#             sum0 += (Pd1[s][x] * Py1x[s][x] * Px[s][x])
#             sum1 += (Py1x[s][x] * Px[s][x])
#         EOP.append(sum0 / sum1)
#
#     return EOP
#
#
# def get_Px_new(kernel, Px):
#
#     Px_new = np.array([np.dot(kernel[0].T, Px[0]).round(4), np.dot(kernel[1].T, Px[1]).round(4)])
#
#     return Px_new