import numpy as np
import copy
import itertools


class Gridworld(object):
    def __init__(self, 
                alpha=0.3, 
                alpha_pr=0.1, 
                est_err_1=0.2, 
                est_err_2=0,
                q_2_stochastic=False):
        h = -0.5
        f = -0.02
        C = 0.05
        self.C = C
        self.grid = np.array(
            [[-0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01],
             [-0.01, -0.01, f, -0.01, h, -0.01, -0.01, -0.01],
             [-0.01, -0.01, -0.01, h, -0.01, -0.01, f, -0.01],
             [-0.01, f, -0.01, -0.01, -0.01, h, -0.01, f],
             [-0.01, -0.01, -0.01, h, -0.01, -0.01, f, -0.01],
             [-0.01, h, h, -0.01, f, -0.01, h, -0.01],
             [-0.01, h, -0.01, -0.01, h, -0.01, h, -0.01],
             [-0.01, -0.01, -0.01, h, -0.01, f, -0.01, +1]])

        self.terminal_state = 63
        self.alpha = alpha
        self.alpha_pr = alpha_pr
        self.est_err_1 = est_err_1
        self.est_err_1_alpha = (1 - alpha) * est_err_1
        self.est_err_2 = est_err_2
        # in paper experiments Agent A2 is deterministic
        self.q_2_stochastic = q_2_stochastic
        self.n_actions_1 = 4
        self.n_actions_2 = 2
        # state space
        self.n_dim = np.prod(self.grid.shape) + 1  # +1 for absorbing state
        # move: left, right, up, down
        self.mapping_reversed = {(0, -1): 0, (0, 1): 1, (-1, 0): 2, (1, 0): 3}
        self.mapping = {val: key for key, val in self.mapping_reversed.items()}
        self.gamma = .99
        self.V_star = self.value_iteration()
        self.opt_move = self.create_opt_moves()

        self.instantiate_behavior_policy_1()
        self.instantiate_behavior_policy_2()
      
    # transitions function
    def T(self, s, a1, a2):
        assert a1 in range(self.n_actions_1)
        assert a2 in range(self.n_actions_2)
        if a2 == 0:
            return self.T_w(s, a1)
        else:
            return self.T_w(s, self.opt_move[s])

    # reward function
    def R(self, s, a1, a2):
        assert a1 in range(self.n_actions_1)
        assert a2 in range(self.n_actions_2)
        if a2 == 0:
            return self.R_w(s, a1)
        else:
            return self.R_w(s, self.opt_move[s]) - self.C

    def expected_performance(self, V):
        J = -V[0]
        for s in range(self.grid.shape[0]):
            J += V[s]
        for s in range(self.grid.shape[1]):
            J += V[s * self.grid.shape[0]]
        return J / (self.grid.shape[0] + self.grid.shape[1] - 1)
        
    # single-agent transitions function
    def T_w(self, s, a1):
        assert a1 in range(self.n_actions_1)
        state = (s // self.grid.shape[0], s % self.grid.shape[1])

        new_state = self.vector_add(state, self.mapping[a1])
        if not self.is_valid(new_state):
            new_state = state

        return new_state[0] * self.grid.shape[0] + new_state[1]
    
    # singl-agent reward function
    def R_w(self, s, a1):
        assert a1 in range(self.n_actions_1)
        ns = self.T_w(s, a1)
        new_state = (ns // self.grid.shape[0], ns % self.grid.shape[1])

        return self.grid[new_state[0], new_state[1]]
    
    # single-agent optimal value function
    def value_iteration(self, tol=1e-5):
        gamma = self.gamma
        U = dict([(s, 0) for s in range(self.n_dim - 1)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            for s in range(self.n_dim - 1):
                U[s] = max([
                    self.R_w(s, a1) + gamma * U_old[self.T_w(s, a1)] 
                    for a1 in range(self.n_actions_1)
                    ]) * (s != self.terminal_state)

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U

    # single-agent optimal move selection
    def create_opt_moves(self):
        gamma = self.gamma
        opt_moves = dict([(s, 0) for s in range(self.n_dim - 1)])
        for s in range(self.n_dim - 1):
            opt_moves[s] = np.argmax(np.array([
                self.R_w(s, a1) + gamma * self.V_star[self.T_w(s, a1)] 
                for a1 in range(self.n_actions_1)
                ]))
        return opt_moves

    def best_resp_ag2(self, q_1, tol=1e-5):
        gamma = self.gamma
        U = dict([(s, 0) for s in range(self.n_dim - 1)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            for s in range(self.n_dim - 1):

                U[s] = max([np.sum([q_1[s, a1] * 
                    (self.R(s, a1, a2) + gamma * U_old[self.T(s, a1, a2)]) 
                    for a1 in range(self.n_actions_1)])
                    for a2 in range(self.n_actions_2) 
                    ]) * (s != self.terminal_state)

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                break
        
        best_resp_2 = np.zeros(shape=(self.n_dim - 1, self.n_actions_1))
        for s in range(self.n_dim - 1):
            a2_max = np.argmax(np.array([np.sum([q_1[s, a1] *
                (self.R(s, a1, a2) + gamma * U[self.T(s, a1, a2)]) 
                for a1 in range(self.n_actions_1)])
                for a2 in range(self.n_actions_2) 
            ]))
            best_resp_2[s, a2_max] = 1.0
        
        return best_resp_2

    def instantiate_behavior_policy_1(self):
        alpha = self.alpha

        # personal policy of agent A1
        self.q_1_per_tr = np.array(
            [[0, .5, 0, .5], [0, .5, 0, .5], [0, .8, 0, .2], [0, .6, 0, .4],
                [0, .8, 0, .2], [0, .5, 0, .5], [0, .5, 0, .5], [0, 0, 0, 1.0],
            [0, .5, 0, .5], [0, .2, 0, .8], [0, .4, 0, .6], [0, .5, 0, .5],
                [0, .5, 0, .5], [0, .5, 0, .5], [0, .8, 0, .2], [0, 0, 0, 1.0],
            [0, .5, 0, .5], [0, .8, 0, .2], [0, .2, 0, .8], [0, .5, 0, .5],
                [0, .4, 0, .6], [0, .5, 0, .5], [0, .5, 0, .5], [0, 0, 0, 1.0],
            [0, .8, 0, .2], [0, .5, 0, .5], [0, .6, 0, .4], [0, .8, 0, .2],
                [0, .2, 0, .8], [0, .4, 0, .6], [0, .5, 0, .5], [0, 0, 0, 1.0],
            [0, .5, 0, .5], [0, .8, 0, .2], [0, .5, 0, .5], [0, .6, 0, .4],
                [0, .8, 0, .2], [0, .2, 0, .8], [0, .8, 0, .2], [0, 0, 0, 1.0],
            [0, .5, 0, .5], [0, .5, 0, .5], [0, .5, 0, .5], [0, .5, 0, .5],
                [0, .8, 0, .2], [0, .3, 0, .7], [0, .8, 0, .2], [0, 0, 0, 1.0],
            [0, .2, 0, .8], [0, .5, 0, .5], [0, .5, 0, .5], [0, .5, 0, .5],
                [0, .5, 0, .5], [0, .5, 0, .5], [0, .5, 0, .5], [0, 0, 0, 1.0],
            [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0],
                [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [.25, .25, .25, .25]])
        
        # induce estimation error
        self.q_1_per_es = copy.deepcopy(self.q_1_per_tr)
        max_est_err = self.est_err_1
        for s in range(self.n_dim - 1):
            # substract probs
            est_err_s = max_est_err
            r = list(range(self.n_actions_1))
            np.random.shuffle(r)
            for a1 in r:
                e = min(self.q_1_per_tr[s, a1], np.random.rand() * est_err_s)
                self.q_1_per_es[s, a1] = self.q_1_per_es[s, a1] - e
                est_err_s = est_err_s - e
            # add probs
            est_err_a = max_est_err - est_err_s
            r = list(range(self.n_actions_1))
            np.random.shuffle(r)
            for a1 in itertools.islice(r, self.n_actions_1 - 1):
                e = np.random.rand() * est_err_a
                self.q_1_per_es[s, a1] = self.q_1_per_es[s, a1] + e
                est_err_a = est_err_a - e
            a1 = r[self.n_actions_1 - 1]
            self.q_1_per_es[s, a1] = self.q_1_per_es[s, a1] + est_err_a

            assert np.sum(self.q_1_per_es[s]) >= .999999 and np.sum(self.q_1_per_es[s]) <= 1.000001, np.sum(self.q_1_per_es[s])

        # LB1 and UB1
        self.q_1_per_lo = copy.deepcopy(self.q_1_per_es)
        self.q_1_per_up = copy.deepcopy(self.q_1_per_es)
        for s in range(self.n_dim - 1):
            for a1 in range(self.n_actions_1):
                self.q_1_per_lo[s, a1] = max(0.0, self.q_1_per_lo[s, a1] - max_est_err)
                self.q_1_per_up[s, a1] = min(1.0, self.q_1_per_up[s, a1] + max_est_err)

        # parameterize behavior policy of A1 with alpha
        self.q_1_tr = copy.deepcopy(self.q_1_per_tr)
        self.q_1_es = copy.deepcopy(self.q_1_per_es)
        self.q_1_lo = copy.deepcopy(self.q_1_per_lo)
        self.q_1_up = copy.deepcopy(self.q_1_per_up)
        for s in range(self.n_dim - 1):
            for a1 in range(self.n_actions_1):
                self.q_1_tr[s, a1] = (1.0 - alpha) * self.q_1_per_tr[s, a1]
                self.q_1_es[s, a1] = (1.0 - alpha) * self.q_1_per_es[s, a1]
                self.q_1_lo[s, a1] = (1.0 - alpha) * self.q_1_per_lo[s, a1]
                self.q_1_up[s, a1] = (1.0 - alpha) * self.q_1_per_up[s, a1]

            a1_max = self.opt_move[s]
            self.q_1_tr[s, a1_max] = self.q_1_tr[s, a1_max] + alpha
            self.q_1_es[s, a1_max] = self.q_1_es[s, a1_max] + alpha
            self.q_1_lo[s, a1_max] = self.q_1_lo[s, a1_max] + alpha
            self.q_1_up[s, a1_max] = self.q_1_up[s, a1_max] + alpha
            

    def instantiate_behavior_policy_2(self):
        alpha = self.alpha_pr

        # parameterize behavior policy of A1 with alpha'
        q_1_pr = copy.deepcopy(self.q_1_per_tr)
        for s in range(self.n_dim - 1):
            for a in range(self.n_actions_1):
                q_1_pr[s, a] = (1.0 - alpha) * self.q_1_per_tr[s, a]

            a1_max = self.opt_move[s]
            q_1_pr[s, a1_max] = q_1_pr[s, a1_max] + alpha
        
        # compute deterministic best response to q_1_pr
        self.q_2_opt_tr = self.best_resp_ag2(q_1_pr)

        if self.q_2_stochastic:
            # true probs - low = .1, high = .2
            prob_tr = np.random.rand(self.n_dim - 1,)
            prob_tr = prob_tr * .1
            prob_tr = prob_tr + .1
            
            # estimated probs +- est_err_2
            prob_es = np.random.rand(self.n_dim - 1,)
            prob_es = prob_es * 2 * self.est_err_2
            prob_es = prob_es - self.est_err_2
            
            # set true/estimated/LB/UB behavior policy q_2
            self.q_2_tr = np.zeros(shape=(self.n_dim - 1, self.n_actions_2))
            self.q_2_es = np.zeros(shape=(self.n_dim - 1, self.n_actions_2))
            self.q_2_lo = np.zeros(shape=(self.n_dim - 1, self.n_actions_2))
            self.q_2_up = np.zeros(shape=(self.n_dim - 1, self.n_actions_2))
            for s in range(self.n_dim - 1):
                for a2 in range(self.n_actions_2):
                    if self.q_2_opt_tr[s, a2] == 0:
                        self.q_2_tr[s, a2] = prob_tr[s]
                        self.q_2_es[s, a2] = max(0, self.q_2_tr[s, a2] + prob_es[s])
                        self.q_2_lo[s, a2] = max(0, self.q_2_es[s, a2] - self.est_err_2)
                        self.q_2_up[s, a2] = min(1.0, self.q_2_es[s, a2] + self.est_err_2)
                    else:
                        self.q_2_tr[s, a2] = 1.0 - prob_tr[s]
                        self.q_2_es[s, a2] = min(1.0, self.q_2_tr[s, a2] - prob_es[s])
                        self.q_2_lo[s, a2] = max(0, self.q_2_es[s, a2] - self.est_err_2)
                        self.q_2_up[s, a2] = min(1.0, self.q_2_es[s, a2] + self.est_err_2)
        else:
            self.q_2_tr = copy.deepcopy(self.q_2_opt_tr)
            self.q_2_es = copy.deepcopy(self.q_2_opt_tr)
            self.q_2_lo = copy.deepcopy(self.q_2_opt_tr)
            self.q_2_up = copy.deepcopy(self.q_2_opt_tr)
            assert self.est_err_2 == 0, "can not have estimation error when policy determinitstic"
    
    @staticmethod
    def vector_add(x, y):
        return (x[0] + y[0], x[1] + y[1])

    def is_valid(self, x):
        return (0 <= x[0] < self.grid.shape[0]) and (0 <= x[1] <
                                                     self.grid.shape[1])