import sys
from datetime import datetime
from pathlib import Path
from typing import NamedTuple

import torch
import torch.nn.functional
from torch.optim import Adam
from tqdm import tqdm

from channel import Channel
from env.pomnist.vis_util import *
from env.pomnist.vis_util import add_discrete_value_hist
from config import Config
from modules.messages import get_message_sizes, empty_message_tensor
from modules.model import AdaComm, joint_inference
from util import write_dict_scalars, get_grad_norm


class StepRecord(NamedTuple):
    """
    Contains information of a single step (in parallel environments).
    """

    # main step information
    observations: torch.Tensor
    actions: torch.Tensor
    rewards: torch.Tensor
    next_observations: torch.Tensor
    done: bool
    # q values (to keep the whole graph)
    q_values: Optional[torch.Tensor]
    log_pi_a: Optional[torch.Tensor]
    # messages
    messages: Optional[torch.Tensor]  # only logged on demand
    msg_actions: Optional[torch.Tensor]
    q_msg_sizes: Optional[torch.Tensor]
    discrete_msg_selected_q_out: Optional[torch.Tensor]
    # additional information
    step: int
    info: dict
    num_dropped: Optional[torch.Tensor]
    channel_util: Optional[torch.Tensor]
    alive_mask: Optional[torch.Tensor]
    agent_done: Optional[torch.Tensor]
    success: Optional[torch.Tensor]

    def get_q_a(self):
        """
        Get the (Q-)values of the selected actions. Returns value independent of action if last dim of Q is 1.

        :return: (Q-)values of the selected actions
        """
        if self.q_values.shape[-1] == 1:
            return self.q_values.squeeze(-1)

        index = self.actions.unsqueeze(-1)
        gathered = torch.gather(self.q_values, -1, index).squeeze(-1)
        return gathered

    def get_q_msg(self):
        """
        Get the Q-values of the selected message sizes.

        :return: Q-values of the selected message size
        """
        index = self.msg_actions.unsqueeze(-1)
        gathered = torch.gather(self.q_msg_sizes, -1, index).squeeze(-1)
        return gathered

    def get_num_agents(self):
        return self.observations.shape[0]

    def get_num_envs(self):
        return self.observations.shape[1]

    def to(self, device):
        return StepRecord(
            observations=self.observations.to(device),
            actions=self.actions.to(device),
            rewards=self.rewards.to(device),
            next_observations=self.next_observations.to(device),
            done=self.done,
            q_values=None if self.q_values is None else self.q_values.to(device),
            log_pi_a=None if self.log_pi_a is None else self.log_pi_a.to(device),
            messages=None if self.messages is None else self.messages.to(device),
            msg_actions=None if self.msg_actions is None else self.msg_actions.to(device),
            q_msg_sizes=None if self.q_msg_sizes is None else self.q_msg_sizes.to(device),
            discrete_msg_selected_q_out=(
                None if self.discrete_msg_selected_q_out is None else self.discrete_msg_selected_q_out.to(device)
            ),
            step=self.step,
            info=self.info,
            num_dropped=None if self.num_dropped is None else self.num_dropped.to(device),
            channel_util=None if self.channel_util is None else self.channel_util.to(device),
            alive_mask=None if self.alive_mask is None else self.alive_mask.to(device),
            agent_done=None if self.agent_done is None else self.agent_done.to(device),
            success=None if self.success is None else self.success.to(device)
        )


class ElementWiseRunningAverage:
    """
    Running average over tensor, each element can have different number of samples.
    """
    def __init__(self, shape):
        self.running_average = torch.empty(shape)
        self.num_samples = torch.empty(shape)
        self.reset()

    def reset(self):
        self.running_average = torch.zeros_like(self.running_average)
        self.num_samples = torch.zeros_like(self.num_samples)

    def update(self, new_values, mask=None):
        """
        Update the running average with the given values.

        :param new_values: new observed values
        :param mask: optional, can be used to mask elements in the new values & running average tensors
        """
        if mask is None:
            mask = torch.ones_like(new_values, dtype=torch.bool)

        self.running_average[mask] = (self.num_samples[mask] * self.running_average[mask] + new_values[mask]) \
                                     / (self.num_samples[mask] + 1)
        self.num_samples[mask] += 1

    def total_samples(self):
        return torch.sum(self.num_samples)

    def mean(self, dim):
        return torch.sum(self.running_average * self.num_samples, dim=dim) / torch.sum(self.num_samples, dim=dim)


