from typing import List, Optional, Dict

import numpy as np
import pandas as pd

from torch.utils.tensorboard import SummaryWriter

from env.pomnist.rl_env import PartialMNISTEnv
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
from matplotlib.patches import Rectangle

from config import Config
from constants import *


def create_df_analysis(env, actions):
    """
    Create a dataframe with columns ['agent', 'target', 'data_ind', '0', '1'...] (depending on number of steps).
    Numeric columns indicate the actions taken in respective steps.

    :param env: The environment
    :param actions: All actions of all agents (expected shape (num_agents, num_samples, num_steps))
    :return: dataframe
    """
    targets = env.targets
    num_agents = env.num_agents
    num_steps = env.max_steps
    df = pd.DataFrame(columns=['agent', 'target', 'data_ind'] + list(range(num_steps)))
    for agent_id in range(num_agents):
        df_agent = pd.DataFrame(actions[agent_id])
        df_agent['agent'] = agent_id
        df_agent['target'] = targets
        df_agent['data_ind'] = np.arange(0, env.data.shape[0])
        df = pd.concat((df, df_agent)).reset_index(drop=True)

    df.columns = df.columns.astype(str)
    return df


def create_agent_view_img(env: PartialMNISTEnv, values: List[float]) -> np.ndarray:
    """
    Creates an image with the same dimensions as the source data and fills the partial view of the agents with
    the given values (for each agent).

    :param env: The environment
    :param values: Contains a value for each agent
    :return: A 2D ndarray where the given values are inserted in the respective partial agent view
    """
    img = np.zeros(env.image_shape, dtype=np.float)

    for i, (x, y, split_w, split_h) in enumerate(env.agents_view_rect):
        img[y:y + split_h, x:x + split_w] = values[i]

    return img


def add_centred_text(ax, txt, x, y, fontsize):
    """
    Adds centred white text with a black border.

    :param ax: The axis
    :param txt: The text
    :param x: The x position
    :param y: The y position
    :param fontsize: The font size
    :return: The created text
    """
    txt = ax.text(x, y, txt, verticalalignment='center', horizontalalignment='center', transform=ax.transAxes,
                  color='white', fontsize=fontsize)

    txt.set_path_effects(
        [path_effects.Stroke(linewidth=fontsize / 10, foreground='black'), path_effects.Normal()]
    )

    return txt


def annotate_agent_text(env: PartialMNISTEnv, ax, text_list: List[str], fontsize=20):
    """
    Creates texts which are centered at the agents' views.

    :param env: The environment
    :param ax: The axis
    :param text_list: A list of strings which contains the text for each agent
    :param fontsize: The font size
    """
    w = env.image_shape[1]
    h = env.image_shape[0]

    for i, (x, y, split_w, split_h) in enumerate(env.agents_view_rect):
        add_centred_text(ax, text_list[i], (x + split_w / 2) / w, 1 - (y + split_h / 2) / h, fontsize)


def create_reward_figure(env: PartialMNISTEnv, rewards: np.ndarray, fontsize=20):
    """
    Creates a figure which illustrates (at least) the rewards of each agent.

    :param env: The environment
    :param rewards: All rewards which were gathered during an evaluation run. Expected shape: (num_agents, len_data)
    :param fontsize: The fontsize for the annotation
    :return: The created figure
    """
    fig, axes = plt.subplots(1, 3, figsize=(13, 4))

    def create_numbers_plot(ax, numbers, title, vmin, vmax, cmap=None):
        ax.set_title(title)
        img_array = create_agent_view_img(env, numbers)
        im = ax.imshow(img_array, vmin=vmin, vmax=vmax, cmap=cmap)

        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        agent_text = [f"{id}\n" + "{:.4f}".format(numbers[id]) for id in range(0, env.num_agents)]

        annotate_agent_text(env, ax, agent_text, fontsize)
        fig.colorbar(im, ax=ax)

    hits = env.reward_to_accuracy(rewards)

    create_numbers_plot(axes[0], hits.mean(axis=1), "Accuracy", 0, 1)
    create_numbers_plot(axes[1], rewards.mean(axis=1), "Mean Reward", -1, 1)
    create_numbers_plot(axes[2], rewards.std(axis=1), "Std Reward", 0, 1, cmap=plt.get_cmap('inferno'))

    return fig


