import traceback
import ray
import numpy as np

from expground.types import Sequence, Any, Dict, AgentID, PolicyID, Tuple, List
from expground.logger import Log
from expground.utils import rollout


def dispatch_resources_for_agent(
    agents: Sequence[str], use_cuda: bool = False, resource_limit={}
) -> Dict[str, Any]:
    """Dispatching computing resources for leaners.

    Args:
        agents (Sequence[str]): A sequence of agents.
        use_cuda (bool, optional): Use CUDA or not. Defaults to False.

    Returns:
        Dict[str, Any]: A dict of resource configuration for all agents.
    """

    resources = ray.available_resources()
    resources.update(resource_limit)
    Log.info("Ray available resources: %s", resources)
    n_agent = len(agents)
    res = {
        "num_gpus": resources["GPU"] / n_agent if use_cuda else 0,
        "num_cpus": resources["CPU"] / n_agent,
        # "memory": resources["memory"] / n_agent,
        # "object_store_memory": resources["object_store_memory"] / n_agent
    }
    return res


def generate_random_from_shapes(
    equilibrium: Dict[AgentID, Dict[PolicyID, float]]
) -> Dict[AgentID, Dict[PolicyID, float]]:
    res = {}
    for aid, dist_dict in equilibrium.items():
        logits = np.random.rand(len(dist_dict)).tolist()
        res[aid] = dict(zip(dist_dict.keys(), logits))
    return res


def work_flow(
    optimize_runtime_config,
    rollout_runtime_config,
    evaluation_runtime_config,
    time_step_th: int = 0,
    episode_th: int = 0,
) -> Tuple[List, List]:
    caller = rollout_runtime_config["caller"]
    sampler = rollout_runtime_config["sampler"]

    generator = caller(
        sampler=rollout_runtime_config["sampler"],
        agent_policy_mapping=rollout_runtime_config["behavior_policies"],
        agent_interfaces=rollout_runtime_config["agent_interfaces"],
        env_description=rollout_runtime_config["env_desc"],
        fragment_length=rollout_runtime_config["fragment_length"],
        max_step=rollout_runtime_config["max_step"],
        episodic=optimize_runtime_config["episodic"],
        train_every=optimize_runtime_config["train_every"],
        evaluate=False,
    )

    trainer = optimize_runtime_config["trainer"]

    total_episodes = episode_th
    total_timesteps = time_step_th
    epoch_training_statistic = []

    try:
        start_timesteps = 0
        while True:
            info = next(generator)
            tmp = {
                _aid: _trainer(sampler, agent_filter=[_aid], time_step=start_timesteps)
                for _aid, _trainer in trainer.items()
            }
            start_timesteps += info["timesteps"]
            epoch_training_statistic.append(tmp)
    except StopIteration as e:
        info = e.value
        total_episodes += info["num_episode"]
        total_timesteps += info["total_timesteps"]
    except Exception:
        Log.error(traceback.format_exc())

    try:
        if sampler.is_ready() and optimize_runtime_config["episodic"]:
            for _ in range(optimize_runtime_config["mini_epoch"]):
                epoch_training_statistic = [
                    {
                        _aid: _trainer(
                            sampler,
                            agent_filter=[_aid],
                            time_step=total_timesteps,
                        )
                        for _aid, _trainer in trainer.items()
                    }
                ]

        # run evaluation
        epoch_evaluation_statistic = rollout.Evaluator.run(
            policy_mappings=evaluation_runtime_config["policy_mappings"],
            max_step=evaluation_runtime_config["max_step"],
            fragment_length=evaluation_runtime_config["fragment_length"],
            agent_interfaces=evaluation_runtime_config["agent_interfaces"],
            rollout_caller=evaluation_runtime_config["caller"],
            env_desc=evaluation_runtime_config["env_desc"],
            seed=evaluation_runtime_config["seed"],
        )
    except Exception:
        Log.error(traceback.format_exc())

    return (
        epoch_training_statistic,
        epoch_evaluation_statistic,
        total_timesteps,
        total_episodes,
    )
