"""
Deprecated!
"""
import ray
import copy
import numpy as np

from typing import Callable, Dict, Any, Sequence

from expground.types import SimulationFeedback, MetricType, PolicyID, AgentID
from expground.envs.agent_interface import AgentInterface
from expground.utils.metric import get_metric_handler
from expground.utils.policy_pool import PolicyPool


@ray.remote
class SimulationThreadPool:
    def __init__(
        self,
        cur_aid,
        cur_pid,
        other_agents,
        metric_type,
        env_creator,
        env_configs,
        mapping_func,
        selected_population_pairs,
        populations,
        fragment_length,
        num_episode,
    ):
        parameters = locals()
        for k, v in parameters.items():
            if k == "self":
                continue
            setattr(self, k, v)

        # convert policy pool to agent interfaces here
        self.agent_interfaces = {
            agent_id: ray.get(
                populations[mapping_func(agent_id)].to_agent_interface.remote()
            )
            for agent_id in other_agents
        }
        self.agent_interfaces[cur_aid] = ray.get(
            populations[mapping_func(cur_aid)].to_agent_interface.remote()
        )

    def run(self, agent_combination) -> SimulationFeedback:
        env = self.env_creator(**self.env_configs)
        metric_handler = get_metric_handler(self.metric_type)(env.possible_agents)
        other_agents = self.other_agents
        assert len(other_agents) == len(agent_combination), (
            len(agent_combination),
            len(other_agents),
        )

        # define behavior agent id mapping here
        policy_pair_desc = {
            other_agents[j]: self.selected_population_pairs[other_agents[j]][idx]
            for j, idx in enumerate(agent_combination)
        }
        policy_pair_desc[self.cur_aid] = self.cur_pid

        # generate agent interface
        # self.agent_interfaces.update(
        #     {
        #         aid: AgentInterface(
        #             policy_name=pid,
        #             policy=ray.get(self.populations[aid].get_policy.remote(pid)),
        #             observation_space=env.observation_spaces[aid],
        #             action_space=env.action_spaces[aid],
        #         )
        #         for aid, pid in policy_pair_desc.items()
        #     }
        # )

        simulation_feedback = SimulationFeedback(
            self.num_episode,
            {
                aid: policy_pair_desc[aid]
                for aid, interface in self.agent_interfaces.items()
            },
        )

        episode_th = 0
        while episode_th < self.num_episode:
            env.reset()
            metric_handler.reset()
            action = None
            for aid in env.agent_iter(
                max_iter=self.fragment_length * len(env.possible_agents)
            ):
                observation, reward, done, info = env.last()
                metric_handler.step(aid, observation, action, reward, done, info)
                simulation_feedback.step(
                    aid, observation, action, reward / self.num_episode, info
                )
                if not done:
                    action_mask = (
                        observation["action_mask"]
                        if isinstance(observation, dict)
                        else None
                    )
                    action = self.agent_interfaces[aid].compute_action(
                        observation, action_mask, policy_id=policy_pair_desc[aid]
                    )
                    env.step(action)
                else:
                    env.step(None)
            episode_th += 1

        return simulation_feedback


@ray.remote
def run_simulations(
    metric_type: str,
    env_creator: Callable,
    env_config: Dict[str, Any],
    populations: Dict[str, PolicyPool],
    brs: Dict[AgentID, PolicyID],
    fragment_length: int,
    n_simulation: int,
    mapping_func: Callable,
) -> Sequence[SimulationFeedback]:
    """Run simulation, if the transferred payoff_matrix_interface is not None, then we will
    use the simulation result to update the payoff_matrix.

    :param metric_type: str, the name of metric handler
    :param env_creator: object, environment creator
    :param env_config: Dict[str, Any], environment configuration
    :param populations: Dict[str, Any], population dictionary
    :param brs: Dict[AgentID, PolicyID], trainable agents, best response ids
    :param fragment_length: int, the fragment length
    :param n_simulation: int, the number of simulation for each policy combination
    :param mapping_func: Callable, mapping agent id to pool id
    :return: a sequence of simulation feedbacks
    """

    env = env_creator(**env_config)
    agents = copy.copy(env.possible_agents)
    observation_spaces = env.observation_spaces
    action_spaces = env.action_spaces
    for k in brs:
        idx = agents.index(k)
        agents.pop(idx)
    other_agents = agents
    mesh_grids = ray.get(
        [
            populations[mapping_func(k)].get_population_size.remote()
            for k in other_agents
        ]
    )
    mesh_grids = [np.arange(e) for e in mesh_grids]
    agent_id_set = set(populations.keys())

    # (agent_id, [pair, ...])
    population_pairs = {
        aid: ray.get(populations[mapping_func(aid)].get_population_desc.remote())
        for aid in agents
    }

    for cur_aid in brs:
        # for other agents who will select a policy from its population set
        selected_population_pairs = {aid: population_pairs[aid] for aid in other_agents}
        # gen combinations
        agent_combinations = list(
            zip(*map(lambda x: x.reshape(-1), np.meshgrid(*mesh_grids)))
        )

        # shared selected population pairs, reduce the transfer cost
        selected_population_pairs_id = ray.put(selected_population_pairs)
        # since trainable policy hasn't be register yet, also to reduce the communication cost
        # we make it shareable

        # make it be multiprocess
        simulation_thread_pool = SimulationThreadPool.remote(
            cur_aid,
            brs[cur_aid],
            other_agents,
            metric_type,
            env_creator,
            env_config,
            mapping_func,
            selected_population_pairs_id,
            populations,
            fragment_length,
            n_simulation,
        )
        result_ids = [
            simulation_thread_pool.run.remote(agent_combination)
            for agent_combination in agent_combinations
        ]
        results = ray.get(result_ids)

        return results