def visualize_q_values(agent_q_values, ax=None):
    """
    Plots the q values using the given axis.

    :param agent_q_values: The q-values of the agent for this particular sample (expected shape (10))
    :param ax: The axis
    """
    assert len(agent_q_values) == 10, "Incorrect length of q values!"

    if ax is None:
        ax = plt.gca()

    labels = np.arange(0, 10)
    ax.bar(labels, agent_q_values)
    ax.set_xticks(labels)
    ax.set_xticklabels(labels)


def visualize_q_stat(q_values, labels=None, ax=None):
    """
    Plots the distribution of q values using the given axis.

    :param q_values: Array of q-values (expected shape (:, 10))
    :param ax: The axis
    """
    if ax is None:
        ax = plt.gca()

    if labels is None:
        labels = np.arange(0, 10)

    q_mean = q_values.mean(axis=0)
    q_std = q_values.std(axis=0)
    ax.bar(labels, q_mean, yerr=q_std, align='center', alpha=0.5, ecolor='black', capsize=5)
    ax.set_ylabel('Q values', fontsize=8)
    ax.set_xlabel('Target', fontsize=8)
    ax.set_xticks(labels)
    ax.set_xticklabels(labels)
    ax.set_ylim([-1, 1])


def annotate_subplots(axes, col_labels=None, row_labels=None):
    """
    Annotates the columns and rows of given axes

    :param axes: The axes (array of axis)
    :param col_labels: A list of strings which contains the text for each column
    :param row_labels: A list of strings which contains the text for each column
    """
    if col_labels:
        for ax, col in zip(axes[0], col_labels):
            ax.annotate(col, xy=(0.5, 1), xytext=(0, 5), xycoords='axes fraction', textcoords='offset points',
                        size='large', ha='center', va='baseline')
    if row_labels:
        for ax, row in zip(axes[:, 0], row_labels):
            ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - 5, 0), xycoords=ax.yaxis.label,
                        textcoords='offset points', size='large', ha='right', va='center', rotation='vertical')


def visualize_agent_view(env: PartialMNISTEnv, data_index, agent_idx, ax=None):
    """
    Using the given axis, create an image of the sample with index data_index and highlight the view of the agent
    with index agent_id.
    
    :param env: The environment
    :param data_index: The data index (in env)
    :param agent_idx: The agent index
    :param ax: The axis
    """
    if ax is None:
        ax = plt.gca()

    data = env.data[data_index]
    ax.imshow(data)

    ticks = np.arange(-1.5, 29)
    labels = [''] * len(ticks)

    ax.set_xticks(ticks)
    ax.set_xticklabels(labels)
    ax.set_yticks(ticks)
    ax.set_yticklabels(labels)

    x, y, split_w, split_h = env.agents_view_rect[agent_idx]
    rect = Rectangle((x - 0.5, y - 0.5), split_w, split_h, linewidth=2, edgecolor='r', facecolor='none')
    ax.add_patch(rect)


def create_agent_figure(env: PartialMNISTEnv, data_index, agent_idx, agent_q_values):
    """
    Create a figure of the sample with index data_index and the corresponding observation and q-values of the agent
    with index agent_idx.

    :param env: The environment
    :param data_index: The data index (in env)
    :param agent_idx: The agent index (in env and q_values)
    :param agent_q_values: The q-values of the agent for this particular sample (expected shape (10))
    :return: The created figure
    """
    fig, axes = plt.subplots(1, 2, figsize=(9, 4))

    axes[0].set_title(f"Target class: {env.targets[data_index]} @ {data_index}")
    visualize_agent_view(env, data_index, agent_idx, ax=axes[0])

    axes[1].set_title(f"Agent Q-values")
    visualize_q_values(agent_q_values, ax=axes[1])

    return fig


