import numpy as np
import copy
from itertools import product

class Graph(object):
    def __init__(self, version=0, n_agents = 4, length = 4, alpha=.2, est_err=.1):

        self.n_agents = n_agents
        self.length = length
        self.n_actions = 2
        self.action_space_i = [0, 1]
        self.action_space = [self.action_space_i] * self.n_agents
        self.joint_action_space = [list(a) for a 
            in list(product(* self.action_space))]
        # state space
        self.n_dim = self.length * 2**self.n_agents + 2
        # active state space
        self.active_n_dim = self.n_dim - 2**self.n_agents - 1
        self.terminal_state = self.n_dim - 1
        self.nodes = 2 * self.length + 2 # -1, 0, ..., self.nodes - 2
        self.initial_node = -1
        self.final_node = self.nodes - 2
        self.initial_pos = [-1] * self.n_agents
        self.final_pos = [self.final_node] * self.n_agents
        self.gamma = .99
        # alpha here is an environment parameter related to the agents' policies
        # at the robustness experiment
        self.alpha = alpha
        self.est_err = est_err
        # version parameter determines the variant we consider
        # version=0 corresponds to the robustness experiment
        self.version = version

        self.V_star = self.value_iteration()
        if version > 0:
            # behavior policies Exp 2 (coordination)
            self.instantiate_behavior_joint_policy_c()
        else:
            # behavior policies Exp 3 (robustness)
            self.instantiate_behavior_joint_policy_r()

    # transforms position to state
    def pos_to_state(self, pos):
        if pos == self.initial_pos:
            s = 0
        elif pos == self.final_pos:
            s = self.terminal_state
        else:
            s = np.sum([(pos[i] % 2) * 2**i for i in range(self.n_agents)])
            s = s + (pos[0] // 2) * 2**self.n_agents + 1
        return s

    # transforms state to position
    def state_to_pos(self, s):
        if s == 0:
            pos = self.initial_pos
        elif s == self.terminal_state:
            pos = self.final_pos
        else:
            pos = []
            t = s - 1
            L = t // 2**self.n_agents
            t = t - L * 2**self.n_agents
            for i in range(self.n_agents):
                pos.append(2 * L + t % 2)
                t = t // 2
        return pos    
    
    # reward function (based on version)
    def R(self, s, a):
        # coordination experiment
        if self.version > 0:
            return self.R_v(s, a)
        # robustness experiment
        if s >= self.terminal_state - 2**self.n_agents:
            return 0
        else:
            if self.n_agents % 2 == 0:
                if np.sum(a) == self.n_agents // 2:
                    return 1
                else:
                    return -1
            else:
                if np.sum(a) in [self.n_agents // 2, self.n_agents // 2 + 1]:
                    return 1
                else:
                    return -1

    def R_v(self, s, a):
        thres = self.version
        if s >= self.terminal_state - 2**self.n_agents:
            return 0
        else:
            if np.sum([a[i] * (i + 1) for i in range(self.n_agents)]) >= thres:
                return 1
            else:
                return -1
    

    # transitions dynamics
    def T(self, s, a):
        pos = self.state_to_pos(s)
        new_pos = []
        for i in range(self.n_agents):
            new_pos.append(self.T_a(pos[i], a[i]))
        new_state = self.pos_to_state(new_pos)
        return new_state


    # single-agent transitions dynamics
    def T_a(self, pos_i, a_i):
        # terminal state
        if pos_i == self.final_node:
            return pos_i
        elif pos_i >= self.final_node - 2:
            return self.final_node
        else:
            if pos_i % 2 == 0:
                return pos_i + 2 + a_i
            else:
                return pos_i + 1 + a_i

    def value_iteration(self, tol=1e-5):
        gamma = self.gamma
        U = dict([(s, 0) for s in range(self.n_dim)])
        while True:
            U_old = copy.deepcopy(U)
            delta = 0
            for s in range(self.active_n_dim):
                
                U[s] = max([
                    self.R(s, a) + gamma * U_old[self.T(s, a)] 
                    for a in self.joint_action_space
                    ])

                delta = max(delta, abs(U[s] - U_old[s]))

            if delta < tol:
                return U
    
    def expected_performance(self, U):
        return U[0]

    def instantiate_behavior_joint_policy_c(self):
        self.jq_tr = np.zeros(shape=(self.n_dim, self.n_agents, self.n_actions))
        for s in range(self.n_dim):
            for i in range(self.n_agents):
                self.jq_tr[s, i, 0] = 1  

    def instantiate_behavior_joint_policy_r(self):
        self.jq_tr = np.zeros(shape=(self.n_dim, self.n_agents, self.n_actions))
        self.jq_es = np.zeros(shape=(self.n_dim, self.n_agents, self.n_actions))
        self.jq_lo = np.zeros(shape=(self.n_dim, self.n_agents, self.n_actions))
        self.jq_up = np.zeros(shape=(self.n_dim, self.n_agents, self.n_actions))

        max_est_err = self.est_err
        for s in range(self.n_dim):
            pos = self.state_to_pos(s)
            b = [p % 2 for p in pos]
            for i in range(self.n_agents):
                if s == 0 or s >= self.terminal_state - 2**self.n_agents:
                    for a_i in range(self.n_actions):
                        self.jq_tr[s, i, a_i] = .5
                else:
                    if self.n_agents % 2 == 0:
                        if np.sum(b) == self.n_agents // 2:
                            self.jq_tr[s, i, b[i]] = 1.0 - self.alpha * i
                            assert self.jq_tr[s, i, b[i]] >= 0, "alpha or n_agents too big"
                            self.jq_tr[s, i, (b[i] + 1) % 2] = 1.0 - self.jq_tr[s, i, b[i]]
                        elif np.sum(b) < self.n_agents // 2:
                            self.jq_tr[s, i, 1] = 1.0 - self.alpha * i
                            assert self.jq_tr[s, i, 1] >= 0, "alpha or n_agents too big"
                            self.jq_tr[s, i, 0] = 1.0 - self.jq_tr[s, i, 1]
                        else:
                            self.jq_tr[s, i, 0] = 1.0 - self.alpha * i
                            assert self.jq_tr[s, i, 0] >= 0, "alpha or n_agents too big"
                            self.jq_tr[s, i, 1] = 1.0 - self.jq_tr[s, i, 0]
                    else:
                        if np.sum(b) in [self.n_agents // 2, self.n_agents // 2 + 1]:
                            self.jq_tr[s, i, b[i]] = 1.0 - self.alpha * i
                            assert self.jq_tr[s, i, b[i]] >= 0, "alpha or n_agents too big"
                            self.jq_tr[s, i, (b[i] + 1) % 2] = 1.0 - self.jq_tr[s, i, b[i]]
                        elif np.sum(b) < self.n_agents // 2:
                            self.jq_tr[s, i, 1] = 1.0 - self.alpha * i
                            assert self.jq_tr[s, i, 1] >= 0, "alpha or n_agents too big"
                            self.jq_tr[s, i, 0] = 1.0 - self.jq_tr[s, i, 1]
                        else:
                            self.jq_tr[s, i, 0] = 1.0 - self.alpha * i
                            assert self.jq_tr[s, i, 0] >= 0, "alpha or n_agents too big"
                            self.jq_tr[s, i, 1] = 1.0 - self.jq_tr[s, i, 0]

                est_err = np.random.rand()
                est_err = est_err * 2 * max_est_err - max_est_err

                self.jq_es[s, i, 0] = self.jq_tr[s, i, 0] + est_err
                self.jq_es[s, i, 0] = max(0, self.jq_es[s, i, 0])
                self.jq_es[s, i, 0] = min(1.0, self.jq_es[s, i, 0])
                self.jq_es[s, i, 1] = 1.0 - self.jq_es[s, i, 0]

                for a_i in range(self.n_actions):
                    self.jq_lo[s, i, a_i] = max(0, self.jq_es[s, i, a_i] - max_est_err)
                    self.jq_up[s, i, a_i] = min(1.0, self.jq_es[s, i, a_i] + max_est_err)                 

