from typing import List, Optional

import torch
from torch.utils.tensorboard import SummaryWriter

import constants
from channel import Channel
from constants import CONTINUOUS, TRAFFIC_JUNCTION
from env.pomnist.rl_env import PartialMNISTEnv
from env.pomnist.vis_util import create_q_stats_msg_size, create_reward_figure, create_max_q_figure, \
    create_q_stats_per_digit, create_q_stats_per_action, visualize_message_effect, \
    visualize_message_effect_on_targets, create_df_analysis
from config import Config
from modules.messages import MessageEncoder, get_message_sizes
from training import StepRecord, run_env
from util import set_title

import logging
from typing import Tuple

import pandas as pd
import numpy as np


def get_listening_effect(env, actions, first_step=0, second_step=1):
    """
    Get positive and negative listening rates from the actions without(1st step)/with(2nd step) messages

    :param env: The environment
    :param actions: All actions of all agents (expected shape (num_agents, num_samples, num_steps))
    :param first_step: The action to be considered before the messaging took place (in case we increase the number of
    steps)
    :param second_step: The action to be considered after the messaging took place
    """
    first_step, second_step = str(first_step), str(second_step)
    df = create_df_analysis(env, actions)
    num_actions = len(df)
    df_effected = df[(df[first_step] != df[second_step])]
    pos_listening = sum(df_effected.target == df_effected[second_step])/num_actions
    neg_listening = sum(df_effected.target == df_effected[first_step])/num_actions
    print("Positive listening: ", pos_listening)
    print("Negative listening: ", neg_listening)
    print("Unaffected actions: ", sum(df[first_step] == df[second_step])/num_actions)
    return pos_listening, neg_listening


def get_speaker_consistency(actions: np.ndarray, messages: np.ndarray, normalization):
    """
    Get the (normalized) speaker consistency for the given actions and messages

    :param actions: ndarray of shape (samples, 1) or (samples)
    :param messages: ndarray of shape (samples, message_content)
    :param normalization: available normalization modes: [None or 'none', 'joint_entropy', 'min_entropy']
    :return: the (normalized) speaker consistency
    """
    if len(actions.shape) == 1:
        actions = np.expand_dims(actions, 1)

    action_label = ['action']
    message_labels = [f'msg_{i}' for i in range(messages.shape[1])]
    # table with all actions and messages
    df = pd.DataFrame(np.concatenate([actions, messages], axis=1), columns=action_label + message_labels)
    # calculate P(a, m)
    P_am = (
        df.groupby(action_label + message_labels)
        # add new column with joint count
        .size().reset_index(name='P_am')
        # update index
        .set_index(action_label + message_labels)
    )
    P_am.P_am = P_am.P_am.astype(np.double) / df.shape[0]
    assert 0.999 <= P_am.sum().item() <= 1.001

    P_m = P_am.groupby(message_labels).sum().rename(columns={'P_am': 'P_m'})
    assert 0.999 <= P_m.sum().item() <= 1.001

    P_a = P_am.groupby(action_label).sum().rename(columns={'P_am': 'P_a'})
    assert 0.999 <= P_a.sum().item() <= 1.001

    mutal_information = (P_am.P_am * np.log((P_am.P_am / P_m.P_m) / P_a.P_a)).sum()
    if not normalization or normalization == 'none':
        return mutal_information
    elif normalization == 'min_entropy':
        entropy_A = -(P_a.P_a * np.log(P_a.P_a)).sum()
        entropy_M = -(P_m.P_m * np.log(P_m.P_m)).sum()
        min_entropy = min(entropy_A, entropy_M)
        if min_entropy < 0.00001:
            # avoid division by 0 and numerical issues
            logging.warning(f"Min entropy is too small to normalize speaker consistency: {min_entropy}.")
            return 0
        return mutal_information / min_entropy
    elif normalization == 'joint_entropy':
        joint_entropy = -(P_am.P_am * np.log(P_am.P_am)).sum()
        if joint_entropy < 0.00001:
            # avoid division by 0 and numerical issues
            logging.warning(f"Joint entropy is too small to normalize speaker consistency: {joint_entropy}.")
            return 0
        return mutal_information / joint_entropy
    else:
        raise ValueError(f"Unknown normalization mode '{normalization}'.")


