import numpy as np
from numba import njit
import torch

from torch.nn import functional as F
from itertools import cycle, islice
from sklearn.metrics import confusion_matrix as _cm


@njit
def jagged_slice(x, centers, window_size):
    # x.shape = TACHW
    # centers.shape = TA2
    # WARNING: assumes centers contains integers bounded by H and W (respectively).
    T, A, C, H, W = x.shape
    h, w = window_size[0], window_size[1]
    processed = np.zeros(shape=(T, A, C, h, w), dtype=np.float32)
    valid = np.zeros(shape=(T, A, h, w), dtype=np.uint8)
    for t in range(T):
        for a in range(A):
            center_i, center_j = centers[t, a, 0], centers[t, a, 1]
            from_i, from_j = center_i - (h // 2), center_j - (w // 2)
            to_i, to_j = center_i + (h // 2) + 1, center_j + (w // 2) + 1
            valid_from_i, valid_to_i = 0, h
            valid_from_j, valid_to_j = 0, w
            # Mark validity
            if from_i < 0:
                valid_from_i = -from_i
                from_i = 0
            if from_j < 0:
                valid_from_j = -from_j
                from_j = 0
            if to_i > H:
                overshoot = to_i - H
                valid_to_i = h - overshoot
                to_i = H
            if to_j > W:
                overshoot = to_j - W
                valid_to_j = w - overshoot
                to_j = W
            processed[t, a, :, valid_from_i:valid_to_i, valid_from_j:valid_to_j] = x[t, a, :, from_i:to_i, from_j:to_j]
            valid[t, a, valid_from_i:valid_to_i, valid_from_j:valid_to_j] = 1
    return processed, valid


def flatten_spatial(x):
    shape = x.shape
    assert shape[-2] == shape[-1] == 1
    return x.view(*shape[:-2])


def map_spec_to_int(map_spec):
    # map_spec is a NTA(16)hw tensor, where the 16 is in binary
    N, T, A, C, H, W = map_spec.shape
    assert C == 16
    exponents = np.power(2, np.arange(16)[None, None, None, :, None, None])
    # The following makes a NTAHW tensor with of integers
    int_map_spec = (exponents * map_spec).sum(3)
    return int_map_spec


def round_to_bins(x, bin_start, bin_stop, num_bins, additional_buffer_bin=False):
    assert bin_stop >= bin_start
    bins = np.linspace(bin_start, bin_stop, num_bins)
    spacing = bins[1] - bins[0]
    bins = bins - spacing / 2
    if additional_buffer_bin:
        bins = np.concatenate([bins, np.array([bins[-1] + spacing])])
    rounded = np.digitize(x, bins) - 1
    return rounded


def add_null_class_to_onehot_map(x: torch.Tensor, dim=3) -> torch.Tensor:
    # noinspection PyTypeChecker
    return torch.cat([x, 1. - x.sum(dim, keepdim=True)], dim=dim)


def add_zero_channel_to_logit_map(x: torch.Tensor, dim=3) -> torch.Tensor:
    shape = list(x.shape)
    shape[dim] = 1
    return torch.cat([x, torch.zeros(*shape, dtype=x.dtype, device=x.device)], dim=dim)


def softmax_logit_map(x: torch.Tensor, onehot=False, dim=3) -> torch.Tensor:
    # This function softmaxes x, but after concatenating a zero-channel. Finally, this channel is stripped
    # and a tensor the same shape as that of x is returned.
    zero_shape = list(x.shape)
    zero_shape[dim] = 1
    pre_softmax = torch.cat([x, torch.zeros(*zero_shape, device=x.device, dtype=x.dtype)], dim=dim)
    softmaxed = torch.softmax(pre_softmax, dim=dim)
    if onehot:
        predicted_class = torch.argmax(softmaxed, dim)
        onehot_or_softmax = torch.zeros_like(softmaxed).scatter_(dim, predicted_class, 1)
    else:
        onehot_or_softmax = softmaxed
    # The last channel should have "absorbed" the negatives in the map
    trimmed = torch.index_select(onehot_or_softmax, dim,
                                 torch.arange(onehot_or_softmax.shape[dim] - 1,
                                              dtype=torch.long, device=onehot_or_softmax.device))
    assert trimmed.shape == x.shape
    return trimmed


def roundrobin(*iterables):
    """roundrobin('ABC', 'D', 'EF') --> A D E B F C"""
    # Recipe credited to George Sakkis
    pending = len(iterables)
    nexts = cycle(iter(it).__next__ for it in iterables)
    while pending:
        try:
            for next in nexts:
                yield next()
        except StopIteration:
            pending -= 1
            nexts = cycle(islice(nexts, pending))


class RoundRobinLoader(object):
    def __init__(self, *loaders):
        self.loaders = loaders

    def __iter__(self):
        return roundrobin(*self.loaders)

    def __len__(self):
        return sum([len(loader) for loader in self.loaders])


def no_grad(fn):
    def _no_grad_fn(*args, **kwargs):
        with torch.no_grad():
            return fn(*args, **kwargs)
    return _no_grad_fn


def make_reconstruction_to_state_processor(env='rail_env', eps=10e-5, **kwargs):
    assert env in ['rail_env', 'sc2', 'bb']
    from mawm.envs import RailEnvTrajectories
    from mawm.envs import SC2Trajectories
    from mawm.envs import BBTrajectories

    @no_grad
    def _rail_env_reconstruction_to_state(recons, rollout_state):
        # recons have the same format as target state, i.e. they're NTA(21)hw tensors. The input to the WM is OTOH a
        # NTA(23)hw tensor, where the extra two channels are the ego marker and the done marker.
        # The done-marker is 1 if the metric between position and goal embeddings is almost one
        # positions.shape = NTAC
        positions = rollout_state.positions
        _n, _t, _a, _c = positions.shape
        # goals.shape = NAC
        goals = rollout_state.goals
        assert goals.shape == (_n, _a, _c)
        goals = goals[:, None, :, :].expand(_n, _t, _a, _c)
        # recons.shape = NTAChw
        _, _, _, _rc, _h, _w = recons.shape
        assert recons.shape == (_n, _t, _a, _rc, _h, _w)
        # similarities.shape = NTA
        similarities = F.cosine_similarity(positions, goals, dim=-1)
        # done_markers.shape = NTAhw
        done_markers = similarities.gt_(1 - eps)[:, :, :, None, None].expand(_n, _t, _a, _h, _w)
        # ego_markers.shape = NTAhw
        ego_markers = torch.zeros(_n, _t, _a, _h, _w, device=recons.device)
        ego_markers[:, :, :, _h // 2, _w // 2] = 1.
        # Construct state
        state = torch.zeros(_n, _t, _a, 23, _h, _w, device=recons.device)
        # Set target channels to recons (after thresholding, if required)
        state[:, :, :, RailEnvTrajectories.TARGET_CHANNELS, :, :] = (recons.clone().gt_(0.)
                                                                     if kwargs.get('threshold', True)
                                                                     else recons.sigmoid())
        # Now set the hitherto unset channels (channel 20 for ego markers and 22 for done marker)
        state[:, :, :, 20, :, :] = ego_markers
        state[:, :, :, 22, :, :] = done_markers
        # Done
        return state

    @no_grad
    def _sc2_reconstruction_to_state(recons, rollout_state):
        # Recons contains pretty much everything that's needed to get the next state.
        state = SC2Trajectories.recon_to_state(recon_or_components=recons,
                                               return_components=False)
        return state

    @no_grad
    def _bb_reconstruction_to_state(recons, rollout_state):
        # recons is a NTA1hw tensor of pre-sigmoidal logits. Simply thresholding should be enough
        recons = recons.clone().gt_(0.)
        return recons

    if env == 'rail_env':
        return _rail_env_reconstruction_to_state
    elif env == 'sc2':
        return _sc2_reconstruction_to_state
    elif env == 'bb':
        return _bb_reconstruction_to_state
    else:
        raise NotImplementedError


def drop_agents(keep_num_agents, *tensors):
    assert len(tensors) > 0
    # tensors should have the shape NTA...
    num_agents = tensors[0].shape[2]
    if isinstance(keep_num_agents, float) and 0. <= keep_num_agents <= 1:
        keep_num_agents = max(round(num_agents * keep_num_agents), 1)
    keep_agents_at_idx = torch.sort(torch.randperm(num_agents, device=tensors[0].device)[0:keep_num_agents]).values
    return [tensor[:, :, keep_agents_at_idx] for tensor in tensors]


class CodepathNotReadyError(NotImplementedError):
    pass


def confusion_matrix(input: torch.Tensor, target: torch.Tensor, num_classes: int) -> torch.Tensor:
    input = input.detach().cpu().numpy().ravel()
    target = target.detach().cpu().numpy().ravel()
    conf_mat = torch.from_numpy(_cm(target.astype('int'), input.astype('int'),
                                    labels=list(range(num_classes)))).float()
    return conf_mat


def not_important(fn):
    def _fn(*args, **kwargs):
        try:
            return _fn(*args, **kwargs)
        except Exception:
            return None
    return _fn


class MockWandBMixin(object):
    wandb_directory = None
    wandb_run = None
    wandb_config = None
    initialize_wandb = lambda *args, **kwargs: None
    wandb_run_id = None
    find_existing_wandb_run_id = lambda *args, **kwargs: None
    dump_wandb_info = lambda *args, **kwargs: None
    wandb_pause_step_counter = lambda *args, **kwargs: None
    wandb_resume_step_counter = lambda *args, **kwargs: None
    as_wandb_image = lambda *args, **kwargs: None
    wandb_log = lambda *args, **kwargs: None
    wandb_log_scalar = lambda *args, **kwargs: None
    wandb_log_image = lambda *args, **kwargs: None
    wandb_watch = lambda *args, **kwargs: None
    log_wandb_now = None