def run_env(env: PartialMNISTEnv, msg_sizes, joint_policy, num_episodes, channel: Channel, detach_gap: int,
            log_messages=False, record_device=None, show_progress=False) -> List[StepRecord]:
    """
    Runs the environment(s) using the given models for the specified number of episodes.

    :param env: The MNIST environment
    :param msg_sizes: Message sizes used by the policy
    :param joint_policy: The joint policy to get actions, messages (and additional stats) for all agents
    :param num_episodes: The number of episodes (does not account for parallel environments!)
    :param channel: The communication channel
    :param detach_gap: The number of steps until which message & rnn state tensors are detached
    :param log_messages: Whether to log messages in step records
    :param record_device: Torch device used for storing records. If None, use the same device as param "device"
    :param show_progress: Whether to show the progress
    :return: A list with the records of all steps
    """
    step_records = []

    episode_range = range(0, num_episodes)
    if show_progress:
        episode_range = tqdm(episode_range, file=sys.stdout)

    for _ in episode_range:
        done = False
        observations = torch.tensor(env.reset(), dtype=torch.float)

        step = 0
        messages = empty_message_tensor(n_agents=observations.shape[0], msg_sizes=msg_sizes,
                                        batch_size=observations.shape[1])
        msg_actions = None
        num_dropped = None
        channel_util = None
        alive_mask = None
        agent_done = None
        success = None
        states = None
        while not done:
            if step > 0 and channel is not None and (msg_sizes is not None and len(msg_sizes) > 0):
                if msg_actions is not None:
                    message_sizes = get_message_sizes(msg_actions, msg_sizes)
                else:
                    assert len(msg_sizes) == 1, f"Expected message actions for msg sizes {msg_sizes}"
                    if messages is None:
                        messages = empty_message_tensor(n_agents=observations.shape[0], msg_sizes=msg_sizes,
                                                        batch_size=observations.shape[1])
                    # all agents choose the first (and only) message size
                    message_sizes = torch.ones(messages.shape[:-1]) * msg_sizes[0]

                # fix message sizes for deleted messages
                if alive_mask is not None:
                    message_sizes[~alive_mask] = 0

                messages, num_dropped, channel_util = channel(messages, message_sizes)

            actions, msg_actions, messages, states, policy_info = joint_policy(observations, messages, states)
            next_observations, rewards, done, info = env.step(actions)

            if 'agent_done' in info:
                assert 'alive_mask' in info, "We need an alive_mask if agents have individual done signals!"
                alive_mask = torch.Tensor(info['alive_mask']).bool()
                agent_done = torch.Tensor(info['agent_done']).bool()

                if messages is not None:
                    # only alive agents can send messages
                    messages[~alive_mask] = 0
                    if (step + 1) % detach_gap == 0:
                        messages.detach_()

                if states is not None:
                    # mask states of dead and done agents, so they don't bleed into new episodes
                    for i in range(0, alive_mask.shape[0]):
                        mask = alive_mask[i].unsqueeze(0).unsqueeze(-1)
                        # LSTM
                        if isinstance(states[i], tuple):
                            mask = mask.to(states[i][0].device)
                            states[i] = tuple([s_i_j * mask for s_i_j in states[i]])
                            if (step + 1) % detach_gap == 0:
                                for s_i_j in states[i]:
                                    s_i_j.detach_()
                        # assume GRU
                        else:
                            mask = mask.to(states[i].device)
                            states[i] = states[i] * mask
                            if (step + 1) % detach_gap == 0:
                                states[i].detach_()

            if 'success' in info:
                success = torch.Tensor(info['success']).bool()

            done = done or (hasattr(env, 'max_steps') and step + 1 >= env.max_steps)

            next_observations = torch.tensor(next_observations, dtype=torch.float)
            rewards = torch.tensor(rewards, dtype=torch.float)

            record = StepRecord(
                observations=observations, actions=actions, rewards=rewards, next_observations=next_observations,
                done=done, q_values=policy_info.get('q_values', None), log_pi_a=policy_info.get('log_pi_a', None),
                messages=messages if log_messages else None, msg_actions=msg_actions,
                q_msg_sizes=policy_info.get('q_msg_sizes', None),
                discrete_msg_selected_q_out=policy_info.get('discrete_msg_selected_q_out', None), step=step, info=info,
                num_dropped=num_dropped, channel_util=channel_util, alive_mask=alive_mask, agent_done=agent_done,
                success=success
            )

            if record_device is not None:
                step_records.append(record.to(record_device))
            else:
                step_records.append(record)

            observations = next_observations
            step += 1

    return step_records


