import cvxpy as cp
import numpy as np

class Cost_Allocation(object):
    def __init__(self, env):
        self.env = env

    # Shapley Value
    def SV(self, J_L, J_U):
        env = self.env
        N = [i for i in range(env.n_agents)]
        beta_sv = [0] * env.n_agents
        for i in range(env.n_agents):
            S_i = [W for W in self.powerset(N) if i in W]
            for W in S_i:
                W_pr = [j for j in W if j != i]
                a = float(self.fact(len(W_pr)) * self.fact(env.n_agents 
                        - len(W_pr) - 1)) / self.fact(env.n_agents)
                beta_sv[i] += a * max(0, J_L[str(W)] - J_U[str(W_pr)])

        return beta_sv

    # Average Participation
    def AP(self, J_L, J_U, beta_sv):
        env = self.env
        N = [i for i in range(env.n_agents)]
        beta_ap = [0] * env.n_agents

        if J_L == J_U:
            for i in range(env.n_agents):
                if beta_sv[i] == 0:
                    continue
                S_i = [W for W in self.powerset(N) if i in W]
                for W in S_i:
                    sum_c = np.sum([beta_sv[i] > 0 for i in W])
                    a = 1 / (sum_c * (2 ** env.n_agents - 1))
                    beta_ap[i] += a * max(0, J_L[str(W)] - J_U[str([])])
        else:
            for i in range(env.n_agents):
                if beta_sv[i] == 0:
                    continue
                S_i = [W for W in self.powerset(N) if i in W]
                for W in S_i:
                    a = 1 / (len(W) * (2 ** env.n_agents - 1))
                    beta_ap[i] += a * max(0, J_L[str(W)] - J_U[str([])])

        return beta_ap

    # Banzhaf Index
    def BI(self, J_L, J_U):
        env = self.env
        N = [i for i in range(env.n_agents)]
        beta_bi = [0] * env.n_agents
        for i in range(env.n_agents):
            S_i = [W for W in self.powerset(N) if i in W]
            for W in S_i:
                W_pr = [j for j in W if j != i]
                a = 1 / 2 ** (env.n_agents - 1)
                beta_bi[i] += a * max(0, J_L[str(W)] - J_U[str(W_pr)])

        return beta_bi


    # Marginal Contribution
    def MC(self, J_L, J_U):
        env = self.env
        beta_mc = [0] * env.n_agents
        for i in range(env.n_agents):
            beta_mc[i] = max(0, J_L[str([i])] - J_U[str([])])

        return beta_mc

    # Max-Efficient Rationality
    def MR(self, J_L, J_U):
        env = self.env

        beta_mr = cp.Variable(env.n_agents, nonneg=True)
        objective = cp.Maximize(cp.sum(beta_mr))
        constraints = []

        N = [i for i in range(env.n_agents)]
        S = [W for W in self.powerset(N) if W]
        for W in S:
            A = np.zeros(env.n_agents)
            for i in W:
                A[i] = 1
            constraints.append(A @ beta_mr <= max(0, J_L[str(W)] - J_U[str([])]))

        problem = cp.Problem(objective, constraints)
        solution = problem.solve(solver=cp.OSQP, eps_abs=1e-5)

        return beta_mr.value

    def powerset(self, s):
        x = len(s)
        masks = [1 << i for i in range(x)]
        for i in range(1 << x):
            yield [ss for mask, ss in zip(masks, s) if i & mask]
    
    def fact(self, n):
        fact = 1
        for i in range(1, n + 1):
            fact = fact * i
        
        return fact