def create_max_q_figure(env: PartialMNISTEnv, agent_idx, rewards, q_values, filter_correct: Optional[bool] = True):
    """
    Create a figure of the sample where the agent with index agent_idx predicted the overall highest q value.

    :param env: The environment
    :param agent_idx: The index of the agent
    :param rewards: All rewards of all agents (expected shape (num_agents, num_samples))
    :param q_values: All Q-values of all agents (expected shape (num_agents, num_samples, 10))
    :param filter_correct: Whether to filter out correct/incorrect predictions (according to the reward)
    :return: The created figure
    """
    agent_max_q = np.max(q_values[agent_idx, :, -1, :], axis=1)

    if filter_correct is not None:
        # try to apply filter to max q
        target_reward = 1 if filter_correct else -1

        if (rewards[agent_idx, :] == target_reward).any():
            agent_max_q = ((rewards[agent_idx, :] == -target_reward) * -99999999) + agent_max_q
        else:
            print(f"Agent has no {'correct' if filter_correct else 'incorrect'} prediction! Cannot apply filter!")

    # select the sample with the highest q value (-> certainty)
    data_idx = np.argmax(agent_max_q)

    # plot it
    return create_agent_figure(env, data_idx, agent_idx, q_values[agent_idx, data_idx, -1])


def create_q_stats_per_digit(env, q_values, w=None):
    """
    Create a plot per digit showing the distribution of Q-values with subplots for each agent and for each step.

    :param env: The environment
    :param q_values: All Q-values of all agents (expected shape (num_agents, num_samples, num_steps, 10))
    """
    targets = env.targets
    num_agents = env.num_agents
    num_steps = env.max_steps

    for tar in np.unique(targets):
        fig, axes = plt.subplots(num_steps, num_agents)
        for agent_id in range(num_agents):
            for step in range(num_steps):
                mask = np.where(targets == tar)[0]
                visualize_q_stat(q_values[agent_id, mask, step, :], ax=axes[step, agent_id])

        col_labels = [f'Agent{i}' for i in range(num_agents)]
        row_labels = [f'Action{i}' for i in range(num_steps)]
        annotate_subplots(axes, col_labels=col_labels, row_labels=row_labels)
        fig.suptitle(f'Q-value distributions for Target={tar}')
        fig.tight_layout()
        if w:
            w.add_figure(f"q_stats/{tar}", fig)


def create_q_stats_msg_size(env, msg_sizes, q_values, eval_mode, step=0, w=None):
    """
    Create a plot per digit showing the distribution of msg Q-values with subplots for each agent and for each step.

    :param env: The environment
    :param q_values: All Q-values of all agents (expected shape (num_agents, num_samples, num_steps, num_messages))
    """
    num_agents = env.num_agents

    fig, ax = plt.subplots()
    labels = np.arange(num_agents)
    q_mean = q_values[:, :, step, :].mean(axis=1)
    q_std = q_values[:, :, step, :].std(axis=1)
    width = min(0.2, 0.9 / len(msg_sizes))
    for i, size in enumerate(msg_sizes):
        ax.bar(labels + width * (i - (len(msg_sizes)-1) / 2), q_mean[:, i], yerr=q_std[:, i], align='center', alpha=0.5,
               ecolor='black', capsize=5, width=width)
    ax.set_ylabel('Q values', fontsize=8)
    ax.set_xlabel('Agents', fontsize=8)
    ax.set_xticks(labels)
    ax.set_xticklabels(labels)
    ax.set_ylim([-0.25, 1.25])
    ax.legend(labels=msg_sizes)

    fig.suptitle(f'Q-value distributions for msg_sizes={msg_sizes}')
    fig.tight_layout()
    if w:
        w.add_figure(f"q_msg_size_stats/{eval_mode}_{msg_sizes}", fig)