def get_losses(step_records: List[StepRecord], discount_factor, receive_own_message, act_entropy_coefficient=0,
               act_loss_only_when_episode_done=False):
    """
    Calculate the losses for the given step records.

    :param step_records: A list of step records
    :param discount_factor: The discount factor
    :param receive_own_message: Whether agents receive their own messages
    :param act_entropy_coefficient: Entropy coefficient for action selection (higher value => more entropy)
    :param act_loss_only_when_episode_done: Only apply action loss when the episode (!) is done (for debugging).
    :return: A tuple with four elements containing the action loss, action loss counts, message loss and message
             loss counts.
    """
    num_agents = step_records[0].get_num_agents()
    num_envs = step_records[0].get_num_envs()

    # torch.zeros is important here because we add up multiple losses
    act_loss = torch.zeros(num_agents, num_envs)
    act_loss_count = torch.zeros(num_agents)

    msg_loss = torch.zeros(num_agents, num_envs)
    msg_loss_count = torch.zeros(num_agents)

    has_q_msg = step_records[0].q_msg_sizes is not None
    has_discrete_msg = step_records[0].discrete_msg_selected_q_out is not None

    rewards = torch.zeros(num_agents, num_envs)
    returns = torch.zeros(num_agents, num_envs)

    alive_mask = None
    for i, step in enumerate(reversed(step_records)):
        # calculate msg_target for msg size selection and discrete messages
        if alive_mask is not None and not step.done and (has_q_msg or has_discrete_msg):
            if receive_own_message:
                # any agent receives the message
                has_receiver_mask = (alive_mask.sum(dim=0) >= 1).unsqueeze(0)
            else:
                # only other agents receive the message
                has_receiver_mask = (alive_mask.sum(dim=0).unsqueeze(0) - alive_mask.float()) >= 1
        else:
            # reset message target => no loss for this step
            has_receiver_mask = None

        done_mask = torch.ones((num_agents, num_envs)).bool() * step.done
        alive_mask = torch.ones((num_agents, num_envs)).bool()

        if step.agent_done is not None:
            done_mask |= step.agent_done
            alive_mask &= step.alive_mask

        num_alive = alive_mask.sum(dim=-1)
        alive_and_not_done = ~done_mask & alive_mask

        # update returns
        next_returns = returns.clone()
        next_rewards = rewards.clone()

        rewards = step.rewards.detach()
        returns *= alive_and_not_done
        returns = rewards + discount_factor * returns

        # message selection loss
        if has_receiver_mask is not None:
            msg_mask = alive_mask & has_receiver_mask

            # squared error for message size selection
            if has_q_msg:
                target = next_returns.sum(dim=0)

                # if we don't receive our own message, the current message cannot have any effect on our next reward
                if not receive_own_message:
                    target = target - next_rewards

                target /= num_agents
                msg_loss += torch.pow(step.get_q_msg() - target, 2) * msg_mask
                msg_loss_count += msg_mask.sum(dim=-1)

            # squared error for discrete messages
            if has_discrete_msg:
                target = next_returns.sum(dim=0)

                # if we don't receive our own message, the current message cannot have any effect on our next reward
                if not receive_own_message:
                    target = target - next_rewards

                target /= num_agents
                msg_loss += torch.pow(step.discrete_msg_selected_q_out.squeeze(-1) - target, 2) * msg_mask
                msg_loss_count += msg_mask.sum(dim=-1)

        # action selector losses
        if step.done or not act_loss_only_when_episode_done:
            # value estimation bootstrapping
            action_target = returns

            # q-value loss
            act_loss += torch.pow(step.get_q_a() - action_target, 2) * alive_mask
            act_loss_count += num_alive

            # if available, policy loss
            if step.log_pi_a is not None:
                advantage = action_target - torch.max(step.q_values.detach(), -1)[0]
                log_pi_a_selected = torch.gather(step.log_pi_a, dim=-1, index=step.actions.unsqueeze(-1)).squeeze(-1)

                act_loss += -log_pi_a_selected * advantage * alive_mask
                act_loss_count += num_alive

                # action selection entropy bonus
                entropy = (step.log_pi_a * torch.exp(step.log_pi_a) * alive_mask.unsqueeze(-1)).sum(dim=-1)
                act_loss += act_entropy_coefficient * entropy
                act_loss_count += act_entropy_coefficient * num_alive

    # mean error (over environment dimension, not agents)
    agent_act_loss = torch.sum(act_loss, dim=-1) / torch.clamp(act_loss_count, min=1)
    if has_q_msg or has_discrete_msg:
        agent_msg_loss = torch.sum(msg_loss, dim=-1) / torch.clamp(msg_loss_count, min=1)
    else:
        agent_msg_loss = None

    return agent_act_loss, act_loss_count, agent_msg_loss, msg_loss_count


