from typing import Optional, Union, Dict, List, Iterator

import torch
import torch.utils.data
import matplotlib.pyplot as plt
import numpy as np
from torch.nn import Parameter
from torch.utils.tensorboard import SummaryWriter


def safe_int_cast(val):
    if int(val) == val:
        return int(val)

    raise ValueError(f"Could not cast {val} to int.")


def get_split_shape(x_splits, y_splits, width, height):
    return safe_int_cast(height / (y_splits + 1)), safe_int_cast(width / (x_splits + 1))


def check_valid_split(x_splits, y_splits, w=28, h=28):
    if not (w / (x_splits + 1)).is_integer():
        raise ValueError(f"Invalid split {(x_splits, y_splits)}: Could not evenly split width {w}.")
    if not (h / (y_splits + 1)).is_integer():
        raise ValueError(f"Invalid split {(x_splits, y_splits)}: Could not evenly split height {h}.")


def check_valid_splits(splits, w=28, h=28):
    for (x, y) in splits:
        check_valid_split(x, y, w, h)


def crop(array: np.ndarray, x, y, width, height):
    return array[y:(y + height), x:(x + width)]


def crop_batch(array: np.ndarray, x, y, width, height):
    return array[:, y:(y + height), x:(x + width)]


def crop_tensor(t: torch.Tensor, x, y, width, height):
    if len(t.shape) == 3:
        return t[:, y:(y+height), x:(x+width)]

    return t[y:(y + height), x:(x + width)]


def show(sample, ax=plt):
    if isinstance(sample, torch.Tensor):
        sample = sample.numpy()

    if len(sample.shape) == 4:
        # select the first element of the batch
        sample = sample[0]

    if len(sample.shape) == 3:
        # select the color channel
        sample = sample[0]

    ax.imshow(sample)


def write_dict_scalars(writer: SummaryWriter, values: dict, global_step, category='values'):
    """
    Adds all the scalars (and tensors) in the values dict to the summary writer.

    category/key <= All scalars will be added in like this
                    If there are tensors in the values dict, this function also adds mean and std

    category-details/key_a <= All individual values of tensors with > 1 elements will be added like this

    :param writer: A summary writer
    :param values: The values dict which can contain scalars and tensors
    :param global_step: The global step for the writer
    :param category: Category under which all values should be added (must not contain '/')
    """
    assert category.find('/') == -1, "Please don't use any '/' in the category name."

    for key in values:
        elem = values[key]
        if elem is None:
            continue

        if isinstance(elem, np.ndarray):
            elem = torch.Tensor(elem)

        if isinstance(elem, torch.Tensor) and elem.numel() > 1:
            for a in range(0, elem.numel()):
                writer.add_scalar(f"{category}-details/{key}_{a}", elem[a], global_step)

            writer.add_scalar(f"{category}/{key}_mean", elem.mean(), global_step)
            writer.add_scalar(f"{category}/{key}_std", elem.std(), global_step)
            continue

        writer.add_scalar(f"{category}/{key}", elem, global_step)


def set_title(fig, title: str):
    fig.suptitle(title)
    fig.canvas.set_window_title(title)


def to_html_linebreak(text: str):
    return text.replace("\r\n", "\n").replace("\n", "<br/>")


def to_markdown_code_list(values_dict: dict):
    txt = ""
    for key in values_dict:
        txt += f"- {key}: `{values_dict[key]}`<br/>"
    return txt