def create_q_stats_per_action(env, q_values, actions, w=None):
    """
    Create a plot per digit and per agent showing the distribution of Q-values with subplots for each action and
    for each step.

    :param env: The environment
    :param q_values: All Q-values of all agents (expected shape (num_agents, num_samples, num_steps, 10))
    :param actions: All actions of all agents (expected shape (num_agents, num_samples, num_steps))
    """
    targets = env.targets
    num_agents = env.num_agents
    num_steps = env.max_steps

    for tar in np.unique(targets):
        for agent_id in range(num_agents):
            unq_actions = np.unique(actions[agent_id, (targets == tar)])
            fig, axes = plt.subplots(num_steps, len(unq_actions), squeeze=False, figsize=(17, 9))
            for step in range(num_steps):
                for i, action in enumerate(unq_actions):
                    mask = (targets == tar) & ((actions[agent_id, :, step] == action).flatten())
                    if mask.sum() == 0:
                        continue
                    q_masked = q_values[agent_id, mask, step, :]
                    visualize_q_stat(q_masked, ax=axes[step, i])
                    axes[step, i].set_title(f'Action={int(action)} ({len(q_masked)})')

            row_label = [f'Action{i}' for i in range(num_steps)]
            annotate_subplots(axes, row_labels=row_label)
            fig.suptitle(f'Q-value distributions of Agent{agent_id} on Target={tar}')
            fig.tight_layout()
            if w:
                w.add_figure(f"q_stats/tar{tar}_agent{agent_id}", fig)


def visualize_incorrect_actions(env, actions, w=None):
    """
    Create a plot per digit and per agent showing the target images where the agent's prediction was incorrect.

    :param env: The environment
    :param actions: All actions of all agents (expected shape (num_agents, num_samples, num_steps))
    """
    df = create_df_analysis(env, actions)
    num_steps = env.max_steps

    df_grouped_target = df.groupby(by='target')
    for tar, df_target in list(df_grouped_target):
        df_grouped_agent = df_target.groupby(by='agent')
        for agent, df_agent in list(df_grouped_agent):
            df_agent.reset_index(drop=True, inplace=True)
            cols = int(np.ceil(np.sqrt(len(df_agent))))
            rows = int(np.ceil(len(df_agent) / cols))
            fig, ax = plt.subplots(rows, cols, figsize=(15, 9))
            for i, row in df_agent.iterrows():
                visualize_agent_view(env, row['data_ind'], agent, ax=ax[int(i / cols), i % cols])
                ax[int(i / cols), i % cols].set_title(f"action_{int(row[str(num_steps-1)])}", fontsize=8)
                ax[int(i / cols), i % cols].axis('off')
            fig.suptitle(f'Incorrect actions of Agent{agent} on Target={tar}')
            fig.tight_layout()
            if w:
                w.add_figure(f"incorrect_actions/tar{tar}_agent{agent}", fig)


def visualize_message_effect(env, actions, first_step=0, second_step=1, w=None):
    """
    Create a plot showing the number of positively and negatively effected predictions by the messages.

    :param env: The environment
    :param actions: All actions of all agents (expected shape (num_agents, num_samples, num_steps))
    :param first_step: The action to be considered before the messaging took place (in case we increase the number of
    steps)
    :param second_step: The action to be considered after the messaging took place
    """
    first_step, second_step = str(first_step), str(second_step)
    df = create_df_analysis(env, actions)
    df_effected = df[(df[first_step] != df[second_step])]
    dict_mask = {'positive': (df_effected.target == df_effected[second_step]),
                 'negative': (df_effected.target == df_effected[first_step])}
    agents = np.arange(env.num_agents)
    targets = np.arange(10)

    fig, ax = plt.subplots(2, 1)
    width = 0.15
    for i, (eff, mask) in enumerate(dict_mask.items()):
        df_effect = df_effected[mask]
        count = df_effect.groupby(by=['target', 'agent']).apply(len).reset_index()
        new_index = pd.Index(targets, name="target")
        for a in agents:
            c_agent = count[count['agent'] == a].set_index("target").reindex(new_index)
            ax[i].bar(targets + width * (a - agents.max() / 2), c_agent[0].values, width, label=f'Agent{a}')
        ax[i].set_title(f'{eff} effect')
        ax[i].set_ylabel('Number of samples')
        ax[i].set_xlabel('Target')
        ax[i].set_xticks(targets)
        ax[i].set_xticklabels(targets)

    ax[0].legend()
    fig.tight_layout()
    if w:
        w.add_figure(f"msg_effect", fig)