def get_step_record_stats(step_records: List[StepRecord], msg_sizes, num_actions) -> dict:
    """
    Calculate statistics about the given step records.

    :param step_records: A list of step records
    :param msg_sizes: Message sizes
    :param num_actions: Number of actions
    :return: A dict containing different statistics (which could be printed)
    """
    num_agents = step_records[0].get_num_agents()
    num_envs = step_records[0].get_num_envs()

    avg_actions = ElementWiseRunningAverage((num_agents, num_envs, num_actions))
    cumulative_reward = torch.zeros((num_agents, num_envs))
    avg_returns = ElementWiseRunningAverage((num_agents, num_envs))
    avg_q = ElementWiseRunningAverage((num_agents, num_envs))
    avg_msg_size = ElementWiseRunningAverage((num_agents, num_envs))

    avg_num_drops = ElementWiseRunningAverage(num_envs)
    avg_channel_util = ElementWiseRunningAverage(num_envs)

    has_q_msgs = step_records[0].q_msg_sizes is not None
    avg_q_msgs = ElementWiseRunningAverage((num_agents, num_envs)) if has_q_msgs else None
    avg_alive = ElementWiseRunningAverage((num_agents, num_envs)) if step_records[0].alive_mask is not None else None
    avg_success = ElementWiseRunningAverage(num_envs) if step_records[0].success is not None else None

    agent_steps = 0
    for _, step in enumerate(step_records):
        # log taken actions
        actions_taken = torch.zeros_like(avg_actions.running_average)
        actions_taken.scatter_(-1, step.actions.unsqueeze(-1), 1.0)
        if step.alive_mask is not None:
            avg_actions.update(actions_taken, mask=(step.alive_mask | step.agent_done))
        else:
            avg_actions.update(actions_taken)

        if step.alive_mask is None:
            cumulative_reward += step.rewards.detach()
            avg_q.update(step.get_q_a().detach())
            agent_steps += num_envs * num_agents
        else:
            cumulative_reward += step.rewards.detach() * step.alive_mask
            avg_q.update(step.get_q_a().detach(), step.alive_mask)
            # ATTENTION: We count the last step of an episode as agent_done => agents are not alive anymore
            if not step.done:
                avg_alive.update(step.alive_mask.float())
                agent_steps += step.alive_mask.sum()

        if step.agent_done is not None:
            # update returns of done episodes and reset cumulative reward for done agents
            avg_returns.update(cumulative_reward, step.agent_done)
            cumulative_reward *= ~step.agent_done

        if step.done:
            if step.agent_done is None:
                avg_returns.update(cumulative_reward)
            else:
                # ATTENTION: We count the last step of an episode as agent_done => log reward
                avg_returns.update(cumulative_reward, step.agent_done + step.alive_mask)

            # reset cumulative reward
            cumulative_reward[:] = 0

            if avg_success is not None:
                avg_success.update(step.success.float())
        else:
            if has_q_msgs:
                # only consider message q values which had an actual effect on the episode/other agents (not done)
                avg_q_msgs.update(step.get_q_msg().detach())

            if step.msg_actions is not None:
                chosen_msg_sizes = get_message_sizes(step.msg_actions, msg_sizes)
            else:
                chosen_msg_sizes = torch.ones((num_agents, num_envs)) * msg_sizes[0]

            if step.alive_mask is not None:
                chosen_msg_sizes *= step.alive_mask

            avg_msg_size.update(chosen_msg_sizes.float(), step.alive_mask)

        if step.channel_util is not None and step.step > 0:
            # step > 0 as there are no messages in the initial step
            avg_num_drops.update(step.num_dropped.float())
            avg_channel_util.update(step.channel_util.float())

    result_dict = {
        "agent_steps": agent_steps,
        "avg_returns": avg_returns.mean(-1),
        "avg_q": avg_q.mean(-1),
    }
    if has_q_msgs:
        result_dict["avg_q_msgs"] = avg_q_msgs.mean(-1)

    if avg_channel_util.total_samples() > 0:
        result_dict["avg_channel_util"] = avg_channel_util.mean(-1)
        result_dict["avg_num_drops"] = avg_num_drops.mean(-1)

    result_dict["avg_msg_size"] = avg_msg_size.mean(-1)

    if avg_alive is not None:
        result_dict["avg_alive"] = avg_alive.mean(-1).sum()

    if avg_success is not None:
        result_dict["avg_success"] = avg_success.mean(0)

    result_dict["avg_actions"] = avg_actions.mean(dim=(0, 1))

    return result_dict