def get_multi_speaker_consistency(agent_actions: np.ndarray, agent_messages: np.ndarray, msg_sizes: Tuple[int]):
    """
    Get speaker consistency considering multiple message sizes.

    :param agent_actions: agent actions of shape (samples, 1) or (samples)
    :param agent_messages: agent messages of shape (samples, message_container_size)
    :param msg_sizes: the message sizes
    :return: weighted speaker consistency value
    """
    speaker_consistency_per_msg_size = np.zeros(len(msg_sizes))
    num_messages_per_msg_size = np.zeros(len(msg_sizes))

    p_phi_eq_0 = 0.0

    for i, m in enumerate(msg_sizes):
        if m == 0:
            # message size 0 = all content & mask is 0
            messages_with_size = (agent_messages[:] == 0).prod(axis=-1)
        else:
            # the mask is at the end of the message [msg, mask]
            if msg_sizes[0] == 0:
                mask_len = len(msg_sizes) - 1
                mask_index = i - 1
            else:
                mask_len = len(msg_sizes)
                mask_index = i

            messages_with_size = agent_messages[:, -mask_len + mask_index] == 1

        num_messages_per_msg_size[i] = messages_with_size.sum()

        if num_messages_per_msg_size[i] == 0:
            continue

        if m == 0:
            assert p_phi_eq_0 == 0.0, "Can't have 2x msg size 0!"
            p_phi_eq_0 = num_messages_per_msg_size[i] / len(agent_actions)
        else:
            filtered_actions = agent_actions[messages_with_size]
            # exclude mask and padding from filtered messages
            filtered_messages = agent_messages[messages_with_size, 0:m]
            speaker_consistency_per_msg_size[i] = get_speaker_consistency(filtered_actions, filtered_messages,
                                                                          normalization='min_entropy')

    combined = 0.0
    for i, m in enumerate(msg_sizes):
        # ignore msg size 0
        if m == 0:
            # print(f"Msg size {m}: {p_phi_eq_0}")
            continue

        if p_phi_eq_0 < 1 and num_messages_per_msg_size[i] > 0:
            p_phi = num_messages_per_msg_size[i] / len(agent_actions)
            combined += (p_phi / (1.0 - p_phi_eq_0)) * speaker_consistency_per_msg_size[i]
        # print(f"Msg size {m}: {p_phi} / {(1.0 - p_phi_eq_0)} * {speaker_consistency_per_msg_size[i]}")

    # weight speaker consistency according to samples
    print(f"Speaker consistency: {combined} - individual: {speaker_consistency_per_msg_size} "
          f"(msg count {num_messages_per_msg_size})")
    return combined, speaker_consistency_per_msg_size, num_messages_per_msg_size