def visualize_message_effect_on_targets(env, actions, first_step=0, second_step=1, w=None):
    """
    Create a plot per digit showing an example of positively and negatively effected samples.

    :param env: The environment
    :param actions: All actions of all agents (expected shape (num_agents, num_samples, num_steps))
    :param first_step: The action to be considered before the messaging took place (in case we increase the number of
    steps)
    :param second_step: The action to be considered after the messaging took place
    """
    first_step, second_step = str(first_step), str(second_step)
    df = create_df_analysis(env, actions)
    targets = np.arange(10)
    num_agents = env.num_agents

    df_effected = df[(df[first_step] != df[second_step])]
    dict_mask = {'positive': (df_effected.target == df_effected[second_step]),
                 'negative': (df_effected.target == df_effected[first_step])}

    for tar in targets:
        fig, axes = plt.subplots(2, num_agents)
        df_tar = df_effected[(df_effected.target == tar)]
        for i, (eff, mask) in enumerate(dict_mask.items()):
            df_tar_eff = df_tar[mask]
            if len(df_tar_eff):
                ind = int(df_tar_eff.iloc[0]['data_ind'])
            else:
                print(f'There are no {eff} effects of messages on target class {tar}!')
                continue
            df_frame = df[df.data_ind == ind]
            for agent_id in range(num_agents):
                visualize_agent_view(env, ind, agent_id, axes[i, agent_id])
                axes[i, agent_id].set_title(f"{df_frame[df_frame.agent == agent_id][first_step].values[0]}-->"
                                            f"{df_frame[df_frame.agent == agent_id][second_step].values[0]}")
                axes[i, agent_id].set_yticks([])
                axes[i, agent_id].set_xticks([])

        annotate_subplots(axes, row_labels=list(dict_mask.keys()))
        fig.suptitle(f'An example of the message effect on Target={tar}')
        fig.tight_layout()
        if w:
            w.add_figure(f"msg_effect/{tar}", fig)


def add_multi_run_return_to_ax(all_run_stats, config: Config, ax):
    """
    Plots the mean return per iteration (with std) of multiple runs using the given axis.

    :param all_run_stats: All run stats
    :param config: The config
    :param ax: The axis
    """
    num_runs = len(all_run_stats)

    avg_returns_mean = np.zeros((config.num_iterations, num_runs))
    for r in range(0, num_runs):
        run_stats = all_run_stats[r]
        if config.num_iterations != len(run_stats):
            print(f"Invalid run stats length in run {r}! Expected {config.num_iterations} and got {len(run_stats)}!")

        for i in range(0, config.num_iterations):
            avg_returns_mean[i, r] = run_stats[i]['avg_returns'].mean()

    add_multi_run_stat_to_ax(avg_returns_mean, config, ax)


