import logging
import os

import torch
from torch.utils.tensorboard import SummaryWriter

import constants
from channel import StochasticChannel, PerfectChannel, BlockingChannel, PermutationChannel, \
    SelectiveChannel
from constants import *
from env.env_parallelizer import SynchronousEnvParallelizer
from env.max_steps_wrapper import MaxStepsWrapper
from env.pomnist.rl_env import PartialMNISTEnv
from env.traffic_junction.traffic_junction_env import TrafficJunctionEnv
from env.traffic_junction_wrapper import TrafficJunctionWrapper
from evaluation import evaluate, evaluate_pomnist
from config import Config, ConfigPOMNIST, ConfigTrafficJunction, LinearSchedule
from modules.model import AdaComm, joint_inference
from training import visualize_multi_run_return, train
from util import add_eval_stats_text, eval_stats_to_metric_dict, add_hparams_to


def joint_random_policy(observations, messages, states, high=10):
    """
    A simple joint policy that returns random actions within 0 to high - 1 and empty messages.

    :param observations: Incoming observations
    :param messages: Incoming messages
    :param high: One above the highest action value that will be returned.
    :return: random action, 0 message actions, 0 message, empty dict
    """
    actions = torch.randint(0, high, size=observations.shape[0:2])
    message_actions = torch.zeros(observations.shape[0:2])
    messages_out = torch.zeros_like(messages)
    return actions, message_actions, messages_out, dict()


def joint_fixed_policy(observations, messages, states, action):
    """
    A simple joint policy that returns the given action and empty messages.

    :param observations: Incoming observations
    :param messages: Incoming messages
    :param action: Fixed action for all agents
    :return: given action, 0 message actions, 0 message, empty dict
    """
    actions = torch.ones(observations.shape[0:2]) * action
    message_actions = torch.zeros(observations.shape[0:2])
    messages_out = torch.zeros_like(messages)
    return actions, message_actions, messages_out, dict()


def create_channel(config: Config, num_agents: int):
    """
    Create a communication channel according to the given config.

    :param config: The config
    :param num_agents: (max) number of agents
    :return: A communication channel
    """
    if config.comm_channel_type == ChannelType.Blocking:
        return BlockingChannel()
    if config.comm_channel_type == ChannelType.SelectiveChannel:
        assert config.comm_channel_selective_allowed_agents is not None

        mask = torch.ones(num_agents)
        for id in config.comm_channel_selective_allowed_agents:
            mask[id] = 0

        return SelectiveChannel(mask)
    if config.comm_channel_size is None or config.comm_channel_type == ChannelType.Perfect:
        return PerfectChannel()
    if config.comm_channel_type == ChannelType.Stochastic:
        return StochasticChannel(config.comm_channel_size, use_msg_size_spacing=False)
    if config.comm_channel_type == ChannelType.StochasticSpacing:
        return StochasticChannel(config.comm_channel_size, use_msg_size_spacing=True)
    if config.comm_channel_type == ChannelType.Permutation:
        return PermutationChannel(config.comm_channel_size, priority_mode='equal')
    if config.comm_channel_type == ChannelType.PermutationPrioritized:
        return PermutationChannel(config.comm_channel_size, priority_mode='inverted_size')

    raise ValueError(f"Unknown channel type '{config.comm_channel_type}'")


