import numpy as np
import copy

class Recursion(object):
    def __init__(self, env):
        self.env = env

    # known joint policy - evaluation
    def recursion_1_a(self, q_1, q_2, tol=1e-5):
        env = self.env
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim - 1)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            for s in range(env.n_dim - 1):

                U[s] = np.sum([q_1[s, a1] * q_2[s, a2] * 
                    (env.R(s, a1, a2) + gamma * U_old[env.T(s, a1, a2)]) 
                    for a1 in range(env.n_actions_1) for a2 in range(env.n_actions_2)
                    ]) * (s != env.terminal_state)

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U

    # optimal joint policy
    def recursion_1_b(self, tol=1e-5):
        env = self.env
        return env.V_star

    # known joint policy - best response of A1
    def recursion_1_c_ag1(self, q_2, tol=1e-5):
        env = self.env
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim - 1)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            for s in range(env.n_dim - 1):

                U[s] = max([np.sum([q_2[s, a2] * 
                    (env.R(s, a1, a2) + gamma * U_old[env.T(s, a1, a2)]) 
                    for a2 in range(env.n_actions_2)])
                    for a1 in range(env.n_actions_1) 
                    ]) * (s != env.terminal_state)

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U

    # known joint policy - best response of A2
    def recursion_1_c_ag2(self, q_1, tol=1e-5):
        env = self.env
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim - 1)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            for s in range(env.n_dim - 1):

                U[s] = max([np.sum([q_1[s, a1] * 
                    (env.R(s, a1, a2) + gamma * U_old[env.T(s, a1, a2)]) 
                    for a1 in range(env.n_actions_1)])
                    for a2 in range(env.n_actions_2)
                    ]) * (s != env.terminal_state)

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U

    # find optimizer of J(q) (for valid aproach)
    def recursion_2_a(self, tol=1e-5):
        env = self.env
        q_1_lo = env.q_1_lo
        q_1_up = env.q_1_up
        q_1_es = env.q_1_es
        q_2_es = env.q_2_es
        q_2_lo = env.q_2_lo
        q_2_up = env.q_2_up
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim - 1)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            q_1_max = copy.deepcopy(q_1_es)
            q_2_max = copy.deepcopy(q_2_es)
            for s in range(env.n_dim - 1):

                iterate = [a1 for a1 in range(env.n_actions_1)]
                a1_max = []

                # sort actions
                for _ in range(env.n_actions_1):

                    action = iterate[np.argmax(np.array([
                        env.R(s, a1, 0) + gamma * U_old[env.T(s, a1, 0)] 
                        for a1 in iterate
                        ]))]

                    a1_max.append(action)
                    iterate.remove(action)               

                # add estimation error
                est_err_a = env.est_err_1_alpha
                for i in range(env.n_actions_1):
                    e = min(est_err_a, q_1_up[s, a1_max[i]] - q_1_es[s, a1_max[i]])
                    q_1_max[s, a1_max[i]] = q_1_es[s, a1_max[i]] + e
                    est_err_a = est_err_a - e
                
                # substract estimation error
                est_err_s = env.est_err_1_alpha - est_err_a
                for i in range(env.n_actions_1 - 1, -1, -1):
                    e = min(est_err_s, q_1_max[s, a1_max[i]] - q_1_lo[s, a1_max[i]])
                    q_1_max[s, a1_max[i]] = q_1_max[s, a1_max[i]] - e
                    est_err_s = est_err_s - e

                assert np.sum(q_1_max[s]) <= 1.0001 and np.sum(q_1_max[s]) >= .9999, (q_1_max[s], np.sum(q_1_max[s]), s)
                for a1 in range(env.n_actions_1):
                    assert q_1_max[s, a1] <= q_1_up[s, a1] + 0.00001 and q_1_max[s, a1] >= q_1_lo[s, a1] - 0.000001

                a2_max = np.argmax(np.array([
                    np.sum([q_1_max[s, a1] *
                        (env.R(s, a1, a2) + gamma * U_old[env.T(s, a1, a2)])
                    for a1 in range(env.n_actions_1)])
                    for a2 in range(env.n_actions_2)
                    ]))

                q_2_max[s, a2_max] = q_2_up[s, a2_max]

                a2_min = (a2_max + 1) % 2
                q_2_max[s, a2_min] = q_2_lo[s, a2_min]

                assert np.sum(q_2_max[s]) <= 1.0001 and np.sum(q_2_max[s]) >= .9999, q_2_max[s]

                U[s] = np.sum([q_1_max[s, a1] * q_2_max[s, a2] * 
                    (env.R(s, a1, a2) + gamma * U_old[env.T(s, a1, a2)]) 
                    for a1 in range(env.n_actions_1) for a2 in range(env.n_actions_2)
                    ]) * (s != env.terminal_state)

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U, q_1_max, q_2_max

    # Known confidence set - LB of A1's best response
    def recursion_3_a_ag1(self, tol=1e-5):
        env = self.env
        q_2_es = env.q_2_es
        q_2_lo = env.q_2_lo
        q_2_up = env.q_2_up
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim - 1)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            q_2_min = copy.deepcopy(q_2_es)
            for s in range(env.n_dim - 1):

                a1_max = np.argmax(np.array([
                    env.R(s, a1, 0) + gamma * U_old[env.T(s, a1, 0)] 
                    for a1 in range(env.n_actions_1)
                    ]))

                a2_min = np.argmin(np.array([
                    env.R(s, a1_max, a2) + 
                    gamma * U_old[env.T(s, a1_max, a2)]
                    for a2 in range(env.n_actions_2)
                    ]))

                q_2_min[s, a2_min] = q_2_up[s, a2_min]

                a2_max = (a2_min + 1) % 2
                q_2_min[s, a2_max] = q_2_lo[s, a2_max]

                assert np.sum(q_2_min[s]) <= 1.0001 and np.sum(q_2_min[s]) >= .9999, q_2_min[s]

                U[s] = max([np.sum([q_2_min[s, a2] * 
                    (env.R(s, a1, a2) + gamma * U_old[env.T(s, a1, a2)]) 
                    for a2 in range(env.n_actions_2)])
                    for a1 in range(env.n_actions_1) 
                    ]) * (s != env.terminal_state)

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U


    # Known confidence set - LB of A2's best response
    def recursion_3_a_ag2(self, tol=1e-5):
        env = self.env
        q_1_es = env.q_1_es
        q_1_lo = env.q_1_lo
        q_1_up = env.q_1_up
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim - 1)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            q_1_min = copy.deepcopy(q_1_es)
            for s in range(env.n_dim - 1):

                iterate = [a1 for a1 in range(env.n_actions_1)]
                a1_min = []

                # sort actions
                for _ in range(env.n_actions_1):

                    action = iterate[np.argmin(np.array([
                        env.R(s, a1, 0) + gamma * U_old[env.T(s, a1, 0)] 
                        for a1 in iterate
                        ]))]

                    a1_min.append(action)
                    iterate.remove(action)               

                # add estimation error
                est_err_a = env.est_err_1_alpha
                for i in range(env.n_actions_1):
                    e = min(est_err_a, q_1_up[s, a1_min[i]] - q_1_es[s, a1_min[i]])
                    q_1_min[s, a1_min[i]] = q_1_es[s, a1_min[i]] + e
                    est_err_a = est_err_a - e
                
                # substract estimation error
                est_err_s = env.est_err_1_alpha - est_err_a
                for i in range(env.n_actions_1 - 1, -1, -1):
                    e = min(est_err_s, q_1_min[s, a1_min[i]] - q_1_lo[s, a1_min[i]])
                    q_1_min[s, a1_min[i]] = q_1_min[s, a1_min[i]] - e
                    est_err_s = est_err_s - e

                assert np.sum(q_1_min[s]) <= 1.0001 and np.sum(q_1_min[s]) >= .9999, (q_1_min[s], np.sum(q_1_min[s]), s)
                for a1 in range(env.n_actions_1):
                    assert q_1_min[s, a1] <= q_1_up[s, a1] + 0.00001 and q_1_min[s, a1] >= q_1_lo[s, a1] - 0.000001

                U[s] = max([np.sum([q_1_min[s, a1] * 
                    (env.R(s, a1, a2) + gamma * U_old[env.T(s, a1, a2)]) 
                    for a1 in range(env.n_actions_1)])
                    for a2 in range(env.n_actions_2) 
                    ]) * (s != env.terminal_state)

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U
    
    # Known confidence set - UB of A1's best response
    def recursion_3_b_ag1(self, tol=1e-5):
        env = self.env
        q_2_es = env.q_2_es
        q_2_lo = env.q_2_lo
        q_2_up = env.q_2_up
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim - 1)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            q_2_max = copy.deepcopy(q_2_es)
            for s in range(env.n_dim - 1):

                a1_max = np.argmax(np.array([
                    env.R(s, a1, 0) + gamma * U_old[env.T(s, a1, 0)] 
                    for a1 in range(env.n_actions_1)
                    ]))

                a2_max = np.argmax(np.array([
                    env.R(s, a1_max, a2) + 
                    gamma * U_old[env.T(s, a1_max, a2)]
                    for a2 in range(env.n_actions_2)
                    ]))

                q_2_max[s, a2_max] = q_2_up[s, a2_max]

                a2_min = (a2_max + 1) % 2
                q_2_max[s, a2_min] = q_2_lo[s, a2_min]

                assert np.sum(q_2_max[s]) <= 1.0001 and np.sum(q_2_max[s]) >= .9999, q_2_max[s]

                U[s] = max([np.sum([q_2_max[s, a2] * 
                    (env.R(s, a1, a2) + gamma * U_old[env.T(s, a1, a2)]) 
                    for a2 in range(env.n_actions_2)])
                    for a1 in range(env.n_actions_1) 
                    ]) * (s != env.terminal_state)

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U

    # Known confidence set - UB of A2's best response
    def recursion_3_b_ag2(self, tol=1e-5):
        env = self.env
        q_1_es = env.q_1_es
        q_1_lo = env.q_1_lo
        q_1_up = env.q_1_up
        gamma = env.gamma
        U = dict([(s, 0) for s in range(env.n_dim - 1)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            q_1_max = copy.deepcopy(q_1_es)
            for s in range(env.n_dim - 1):

                iterate = [a1 for a1 in range(env.n_actions_1)]
                a1_max = []

                # sort actions
                for _ in range(env.n_actions_1):

                    action = iterate[np.argmax(np.array([
                        env.R(s, a1, 0) + gamma * U_old[env.T(s, a1, 0)] 
                        for a1 in iterate
                        ]))]

                    a1_max.append(action)
                    iterate.remove(action)               

                # add estimation error
                est_err_a = env.est_err_1_alpha
                for i in range(env.n_actions_1):
                    e = min(est_err_a, q_1_up[s, a1_max[i]] - q_1_es[s, a1_max[i]])
                    q_1_max[s, a1_max[i]] = q_1_es[s, a1_max[i]] + e
                    est_err_a = est_err_a - e
                
                # substract estimation error
                est_err_s = env.est_err_1_alpha - est_err_a
                for i in range(env.n_actions_1 - 1, -1, -1):
                    e = min(est_err_s, q_1_max[s, a1_max[i]] - q_1_lo[s, a1_max[i]])
                    q_1_max[s, a1_max[i]] = q_1_max[s, a1_max[i]] - e
                    est_err_s = est_err_s - e

                assert np.sum(q_1_max[s]) <= 1.0001 and np.sum(q_1_max[s]) >= .9999, (q_1_max[s], np.sum(q_1_max[s]), s)
                for a1 in range(env.n_actions_1):
                    assert q_1_max[s, a1] <= q_1_up[s, a1] + 0.00001 and q_1_max[s, a1] >= q_1_lo[s, a1] - 0.000001

                U[s] = max([np.sum([q_1_max[s, a1] * 
                    (env.R(s, a1, a2) + gamma * U_old[env.T(s, a1, a2)]) 
                    for a1 in range(env.n_actions_1)])
                    for a2 in range(env.n_actions_2) 
                    ]) * (s != env.terminal_state)

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U