def vectorize_step_records(step_records: List[StepRecord], msg_sizes, n_actions, with_messages=False):
    """
    Converts step record list to multiple ndarrays that contain the data from all steps.
    Note that this does not respect agent masks.

    :param step_records: The step records
    :param msg_sizes: Used message sizes
    :param with_messages: Also save & return message content
    """
    num_agents = step_records[0].observations.shape[0]
    num_episodes = sum([step.get_num_envs() if step.done else 0 for step in step_records])
    max_steps = max([step.step for step in step_records]) + 1

    all_actions = np.zeros((num_agents, num_episodes, max_steps))
    all_returns = np.zeros((num_agents, num_episodes))
    all_q_values = np.zeros((num_agents, num_episodes, max_steps, n_actions))
    if len(msg_sizes) > 1:
        all_q_msg_sizes = np.empty((num_agents, num_episodes, max_steps, len(msg_sizes)))
        all_msg_actions = np.empty((num_agents, num_episodes, max_steps))
    else:
        all_q_msg_sizes = None
        all_msg_actions = None

    all_num_dropped = np.zeros((num_episodes, max_steps))
    all_channel_util = np.zeros((num_episodes, max_steps))

    if with_messages:
        all_messages = np.empty((num_agents, num_episodes, max_steps,
                                 MessageEncoder.get_msg_container_len(msg_sizes)))
    else:
        all_messages = None

    all_success = np.zeros(num_episodes)

    episode = 0
    for step in step_records:
        step_dim = step.get_num_envs()
        episode_slice = slice(episode, episode+step_dim)
        all_actions[:, episode_slice, step.step] = step.actions.numpy()
        all_returns[:, episode_slice] += step.rewards.numpy()
        if step.q_values is not None:
            all_q_values[:, episode_slice, step.step] = step.q_values.numpy()
        if step.q_msg_sizes is not None:
            all_q_msg_sizes[:, episode_slice, step.step] = step.q_msg_sizes.numpy()
        if all_msg_actions is not None:
            all_msg_actions[:, episode_slice, step.step] = step.msg_actions.numpy()
        if step.num_dropped is not None:
            all_num_dropped[episode_slice, step.step] = step.num_dropped.numpy()
        if step.channel_util is not None:
            all_channel_util[episode_slice, step.step] = step.channel_util.numpy()

        if with_messages and step.messages is not None:
            all_messages[:, episode_slice, step.step, :] = step.messages.numpy()

        if step.done:
            episode += step_dim
            if step.success is not None:
                all_success[episode_slice] = step.success

    return all_returns, all_q_values, all_q_msg_sizes, all_actions, all_msg_actions, all_messages, all_num_dropped, \
           all_channel_util, all_success


def get_return_stats(all_returns: np.ndarray, print_stats=True, out_dict: dict = None):
    """
    Gets the return stats from the given returns and puts them into the dictionary (if given).

    :param all_returns: All returns from all agents
    :param print_stats: Whether to print the stats
    :param out_dict: The dictionary where the stats should be saved
    """
    mean_return = all_returns.mean(axis=1)
    std_return = all_returns.std(axis=1)
    if print_stats:
        print(f"Return: {mean_return} +/- {std_return}")

    if out_dict is not None:
        out_dict['agent_mean_return'] = mean_return
        out_dict['agent_std_return'] = std_return


