"""
Implement the agent coordinator.
"""
from util import transform_to_2d
import numpy as np
import GPy


class Coordinator():

    def __init__(self, coordinator_config):

        self.coordinator_config = coordinator_config
        self.num_agents = coordinator_config['num_agents']
        self.agent_list = coordinator_config['agent_list']
        self.num_blackbox_constraints = \
            coordinator_config['num_blackbox_constraints']
        self.num_affine_constraints = \
            coordinator_config['num_affine_constraints']

        self.blackbox_dual = np.zeros(
            self.num_blackbox_constraints
        )
        self.affine_dual = np.zeros(
            self.num_affine_constraints
        )
        self.epsilon = coordinator_config['epsilon']
        self.affine_b = coordinator_config['affine_b']

        self.whole_dim = coordinator_config['whole_dim']
        self.total_grid = coordinator_config['total_grid']
        self.whole_black_box_funcs = \
            coordinator_config['whole_black_box_funcs']

        self.opt_val = coordinator_config['opt_val']
        self.opt_sol = coordinator_config['opt_sol']


    def primal_update(self, eval_and_udt=True):
        # using current dual variables to coordinate the local primal updates
        local_Ax_sum = 0
        local_constr_lg_sum = 0
        agent_list = self.agent_list
        constraints_list = []
        for k in range(self.num_blackbox_constraints):
            new_constraint_val = sum(
                agent_list[i].black_box_func_gp_list[k+1].Y[-1] for
                i in range(self.num_agents)
            )
            constraints_list.append(new_constraint_val)

        for k in range(self.num_agents):
            bo_agent = self.agent_list[k]
            if bo_agent.acq_type == 'primal_dual':
                local_Ax, lowerg = bo_agent.local_pd_primal_update(
                    self.blackbox_dual, self.affine_dual, eval_and_udt=eval_and_udt
                )
                local_Ax_sum += local_Ax
                local_constr_lg_sum += lowerg
            elif bo_agent.acq_type == 'cei':
                bo_agent.local_cei_primal_update(
                    constraints_list,
                    eval_and_udt=eval_and_udt
                )
                local_Ax_sum = None
                local_constr_lg_sum = None
            elif bo_agent.acq_type == 'penalty':
                x_sum = np.sum(
                    max(self.agent_list[i].black_box_func_gp_list[0].X[-1],
                           0.01) for i in range(self.num_agents)
                )
                x_old = self.agent_list[k].old_x #\
                    #(self.agent_list[k].black_box_func_gp_list[0].X[-1]/x_sum)\
                    #* self.affine_b
                bo_agent.local_penalty_primal_update(
                    x_old,
                    eval_and_udt=eval_and_udt
                )

                local_Ax_sum = None
                local_constr_lg_sum = None

            if self.agent_list[0].acq_type == 'penalty':
                opt_lamda = sum(self.agent_list[k].mu_dual
                                for k in range(self.num_agents)
                                )/self.num_agents +\
                    self.agent_list[0].rho * (
                        np.sum(
                            self.agent_list[i].black_box_func_gp_list[0].X[-1]
                            for i in range(self.num_agents)
                )
                  -self.affine_b  ) /self.num_agents
                for k in range(self.num_agents):
                    self.agent_list[k].old_x = \
                        (self.agent_list[k].mu_dual - opt_lamda) \
                        / self.agent_list[k].rho + \
                        self.agent_list[k].black_box_func_gp_list[0].X[-1]


        return local_Ax_sum, local_constr_lg_sum

    def dual_update(self, local_Ax_sum, local_constr_lg_sum):
        # update the dual variable
        self.blackbox_dual += (local_constr_lg_sum + self.epsilon)
        self.blackbox_dual = np.maximum(
            self.blackbox_dual, 0
        )
        self.affine_dual += (local_Ax_sum - self.affine_b)

    def update(self):
        local_Ax_sum, local_constr_lg_sum = self.primal_update()
        if self.agent_list[0].acq_type == 'primal_dual':
            self.dual_update(local_Ax_sum, local_constr_lg_sum)