def add_multi_run_stat_to_ax(avg_stat, config: Config, ax, compared_params=None):
    """
    Plots the mean of given stats per iteration (with std) of multiple runs using the given axis.

    :param avg_stat: The stats to be plotted
    :param config: The config
    :param ax: The axis
    """
    x = np.arange(0, config.num_iterations)
    y = avg_stat.mean(axis=1)
    std = avg_stat.std(axis=1)
    if (compared_params is None) or (len(compared_params)==0):
        if config.env_name == POMNIST:
            num_agents = (config.x_splits + 1) * (config.y_splits + 1)
        else:
            num_agents = config.num_agents
        label = f'{len(avg_stat)} runs {num_agents} agents:{config.msg_sizes}'
    else:
        label = '_'.join(str(getattr(config, attr)) if getattr(config, attr) is not None else "Unlimited" for attr in compared_params)
    p = ax.plot(x, y, label=label)
    ax.fill_between(x, (y - std), (y + std), color=p[0].get_color(), alpha=.3)


def add_single_run_stat_to_ax(avg_stat, config: Config, ax, compared_params=None):
    """
    Plots the stats per iteration of a single run using the given axis.

    :param avg_stat: The stats to be plotted
    :param config: The config
    :param ax: The axis
    """
    x = np.arange(0, config.num_iterations)
    if (compared_params is None) or (len(compared_params)==0):
        if config.env_name == POMNIST:
            num_agents = (config.x_splits + 1) * (config.y_splits + 1)
        else:
            num_agents = config.num_agents
        label = f'{len(avg_stat)} runs {num_agents} agents:{config.msg_sizes}'
    else:
        label = [f'{getattr(config, attr)}' for attr in compared_params]
    ax.plot(x, avg_stat, label=label)


def visualize_multi_run_return(all_run_stats, config, w: Optional[SummaryWriter] = None):
    """
    Create plot of the mean return per iteration (with std) of multiple runs.

    :param all_run_stats: All run stats (can also be a list of lists for multiple multi-runs)
    :param config: The config (can also be a list of configs for multiple multi-runs)
    :param w: The figure will be added to this summary writer if not None
    """
    fig, ax = plt.subplots()

    if isinstance(all_run_stats[0], list) and isinstance(all_run_stats[0][0], list):
        # support plotting multiple runs with different configurations in the same figure
        assert all([isinstance(all_run_stats_i, list) for all_run_stats_i in all_run_stats])
        assert isinstance(config, list) and len(all_run_stats) == len(config)

        for i, all_run_stats_i in enumerate(all_run_stats):
            add_multi_run_return_to_ax(all_run_stats_i, config[i], ax)
    else:
        add_multi_run_return_to_ax(all_run_stats, config, ax)

    fig.suptitle(f'Mean return')
    ax.set_xlabel("Iterations")
    ax.set_ylabel("Mean agent return")
    ax.legend()
    fig.tight_layout()

    if w:
        w.add_figure(f"multi_run/return", fig)

    return fig


def add_discrete_value_hist(tag: str, value_counts: Dict, writer: SummaryWriter, global_step: int):
    """
    Adds a (hacky) histogram to tensorboard that shows the number of values in value_counts.

    :param tag: The tag for this histogram
    :param value_counts: The value counts dict, contains entries of the form {value: count}
    :param writer: The SummaryWriter
    :param global_step: The global step for this histogram
    """
    min_val = None
    max_val = None
    num = 0
    sum = 0
    sum_sq = 0

    bucket_counts = []
    bucket_limits = []

    for val in sorted(value_counts.keys()):
        count = value_counts[val]

        bucket_counts.append(0)
        bucket_limits.append(val - 0.0001)
        bucket_counts.append(count)
        bucket_limits.append(val + 0.0001)

        sum += val * count
        sum_sq += val * val * count
        num += count

        if min_val is None and max_val is None:
            min_val = val
            max_val = val
        else:
            min_val = min(val, min_val)
            max_val = max(val, max_val)

    bucket_counts.append(0)
    bucket_limits.append(val + 0.1)

    writer.add_histogram_raw(
        tag=tag,
        min=min_val - 0.5,
        max=max_val + 0.5,
        num=num,
        sum=sum,
        sum_squares=sum_sq,
        bucket_limits=bucket_limits,
        bucket_counts=bucket_counts,
        global_step=global_step)