def eval_stats_to_metric_dict(eval_stats: Dict, type_filter=None, key_filter=None):
    """
    Converts the given eval stats dict to a dictionary that is compatible with
    torch.utils.tensorboard.writer.SummaryWriter.add_hparams.

    :param eval_stats: the eval stats dict
    :param type_filter: filter the eval stats for a specific type (train/test)
    :param key_filter: filter the eval stats for a specific key
    :return: a metric dict compatible with add_hparams that contains the (filtered) eval stats.
    """
    scalar_dict = {}
    for type, type_stats in eval_stats.items():
        if type_filter is not None:
            if type not in type_filter:
                continue

        for key, stats in type_stats.items():
            # skip keys if they are not in the filter
            if key_filter is not None:
                if key not in key_filter:
                    continue

            if len(stats) == 1:
                scalar_dict["zhparam/" + type + "_" + key] = stats.item()
            else:
                # add aggregated stats
                scalar_dict["zhparam/" + type + "_" + key + "_mean"] = stats.mean()
                scalar_dict["zhparam/" + type + "_" + key + "_std"] = stats.std()

                # and individual values
                for agent_index, value in enumerate(stats):
                    scalar_dict["zhparam/" + type + "_" + key + "_" + str(agent_index)] = value

    return scalar_dict


def add_hparams_to(writer: SummaryWriter, hparam_dict=None, metric_dict=None, global_step=None):
    """
    Adds hparams to the given writer. Adapted from torch.utils.tensorboard.writer.SummaryWriter.add_hparams.
    """
    if type(hparam_dict) is not dict or type(metric_dict) is not dict:
        raise TypeError('hparam_dict and metric_dict should be dictionary.')
    from torch.utils.tensorboard.summary import hparams
    exp, ssi, sei = hparams(hparam_dict, metric_dict)

    writer.file_writer.add_summary(exp)
    writer.file_writer.add_summary(ssi)
    writer.file_writer.add_summary(sei)
    for k, v in metric_dict.items():
        writer.add_scalar(k, v, global_step)


def add_eval_stats_text(all_eval_stats: Union[Dict, List[Dict]], w: SummaryWriter):
    """
    Adds the provided eval stats to tensorboard. Can be eval stats from a single run or from multiple runs.

    :param all_eval_stats: eval stats dict from a single run or list of eval stats
    :param w: the summarywriter
    """
    aggregated_eval_stats = {}

    if not isinstance(all_eval_stats, list):
        all_eval_stats = [all_eval_stats]

    # build lists for each stat in all_eval_stats
    for single_eval_stats in all_eval_stats:
        for type, type_stats in single_eval_stats.items():
            for key, stats in type_stats.items():
                if not type in aggregated_eval_stats:
                    aggregated_eval_stats[type] = {}
                if not key in aggregated_eval_stats[type]:
                    aggregated_eval_stats[type][key] = []

                aggregated_eval_stats[type][key].append(stats)

    num_stats = len(all_eval_stats)
    # aggregate these lists and write the values to tensorboard
    for type, type_stats in aggregated_eval_stats.items():
        for key, stat_list in type_stats.items():
            values = np.array(stat_list).mean(axis=0)
            w.add_text(f"eval_{type}_{key}{'/mean_' + str(num_stats) if num_stats > 1 else ''}", str(values))


def epsilon_greedy(q: torch.Tensor, epsilon: float, device):
    """
    Epsilon-greedy selection in batches (select random index with probability epsilon, highest q-value otherwise).

    :param q: The q values with dimensions (batch_size, q_values)
    :param epsilon: The value for epsilon with 0 <= epsilon <= 1
    :param device: The device
    :return: A selection of indices according to epsilon-greedy
    """
    if epsilon == 0:
        return q.argmax(dim=-1)

    if epsilon < 1:
        best_index = q.argmax(dim=-1)
        # mask is 0 with probability epsilon and 1 with probability (1 - epsilon)
        mask = torch.rand((q.shape[0],), device=device) >= epsilon
        rand_index = torch.randint(0, q.shape[1], (q.shape[0],), device=device)
        return mask * best_index + ~mask * rand_index

    return torch.randint(0, q.shape[1], (q.shape[0],), device=device)


def get_grad_norm(parameters: Iterator[Parameter], norm_type):
    """
    Get the gradient norm for the given parameters.

    :param parameters: Model parameters
    :param norm_type: torch norm type
    :return: the norm
    """
    parameters = [p for p in parameters if p.grad is not None]
    return torch.norm(
        torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]),
        norm_type
    )
