import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Uniform

from mawm import utils
from addict import Dict
from typing import Iterable


def sparsemax(x, k, noise=None, gumbel_tau=None, gumbel_hard=False, dim=-1):
    if noise is not None and noise > 0:
        x = (1 + Uniform(-noise, noise).sample(x.shape).to(x.device)) * x
    # Gumbel if we need to
    if gumbel_tau is None:
        softmaxed = torch.softmax(x, dim=dim)
    else:
        softmaxed = F.gumbel_softmax(x, tau=gumbel_tau, dim=dim, hard=gumbel_hard)
    # Efficient code path in case we don't need sparsification
    if k == x.shape[dim] or k == -1:
        normalized = softmaxed
    elif gumbel_tau is not None and gumbel_hard:
        normalized = softmaxed
    else:
        assert 0 < k < x.shape[dim]
        sparsified = F.relu(softmaxed - torch.topk(softmaxed, k=(k + 1), dim=dim).values[:, -1:])
        normalized = sparsified / sparsified.sum(dim=dim, keepdim=True)
    return normalized


def zonal_kernel(x, bandwidth, truncate_at=None, straight_through=True):
    kernel = torch.exp(-2 * bandwidth * (1 - x))
    if truncate_at is not None:
        # Generate a truncation mask
        mask = x.detach().clone().gt_(truncate_at)
        # Forward pass is the masked kernel, but the backward pass
        masked_kernel = kernel * mask
        if straight_through:
            truncated_kernel = (masked_kernel - kernel).detach() + kernel
        else:
            truncated_kernel = masked_kernel
        kernel = truncated_kernel
    return kernel


@utils.no_grad
def rollout(world_model, actions, positions, states, goals, prompt_till_step=0, feed_true_positions=False):
    # actions.shape = NTAC, where T defines the len of the rollout
    # For positions, states and goals, T should be at least as large as `prompt_till_step`.
    # positions.shape = NTAC
    # states.shape = NTAChw
    # goals.shape = NA2
    num_steps = actions.shape[1]
    rollout_state = None
    all_rollout_states = []
    # Calculate prompt_till_step if it's a fractional float
    if isinstance(prompt_till_step, float) and 0 < prompt_till_step < 1:
        prompt_till_step = int(num_steps * prompt_till_step)
    # Let it roll
    for t in range(num_steps):
        # Construct an initial state if none exists
        if t == 0:
            assert rollout_state is None
            rollout_state = Dict(actions=actions[:, t:(t + 1)], positions=positions[:, t:(t + 1)],
                                 states=states[:, t:(t + 1)], rnn_state=None, positional_rnn_state=None,
                                 goals=goals)
        else:
            if 0 < t <= prompt_till_step:
                # We're prompting (to allow the RNNs to build up a decent hidden state)
                rollout_state.positions = positions[:, t:(t + 1)]
                rollout_state.states = states[:, t:(t + 1)]
                rollout_state.goals = goals
            # Prep for the rollout step by writing action and goal to the previous rollout state
            rollout_state.actions = actions[:, t:(t + 1)]
        # If feeding true positions all the way, write position to the previous rollout state
        if feed_true_positions:
            rollout_state.next_positions = positions[:, (t + 1):(t + 2)]
        # Append to history
        all_rollout_states.append(rollout_state)
        # And roll
        rollout_state = world_model.one_step_rollout(rollout_state, is_initial=(0 <= t <= prompt_till_step))
    # Final append to history (this time without the actions, because there aren't any)
    all_rollout_states.append(rollout_state)
    # Consolidate
    output = Dict()
    output.positions = torch.cat([_s.positions for _s in all_rollout_states], dim=1)
    output.states = torch.cat([_s.states for _s in all_rollout_states], dim=1)
    output.recons = torch.cat([_s.recons for _s in all_rollout_states[1:]], dim=1)
    output.actions = actions
    output.rollout_states = all_rollout_states
    return output


def polar_pad(x, kernel_size, offsets=((0, 0), (0, 0))):
    """Pad x such that a 'valid' convolution implements a circular/polar convolution."""
    # Pad circular along the last axis, but constant along the second last
    ks_r, ks_theta = kernel_size
    pad_r, pad_theta = ks_r // 2, ks_theta // 2
    offsets_r, offsets_theta = offsets
    x = F.pad(F.pad(x, [pad_theta + offsets_theta[0], pad_theta + offsets_theta[1], 0, 0], mode='circular'),
              [0, 0, pad_r + offsets_r[0], pad_r + offsets_r[1]], mode='constant')
    return x


class PolarPad(nn.Module):
    """Module for polar padding."""
    def __init__(self, kernel_size, offsets=((0, 0), (0, 0))):
        super(PolarPad, self).__init__()
        self.kernel_size = kernel_size
        self.offsets = offsets

    def forward(self, x):
        return polar_pad(x, kernel_size=self.kernel_size, offsets=self.offsets)


class LSTMCell(nn.LSTMCell):
    def forward(self, input, hx=None, cx=None):
        if hx is None:
            assert cx is None
            _hx = None
        else:
            assert cx is not None
            _hx = (hx, cx)
        return super(LSTMCell, self).forward(input, _hx)


class MaskedReconstructor(object):
    VALID_MARKERS = {'l', 'b', 'm', 'o'}

    def __init__(self, mask_spec: str, channel_dim: int = 3):
        """
        Parameters
        ----------
        mask_spec : dict
            Specifies the masking recipe. It's a dict  which is a string containing {l, b, m, o}.
            E.g. for SC2, one will expect: 'l' * 4 + 'b' + 'm' * 9 + 'l' + 'o' * 2 = 'llllbmmmmmmmmmloo'.
            This means that we're expecting 4 channels of [l]inear (i.e. regression) outputs, 1 channel of
            [b]inary logits (pre-sigmoid), 9 channels of [m]ultinomial logit maps and 2 channels of unmasked [o] inputs.
        channel_dim : int
            Specifies the channel dimension in the tensors that are received.
        """
        self.mask_spec = mask_spec
        self.channel_dim = channel_dim

    @property
    def num_channels(self):
        return len(self.mask_spec)

    @property
    def num_mask_channels(self):
        # TODO
        return 0

    def _get_channel_idxs(self, marker: str):
        return [i for i, m in enumerate(self.mask_spec) if m == marker]

    def _mask_l(self, recon: torch.Tensor, mask: torch.Tensor):
        # Linear masking requires sigmoiding the mask and multiplying with the recon
        return recon * mask.sigmoid()

    def _mask_b(self, recon: torch.Tensor, mask: torch.Tensor):
        # We define Binomial masking as adding the mask logit to the recon logit.
        return recon + mask

    def apply_mask(self, recon_and_mask_logits: torch.Tensor):
        recon_logits, mask_logits = torch.chunk(recon_and_mask_logits, 2, dim=self.channel_dim)
        # Validate
        assert recon_logits.shape[self.channel_dim] == mask_logits.shape[self.channel_dim] == self.num_channels
        # the 'l'
        # TODO
        pass

    __call__ = apply_mask
