

from simulation.optim_utils import get_stationary_dist_from_kernel


import numpy as np

import argparse
def get_total_qualifications(kernel, Py1x, Ps, num_cat):
    util = 0
    for s in [0, 1]:
        Pxs = get_stationary_dist_from_kernel(kernel[s])

        for i in range(num_cat):  # i is the x in Px
            util += Py1x[s][i] * Pxs[i] * Ps[s]
    return util


def get_group_qualifications(kernel, s, Py1x, num_cat):
    Pxs = get_stationary_dist_from_kernel(kernel[s])
    res = np.matmul(Py1x[s].reshape(1, num_cat), Pxs)
    return res[0].round(4)

def get_group_EOP(kernel, s, Py1x, num_cat, Pd1):
    Pxs = get_stationary_dist_from_kernel(kernel[s])
    sum0 = 0
    sum1 = 0
    for i in range(num_cat):  # i is the x in Px
        sum0 += (Pd1[s][i] * Py1x[s][i] * Pxs[i])
        sum1 += (Py1x[s][i] * Pxs[i])

    return sum0 / sum1


def get_total_utility(kernel, Py1x, Ps, Pd1, num_cat, c):
    util = 0
    for s in [0, 1]:

        Pxs = get_stationary_dist_from_kernel(kernel[s])

        for i in range(num_cat):  # i is the x in Px
            util += Pd1[s][i] * (Py1x[s][i] - c) * Pxs[i] * Ps[s]

    return util