"""
Implement the agent coordinator.
"""
from util import transform_to_2d
from bo_agent import BO_AGENT
from power_allocation_agent_config import get_power_allocation_agent_config
from coordinator import Coordinator
import numpy as np
import GPy
import itertools


class PowerAllocationInstance():

    def __init__(self):

        self.num_agents = 10 #5
        self.run_horizon = 100
        self.M = 1e20

    def generate_one_agent(self):
        power_allocation_config = get_power_allocation_agent_config()
        power_allocation_config['local_A'] = np.ones(1)

        power_allocation_config['acq_type'] = 'primal_dual'
        primal_dual_agent = BO_AGENT(power_allocation_config)

        power_allocation_config['acq_type'] = 'penalty'
        penalty_agent = BO_AGENT(power_allocation_config)
        return primal_dual_agent, penalty_agent

    def generate_one_instance(self):
        max_Power = 3
        coordinator_config = dict()
        coordinator_config['num_agents'] = self.num_agents

        primal_dual_agent_list = []
        penalty_agent_list = []
        for k in range(self.num_agents):
            new_primal_dual_agent, new_penalty_agent = self.generate_one_agent()
            primal_dual_agent_list.append(new_primal_dual_agent)
            penalty_agent_list.append(new_penalty_agent)

        agent_0 = primal_dual_agent_list[0]
        coordinator_config['num_blackbox_constraints'] = \
            agent_0.num_black_box_constrs
        coordinator_config['num_affine_constraints'] = 1
        coordinator_config['epsilon'] = 0.2
        coordinator_config['affine_b'] = max_Power

        coordinator_config['whole_dim'] = self.num_agents * \
            primal_dual_agent_list[0].dim_x
        coordinator_config['total_grid'] = list(itertools.product(
            primal_dual_agent_list[0].x_grid,
            primal_dual_agent_list[1].x_grid,
            primal_dual_agent_list[2].x_grid,
            primal_dual_agent_list[3].x_grid,
            primal_dual_agent_list[4].x_grid
        )
        )

        whole_black_box_funcs = []
        #for k in range(agent_0.num_black_box_constrs + 1):
        whole_black_box_funcs.append(
                lambda x: sum(
                primal_dual_agent_list[i].black_box_funcs_list[0](x[i])
                for i in range(self.num_agents))
            )

        coordinator_config['whole_black_box_funcs'] = whole_black_box_funcs
        x_grid = primal_dual_agent_list[0].x_grid

        M = self.M
        tol_thr = 0.03
        whole_obj = whole_black_box_funcs[0]
        whole_constraints_list = []
        #constr_func_val = [whole_obj(x) +
        #               M * (1 - (np.abs(
        #                   sum(x[i] for i in range(self.num_agents))- max_Power
        #                               ) <= tol_thr * max_Power)
        #                    )
        #                   for x in coordinator_config['total_grid']
        #                   ]

        # For power allocation problem, assume no knowledge of ground truth
        opt_val = 0 #np.min(constr_func_val)
        opt_val_id = 0 #np.argmin(constr_func_val)

        avg_power = max_Power * 1.0 / self.num_agents
        avg_power_sol = [avg_power] * self.num_agents
        obj_avg_sol = whole_obj(avg_power_sol)

        coordinator_config['obj_avg_sol'] = obj_avg_sol

        opt_sol = coordinator_config['total_grid'][opt_val_id]
        coordinator_config['opt_val'] = opt_val
        coordinator_config['opt_sol'] = opt_sol

        coordinator_config['agent_list'] = primal_dual_agent_list
        primal_dual_power_allocation_coordinator = Coordinator(
            coordinator_config)

        coordinator_config['agent_list'] = penalty_agent_list
        penalty_power_allocation_coordinator = Coordinator(
            coordinator_config)

        return primal_dual_power_allocation_coordinator, \
            penalty_power_allocation_coordinator

    def run_one_instance(self):
        max_while = 50
        num_while = 0
        while num_while <= max_while:
            pd_pac, penalty_pac = self.generate_one_instance()
            if pd_pac.opt_val < 0.5 * self.M:
               break
            num_while += 1

        for k in range(self.run_horizon):
            pd_pac.update()

        for k in range(self.run_horizon):
            penalty_pac.update()

        return pd_pac, penalty_pac