def get_step_record_msg_action_counts(step_records: List[StepRecord], msg_sizes, filter_step=0) -> dict:
    """
    Counts the number of individual msg actions in the given dictionary.

    :param step_records: The step records
    :param msg_sizes: The message sizes
    :param filter_step: Only count msg actions in some specific step
    :return: A dict mapping each message size to the total number of uses inside the given step records.
    """
    msg_action_counts = {}
    for size in msg_sizes:
        msg_action_counts[size] = 0

    for _, step in enumerate(step_records):
        if step.step != filter_step:
            continue

        if len(msg_sizes) == 1:
            # there is just one choice
            msg_action_counts[msg_sizes[0]] = step.actions.numel()
        else:
            for i, size in enumerate(msg_sizes):
                msg_action_counts[size] += (step.msg_actions == i).sum().item()

    return msg_action_counts


def train(env, agent_model: AdaComm, channel: Channel, config: Config, device, writer: Optional[SummaryWriter] = None):
    """
    Trains the model in the given environment and uses the given channel to limit communication.

    :param env: The environment
    :param agent_model: The (shared) agent model that is used for all agents
    :param channel: The communication channel
    :param config: The training config
    :param device: The device to use
    :param writer: If provided, this function writes training logs to the writer
    :returns: list of run statistics for each training iteration
    """
    config.check_msg_size(config.msg_sizes)
    optimizer = Adam(agent_model.parameters(), lr=config.learning_rate)

    agent_model.train()
    time_start = datetime.now()
    total_steps = 0
    run_stats = []

    # training loop
    for it in tqdm(range(0, config.num_iterations), file=sys.stdout):
        stats = {}
        optimizer.zero_grad()

        # update training iteration if environment keeps track
        if hasattr(env, 'epoch'):
            env.epoch = it

        agent_epsilon = config.agent_epsilon_schedule.get_value(it)
        msg_size_epsilon = config.msg_size_epsilon_schedule.get_value(it)
        if config.message_mode == DISCRETE:
            discrete_msg_epsilon = config.discrete_message_epsilon_schedule.get_value(it)
        else:
            discrete_msg_epsilon = 0

        def joint_policy(observations, messages, states):
            return joint_inference(agent_model, observations, messages, states, epsilon=agent_epsilon,
                                   msg_size_epsilon=msg_size_epsilon, discrete_msg_epsilon=discrete_msg_epsilon,
                                   receive_own_message=config.receive_own_message, device=device)

        records = run_env(env, config.msg_sizes, joint_policy, config.num_episodes, channel, config.detach_gap)
        records_stats = get_step_record_stats(records, config.msg_sizes, agent_model.n_actions)

        total_steps += int(records_stats['agent_steps'])
        stats['total_steps'] = total_steps
        stats.update(records_stats)

        if agent_model.use_msg and it % 10 == 0:
            msg_action_counts = get_step_record_msg_action_counts(records, config.msg_sizes)
            add_discrete_value_hist("msg_action_size", msg_action_counts, writer, it)

        entropy_coefficient = config.entropy_schedule.get_value(it)
        stats['entropy_coefficient'] = entropy_coefficient

        agent_act_loss, act_loss_count, agent_msg_loss, msg_loss_count = \
            get_losses(records, config.discount_factor, config.receive_own_message,
                       act_entropy_coefficient=entropy_coefficient,
                       act_loss_only_when_episode_done=config.agent_q_loss_only_done)

        stats['agent_act_loss'] = agent_act_loss.detach().clone()

        # agents that were alive for more steps should also contribute more to the loss
        mean_act_loss = (agent_act_loss * (act_loss_count / act_loss_count.sum())).sum()

        if agent_msg_loss is not None:
            stats['agent_msg_loss'] = agent_msg_loss.detach().clone()
            mean_msg_loss = (agent_msg_loss * (msg_loss_count / msg_loss_count.sum())).sum()
            loss = (1 - config.msg_loss_weight) * mean_act_loss + config.msg_loss_weight * mean_msg_loss
        else:
            loss = mean_act_loss

        stats['loss'] = loss.detach().clone()

        # add additional stats
        stats['epsilon_agent'] = agent_epsilon
        stats['epsilon_msg_size'] = msg_size_epsilon
        stats['epsilon_discrete_msg'] = discrete_msg_epsilon

        # print stats
        stats_str = ", ".join([f"{key} = {stats[key]}" for key in stats])
        tqdm.write(f" It {it} > {stats_str}")

        # learning step
        loss.backward()
        if config.gradient_clip_max_norm is not None:
            stats["grad_norm"] = torch.nn.utils.clip_grad_norm_(agent_model.parameters(),
                                                                max_norm=config.gradient_clip_max_norm, norm_type=2.0)
        else:
            stats["grad_norm"] = get_grad_norm(agent_model.parameters(), 2.0)

        optimizer.step()

        if writer:
            write_dict_scalars(writer, stats, it, category='train')

        del records
        run_stats.append(stats)

        if writer is not None and config.save_model and config.save_model_interval is not None \
                and (it + 1) % config.save_model_interval == 0 and it + 1 < config.num_iterations:
            torch.save(agent_model.state_dict(), Path(writer.get_logdir()) / f"model_{it + 1}.pt")

    if writer is not None and config.save_model:
        torch.save(agent_model.state_dict(), Path(writer.get_logdir()) / "model_final.pt")

    seconds_passed = (datetime.now() - time_start).total_seconds()
    print(f"Passed time: {seconds_passed}s")
    print(f"> per iteration: {seconds_passed / config.num_iterations * 1000}ms")
    if config.env_name == POMNIST:
        print(f"> per episode: {seconds_passed / (config.num_iterations * config.num_episodes * env.num_envs) * 1000}ms")
    else:
        print(f"> per episode: {seconds_passed / (config.num_iterations * config.num_episodes) * 1000}ms")

    return run_stats
