import numpy as np
import cvxpy as cp
import copy
from itertools import product

class Recursion(object):
    def __init__(self, env):
        self.env = env

    # Known joint policy - evaluation
    def recursion_1_a(self, jq, tol=1e-5):
        env = self.env
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            for s in range(env.active_n_dim):

                U[s] = np.sum([
                    np.prod([jq[s, i, a[i]] for i in range(env.n_agents)
                    ]) * (env.R(s, a) + gamma * U_old[env.T(s, a)])
                    for a in env.joint_action_space
                ])         

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U

    # Optimal joint policy
    def recursion_1_b(self, tol=1e-5):
        env = self.env
        return env.V_star

    # Known joint policy - best response to W
    def recursion_1_c(self, W, jq, tol=1e-5):
        env = self.env
        jq_W = self.coalition_policy(W, jq)
        gamma = env.gamma
        # Action space for set W
        action_space_W = [env.action_space_i] * len(W)
        joint_action_space_W = [list(a) for a 
            in list(product(*action_space_W))]
        # Action space for set W' = N - W
        action_space_W_pr = [env.action_space_i] * (env.n_agents - len(W))
        joint_action_space_W_pr = [list(a) for a 
            in list(product(*action_space_W_pr))]
        U = dict([(s, 0) for s in range(env.n_dim)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            for s in range(env.active_n_dim):
                
                U[s] = max([np.sum([
                    np.prod([jq_W[s, i, a_W[i]] for i in range(len(W))
                    ]) * (env.R(s, self.merge_actions(W, a_W, a_W_pr)) 
                    + gamma * U_old[env.T(s, self.merge_actions(W, a_W, a_W_pr))])
                    for a_W in joint_action_space_W
                ]) for a_W_pr in joint_action_space_W_pr
                ])   

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U

    # Find optimizer of J(q) (for valid aproach)
    def recursion_2_a(self, tol=1e-5):
        env = self.env
        jq_es = env.jq_es
        jq_lo = env.jq_lo
        jq_up = env.jq_up
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            jq_max = copy.deepcopy(jq_es)
            for s in range(env.active_n_dim):
                max_val = - env.length
                for a in env.joint_action_space:
                    jq_max_s = copy.deepcopy(jq_es[s])
                    for i in range(env.n_agents):
                        jq_max_s[i, a[i]] = jq_up[s, i, a[i]]
                        jq_max_s[i, (a[i] + 1) % 2] = jq_lo[s, i, (a[i] + 1) % 2]
                    
                    val = np.sum([
                        np.prod([jq_max_s[i, a[i]] for i in range(env.n_agents)
                        ]) * (env.R(s, a) + gamma * U_old[env.T(s, a)])
                        for a in env.joint_action_space
                    ])

                    if val > max_val:
                        for i in range(env.n_agents):
                            for a_i in range(env.n_actions):
                                jq_max[s, i, a_i] = jq_max_s[i, a_i]
                        max_val = val
                    
                U[s] = max_val
            
                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U, jq_max


    # Known confidence set - LB best response to W
    def recursion_3_a(self, W, tol=1e-5):
        env = self.env
        jq_lo = env.jq_lo
        jq_up = env.jq_up
        # Action space for set W
        action_space_W = [env.action_space_i] * len(W)
        joint_action_space_W = [list(a) for a 
            in list(product(*action_space_W))]
        # Action space for set W' = N - W
        action_space_W_pr = [env.action_space_i] * (env.n_agents - len(W))
        joint_action_space_W_pr = [list(a) for a 
            in list(product(*action_space_W_pr))]
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            for s in range(env.active_n_dim):
                # Variables
                jprobs = cp.Variable(len(joint_action_space_W), nonneg=True)
                z = cp.Variable()
                # Objective
                objective = cp.Minimize(z)
                # Constraints
                constraints = []
                constraints.append(cp.sum(jprobs) == 1)

                for k in range(len(joint_action_space_W)):
                    a_W = joint_action_space_W[k]
                    jlo_k = np.prod([jq_lo[s, W[i], a_W[i]] for i in range(len(W))])
                    jup_k = np.prod([jq_up[s, W[i], a_W[i]] for i in range(len(W))])
                    constraints.append(jprobs[k] >= jlo_k)
                    constraints.append(jprobs[k] <= jup_k)
                
                for a_W_pr in joint_action_space_W_pr:
                    A = []
                    for a_W in joint_action_space_W:
                        a = self.merge_actions(W, a_W, a_W_pr)
                        A.append(env.R(s, a) + gamma * U_old[env.T(s, a)])
                    constraints.append(z >= jprobs @ A)

                problem = cp.Problem(objective, constraints)
                solution = problem.solve(solver=cp.ECOS)

                U[s] = solution

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U
                        

    # Known confidence set - UB best response to W
    def recursion_3_b(self, W, tol=1e-5):
        env = self.env
        jq_es = env.jq_es
        jq_lo = env.jq_lo
        jq_up = env.jq_up
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            for s in range(env.active_n_dim):
                max_val = - env.length
                for a in env.joint_action_space:
                    jq_max_s = copy.deepcopy(jq_es[s])
                    for i in range(env.n_agents):
                        if i in W:
                            jq_max_s[i, a[i]] = jq_up[s, i, a[i]]
                            jq_max_s[i, (a[i] + 1) % 2] = jq_lo[s, i, (a[i] + 1) % 2] 
                        else:
                            jq_max_s[i, a[i]] = 1
                            jq_max_s[i, (a[i] + 1) % 2] = 0
                    
                    val = np.sum([
                        np.prod([jq_max_s[i, a[i]] for i in range(env.n_agents)
                        ]) * (env.R(s, a) + gamma * U_old[env.T(s, a)])
                        for a in env.joint_action_space
                    ])

                    if val > max_val:
                        max_val = val
                    
                U[s] = max_val
            
                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U #, jq_max
    
    # Merge actions a_W and a_W' to a
    def merge_actions(self, W, a_W, a_W_pr):
        env = self.env
        a = [0] * env.n_agents
        counter_W = 0
        counter_W_pr = 0
        for i in range(env.n_agents):
            if i in W:
                a[i] = a_W[counter_W]
                counter_W += 1
            else:
                a[i] = a_W_pr[counter_W_pr]
                counter_W_pr += 1

        return a

    # Joint policy of agents in W
    def coalition_policy(self, W, jq):
        env = self.env
        jq_W = np.zeros(shape=(env.n_dim, len(W), env.n_actions))

        for s in range(env.n_dim):
            for i in range(len(W)):
                for a_i in env.action_space_i:
                    jq_W[s, i, a_i] = jq[s, W[i], a_i]
        
        return jq_W