def evaluate(env, joint_policy, channel: Channel, config: Config, num_episodes, writer: Optional[SummaryWriter] = None):
    """
    Evaluate the given joint policy in the given environment.

    :param env: environment
    :param joint_policy: policy for all agents
    :param channel: The communication channel
    :param config: The config
    :param num_episodes: number of episodes that are run for the evaluation
    :param writer: writer to log eval figures
    :return: eval stats dictionary
    """

    # disable gradients to speed up evaluation & disable graph generation (=> less memory required)
    torch.set_grad_enabled(False)

    if hasattr(env, 'num_envs'):
        num_run_episodes = int(np.ceil(num_episodes / env.num_envs))
    else:
        num_run_episodes = num_episodes

    step_records = run_env(
        env, config.msg_sizes, joint_policy, num_run_episodes, channel, config.detach_gap, log_messages=True,
        record_device='cpu', show_progress=True
    )
    torch.set_grad_enabled(True)

    returns, q_values, q_msg_sizes, actions, msg_actions, messages, num_drops, channel_util, all_success = \
        vectorize_step_records(step_records, config.msg_sizes, env.num_actions, True)
    del step_records

    stats = {}
    get_return_stats(returns, True, stats)

    stats.update({
        # ignore the stats of the first step as we measure channel util in the receiving step
        # => they are 0 for the first step by definition
        "num_drops": num_drops[:, 1:].mean().item(),
        "channel_util": channel_util[:, 1:].mean().item()
    })
    if msg_actions is not None:
        all_msg_sizes = get_message_sizes(torch.Tensor(msg_actions), config.msg_sizes)
        # we don't care about the last selected message size as it never arrives
        stats.update({"mean_selected_msg_size": all_msg_sizes[:, :-1].mean().item()})
    elif config.msg_sizes is not None and len(config.msg_sizes) == 1:
        stats.update({"mean_selected_msg_size": config.msg_sizes[0]})
    else:
        stats.update({"mean_selected_msg_size": 0})

    if config.env_name == constants.POMNIST and config.message_mode != CONTINUOUS:
        step = 0
        speaker_consistency = np.zeros(env.num_agents)
        msg_speaker_consistency = np.zeros((env.num_agents, len(config.msg_sizes)))
        messages_count = np.zeros((env.num_agents, len(config.msg_sizes)))
        for a in range(0, env.num_agents):
            speaker_consistency[a], msg_speaker_consistency[a], messages_count[a] = \
                get_multi_speaker_consistency(actions[a, :, step], messages[a, :, step, :], config.msg_sizes)

        stats.update({"positive_signaling": speaker_consistency.mean()})
        stats.update({"positive_signaling_agents": speaker_consistency})
        stats.update({"positive_signaling_agents_msg_sizes": list(msg_speaker_consistency)})
        stats.update({"messages_count":  list(messages_count.sum(axis=0))})
        stats.update({"messages_count_agents":  list(messages_count)})

    if config.env_name == TRAFFIC_JUNCTION:
        print(f"Traffic Juncion success: {all_success.mean()} ({len(all_success)} episodes)")
        stats.update({"traffic_junction_success": all_success.mean()})

    if isinstance(env, PartialMNISTEnv):
        accuracy = PartialMNISTEnv.reward_to_accuracy(returns.mean(axis=1))
        print(f"Accuracy: {accuracy}")
        stats['agent_mean_accuracy'] = accuracy

        pos_listening, neg_listening = get_listening_effect(env, actions)
        stats.update({"positive_listening": pos_listening, "negative_listening": neg_listening})

        # create rewards figures
        if writer:
            eval_mode = 'train' if env.train else 'eval'
            if q_msg_sizes is not None:
                create_q_stats_msg_size(env, config.msg_sizes, q_msg_sizes, eval_mode, w=writer)
            fig_reward = create_reward_figure(env, returns)
            set_title(fig_reward, "Model Test")
            writer.add_figure(f"eval-{eval_mode}/fig_reward", fig_reward)

            agent_idx = 0
            fig_correct = create_max_q_figure(env, agent_idx, returns, q_values, True)
            set_title(fig_correct, "Correct prediction with highest Q-value")
            writer.add_figure(f"eval-{eval_mode}/fig_correct", fig_correct)

            fig_incorrect = create_max_q_figure(env, agent_idx, returns, q_values, False)
            set_title(fig_incorrect, "Incorrect prediction with highest Q-value")
            writer.add_figure(f"eval-{eval_mode}/fig_incorrect", fig_incorrect)

            if config.log_eval_plots and eval_mode == 'test':
                create_q_stats_per_digit(env, q_values, w=writer)
                create_q_stats_per_action(env, q_values, actions, w=writer)
                # visualize_incorrect_actions(e, actions, w=writer)
                visualize_message_effect(env, actions, w=writer)
                visualize_message_effect_on_targets(env, actions, w=writer)

    return stats


def evaluate_pomnist(env: PartialMNISTEnv, joint_policy, channel: Channel, config: Config,
                     writer: Optional[SummaryWriter] = None):
    """
    Special evaluation function for the POMNIST environment, returns stats for complete train and test dataset.

    :param env: The environment
    :param joint_policy: The joint policy
    :param channel: The communication channel
    :param config: The used config
    :param writer: The eval writer
    :return: dictionary with stats for the respective datasets (keys 'train' and 'test')
    """
    assert isinstance(env, PartialMNISTEnv)

    eval_stats = {}

    # Log stats for train mode
    print("> Train")
    env = env.eval(test=False)
    eval_stats['train'] = evaluate(env, joint_policy, channel, config, len(env.data), writer)

    # continue with test mode
    print("> Test")
    env = env.eval(test=True)
    eval_stats['test'] = evaluate(env, joint_policy, channel, config, len(env.data), writer)

    return eval_stats