def run(config, log_dir=None, num_eval_episodes=2000):
    """
    Execute a training and evaluation run with the given configuration.

    :param config: The configuration
    :param log_dir: Log directory
    :param num_eval_episodes: Default number of eval episodes (will be overwritten for specific environments)
    :return: run stats (list), eval stats (dict with keys "train" and "test"), tensorboard log dir
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f"Device: {device}")

    if config.env_name == POMNIST:
        # the number of agents is inferred from the environment splits
        env = PartialMNISTEnv(config.x_splits, config.y_splits, max_steps=config.max_steps, num_envs=config.num_envs)
        dummy_input = torch.as_tensor(env.observation_space.sample()[0]).float()
    elif config.env_name == TRAFFIC_JUNCTION:
        def make_env(i):
            e = TrafficJunctionEnv()
            e.init_ic3net_default(config.mode, config.train_curriculum)
            e.multi_agent_init()
            # adaptive curriculum based on number of iterations
            e.curr_start = int(e.curr_start / 2000.0 * config.num_iterations)
            e.curr_end = int(e.curr_end / 2000.0 * config.num_iterations)
            # Get the number of agents from on the environment
            config.num_agents = e.nagents
            if config.train_curriculum and i == 0:
                logging.info(f"Traffic junction curriculum: add rate {e.add_rate_min} ({e.curr_start}) to "
                             f"{e.add_rate_max} ({e.curr_end})")
            return MaxStepsWrapper(TrafficJunctionWrapper(e), e.max_steps)

        env = SynchronousEnvParallelizer(make_env, config.num_envs)
        dummy_input = torch.as_tensor(env.observation_space[0].sample()).float().unsqueeze(0)
    else:
        raise ValueError(f"Unknown environment name '{config.env_name}'! Implement the environment.")

    # create models
    if (config.msg_sizes == (None,)) or (config.msg_sizes == (0,)) or (config.msg_sizes == ()):
        config.use_messages = False
    else:
        config.use_messages = True

    if config.use_messages:
        if config.force_random_msg_size_selection:
            msg_size_selection = 'random'
        elif config.softmax_msg_size_selection:
            msg_size_selection = 'softmax'
        else:
            msg_size_selection = 'egreedy'

        agent_model = AdaComm(
            config.env_name, dummy_input, env.num_actions, env.num_agents, use_location_input=config.use_location_input,
            msg_sizes=config.msg_sizes, msg_decode_embedding_len=config.msg_decode_embedding_len,
            msg_decode_mode=config.msg_decode_mode, message_mode=config.message_mode,
            msg_size_selection=msg_size_selection, receive_own_message=config.receive_own_message,
            recurrent=config.model_recurrent,
            with_policy_head=config.model_with_policy_head
        ).to(device)
    else:
        agent_model = AdaComm(config.env_name, dummy_input, env.num_actions, env.num_agents,
                              use_location_input=config.use_location_input).to(device)

    if config.enable_tensorboard and config.load_model is None:
        writer = SummaryWriter(log_dir=log_dir, comment=config.run_comment, flush_secs=10)
        writer.add_text("config", str(config))
    else:
        writer = None

    channel = create_channel(config, env.num_agents)

    logging.info(f"Using channel {channel}")

    if config.load_model is None:
        # run training
        logging.info(f"Training")
        run_stats = train(env, agent_model, channel, config, device, writer)
    else:
        # load the model
        logging.info(f"Loading model at {config.load_model}")
        run_stats = None
        agent_model.load_state_dict(torch.load(config.load_model))

    # run evaluation
    agent_model.eval()

    def joint_model_policy(observations, messages, states):
        return joint_inference(agent_model, observations, messages, states, epsilon=0, msg_size_epsilon=0,
                               discrete_msg_epsilon=0, device=device)

    logging.info(f"Evaluation")

    if config.env_name == POMNIST:
        eval_stats = evaluate_pomnist(env, joint_model_policy, channel, config, writer)
    else:
        print("> Test")
        eval_stats = {'test': evaluate(env, joint_model_policy, channel, config, num_eval_episodes, writer)}

    if writer:
        add_eval_stats_text(eval_stats, writer)
        metric_dict = eval_stats_to_metric_dict(eval_stats, type_filter='test', key_filter='agent_mean_accuracy')
        add_hparams_to(writer, config.to_str_dict(), metric_dict)
        writer.close()

    del agent_model
    return run_stats, eval_stats, writer.get_logdir() if writer else None


def multi_run(config: Config, num_runs: int, base_log_dir=None):
    """
    Execute multiple runs with the given configuration.

    Individual runs are placed into new subdirectories within base_log_dir.
    If base_log_dir is None, the run comment gets extended by the run id.

    :param config: the configuration
    :param num_runs: the number of runs
    :param base_log_dir: base log directory
    :return: arrays of everything a single run returns
    """
    all_run_stats = []
    all_eval_stats = []
    all_log_dirs = []

    original_comment = config.run_comment

    # execute the runs
    for a in range(0, num_runs):
        if base_log_dir is None:
            # log all runs in default directory (using comment)
            config.run_comment = f"{original_comment}-run-{a}"
            log_dir_a = None
        else:
            # put all runs into given log directory (comment is ignored)
            log_dir_a = os.path.join(base_log_dir, f"run-{a}")

        run_stats, eval_stats, used_log_dir = run(config, log_dir=log_dir_a)

        all_run_stats.append(run_stats)
        all_eval_stats.append(eval_stats)
        all_log_dirs.append(used_log_dir)

    # restore original comment
    config.run_comment = original_comment

    return all_run_stats, all_eval_stats, all_log_dirs


def main():
    logging.basicConfig(format='%(levelname)s > %(message)s', level=logging.INFO)

    debug_env = constants.POMNIST
    global_run_comment_suffix = "default"

    if debug_env == constants.TRAFFIC_JUNCTION:
        config = ConfigTrafficJunction(
            # config with only 2 agents (sanity check)
            # num_iterations=200, num_envs=64, num_episodes=1, mode='super-easy', train_curriculum=False, msg_sizes=(8,), learning_rate=1e-3, entropy_schedule=ConstantSchedule(0),
            # actual environment
            num_iterations=2000, num_envs=128, num_episodes=1, mode='easy', train_curriculum=True, msg_sizes=(0, 32, 128), learning_rate=1e-3,
            discount_factor=1.0,
            msg_loss_weight=0.1,
            gradient_clip_max_norm=0.1,
            use_location_input=False,
            comm_channel_type=ChannelType.StochasticSpacing,
            comm_channel_size=512,
            softmax_msg_size_selection=True,
            message_mode=PSEUDOGRADIENT,
            receive_own_message=False,
            log_eval_plots=False,
            run_comment="TJ"
        )
        config.run_comment += f"_{config.mode}"
    elif debug_env == constants.POMNIST:
        config = ConfigPOMNIST(
            x_splits=1,
            y_splits=1,
            num_iterations=2000,
            learning_rate=1e-3,
            comm_channel_type=ChannelType.StochasticSpacing,
            comm_channel_size=8,
            msg_sizes=(0, 1, 2, 4),
            softmax_msg_size_selection=True,
            message_mode=PSEUDOGRADIENT,
            receive_own_message=False,
            log_eval_plots=False,
            run_comment="POM"
        )
    else:
        raise ValueError(f"Unknown debug env {debug_env}")

    config.set_default_schedules()
    config.run_comment += f"_{global_run_comment_suffix}"

    print(config)

    # the number of runs with your config. Will plot aggregated stats if > 1
    num_runs = 1

    if num_runs == 1:
        run(config)
    else:
        assert num_runs > 1, "Number of runs must be >= 1!"

        config.log_eval_plots = False
        all_run_stats, all_eval_stats, all_log_dirs = multi_run(config, num_runs)

        # add aggregated plots to the last run
        writer = SummaryWriter(log_dir=all_log_dirs[-1])
        visualize_multi_run_return(all_run_stats, config, writer)
        add_eval_stats_text(all_eval_stats, writer)
        writer.close()


if __name__ == '__main__':
    main()
