import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

import src.models.nn.utils as U
import src.utils as utils
import src.utils.config
import src.utils.train

log = src.utils.train.get_logger(__name__)

class Decoder(nn.Module):
    """ This class doesn't do much but just signals the interface that Decoders are expected to adhere to
    TODO: is there a way to enforce the signature of the forward method?
    """
    def forward(x, y, state, *args, **kwargs):
        """
        x: input tensor
        state: additional state from the model backbone
        *args, **kwargs: additional info from the dataset

        Returns:
        y: output tensor
        *args: other arguments to pass into the loss function
        """
        return x,

class SequenceDecoder(Decoder):
    def __init__(self, d_model, d_output=None, l_output=None, use_lengths=False, mode='last'):
        super().__init__()

        self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output)

        if l_output is None:
            self.l_output = None
            self.squeeze = False
        elif l_output == 0:
            # Equivalent to getting an output of length 1 and then squeezing
            self.l_output = 1
            self.squeeze = True
        else:
            assert l_output > 0
            self.l_output = l_output
            self.squeeze = False

        self.use_lengths = use_lengths
        self.mode = mode

        if mode == 'ragged':
            assert not use_lengths

    def forward(self, x, state, l_batch=None, *args, **kwargs):
        """
        x: (n_batch, l_seq, d_model)
        Returns: (n_batch, l_output, d_output)
        """


        if self.l_output is None:
            if isinstance(l_batch, int): # Override by pass in
                l_output = l_batch
            else:
            # Grab entire output
                l_output = x.size(-2)
            squeeze = False
        else:
            l_output = self.l_output
            squeeze = self.squeeze

        if self.mode == 'last':
            restrict = lambda x: x[..., -l_output:, :]
        elif self.mode == 'first':
            restrict = lambda x: x[..., :l_output, :]
        elif self.mode == 'pool':
            restrict = lambda x: (torch.cumsum(x, dim=-2) / torch.arange(1, 1+x.size(-2), device=x.device, dtype=x.dtype).unsqueeze(-1))[..., -l_output:, :]
            def restrict(x):
                L = x.size(-2)
                s = x.sum(dim=-2, keepdim=True)
                if l_output > 1:
                    c = torch.cumsum(x[..., -(l_output-1):, :].flip(-2), dim=-2)
                    c = F.pad(c, (0, 0, 1, 0))
                    s = s - c # (B, l_output, D)
                    s = s.flip(-2)
                denom = torch.arange(L-l_output+1, L+1, dtype=x.dtype, device=x.device)
                s = s / denom
                return s
        elif self.mode == 'sum':
            restrict = lambda x: torch.cumsum(x, dim=-2)[..., -l_output:, :]
            # TODO use same restrict function as pool case\
        elif self.mode == 'ragged':
            assert l_batch is not None, "l_batch must be provided for ragged mode"
            # remove any additional padding (beyond max length of any sequence in the batch)
            restrict = lambda x: x[..., : max(l_batch), :]
        else: raise NotImplementedError("Mode must be ['last' | 'first' | 'pool' | 'sum']")

        # Restrict to actual length of sequence
        if self.use_lengths:
            assert l_batch is not None
            x = torch.stack([
                restrict(out[..., :length, :])
                for out, length
                in zip(torch.unbind(x, dim=0), l_batch)
            ], dim=0)
        else:
            x = restrict(x)

        if squeeze:
            assert x.size(-2) == 1
            x = x.squeeze(-2)

        x = self.output_transform(x)
        return x,

class StateDecoder(Decoder):
    """ Use the output state to decode (useful for stateful models such as RNNs or perhaps Transformer-XL if it gets implemented """
    def __init__(self, d_model, state_to_tensor, d_output):
        super().__init__()
        self.output_transform = nn.Linear(d_model, d_output)
        self.state_transform = state_to_tensor

    def forward(self, x, state, *args, **kwargs):
        return self.output_transform(self.state_transform(state)),

class RetrievalHead(nn.Module):
    def __init__(self, d_input, d_model, n_classes, nli=True, activation='relu'):
        super().__init__()
        self.nli = nli

        if activation == 'relu':
            activation_fn = nn.ReLU()
        elif activation == 'gelu':
            activation_fn = nn.GELU()
        else: raise NotImplementedError

        if self.nli: # Architecture from https://github.com/mlpen/Nystromformer/blob/6539b895fa5f798ea0509d19f336d4be787b5708/reorganized_code/LRA/model_wrapper.py#L74
            self.classifier = nn.Sequential(
                nn.Linear(4*d_input, d_model),
                activation_fn,
                nn.Linear(d_model, n_classes),
            )
        else: # Head from https://github.com/google-research/long-range-arena/blob/ad0ff01a5b3492ade621553a1caae383b347e0c1/lra_benchmarks/models/layers/common_layers.py#L232
            self.classifier = nn.Sequential(
                nn.Linear(2*d_input, d_model),
                activation_fn,
                nn.Linear(d_model, d_model // 2),
                activation_fn,
                nn.Linear(d_model // 2, n_classes),
            )

    def forward(self, x): # , state, *args, **kwargs):
        """
        x: (2*batch, dim)
        """
        outs = rearrange(x, '(z b) d -> z b d', z=2)
        outs0, outs1 = outs[0], outs[1] # (n_batch, d_input)
        if self.nli:
            features = torch.cat([outs0, outs1, outs0-outs1, outs0*outs1], dim=-1) # (batch, dim)
        else:
            features = torch.cat([outs0, outs1], dim=-1) # (batch, dim)
        logits = self.classifier(features)
        return logits

class RetrievalDecoder(Decoder):
    """ Combines the standard FeatureDecoder to extract a feature before passing through the RetrievalHead """
    def __init__(self, d_input, n_classes, d_model=None, nli=True, activation='relu', *args, **kwargs):
        super().__init__()
        if d_model is None: d_model = d_input
        # self.feature = FeatureDecoder(d_input, None, *args, **kwargs)
        self.feature = SequenceDecoder(d_input, d_output=None, l_output=0, *args, **kwargs)
        self.retrieval = RetrievalHead(d_input, d_model, n_classes, nli=nli, activation=activation)

    def forward(self, x, state, *args, **kwargs):
        x, = self.feature(x, state, *args, **kwargs)
        x = self.retrieval(x)
        return x,

class PackedDecoder(Decoder):
    def forward(self, x, state, *args, **kwargs):
        x, _ = nn.utils.rnn.pad_packed_sequence(x
                , batch_first=True)
        return x,

# For every type of encoder/decoder, specify:
# - constructor class
# - list of attributes to grab from dataset
# - list of attributes to grab from model

registry = {
    'id': U.Identity, # U.TupleModule(nn.Identity),
    'sequence': SequenceDecoder,
    'retrieval': RetrievalDecoder,
    'state': StateDecoder,
    'pack': PackedDecoder,
}
model_attrs = {
    'sequence': ['d_output'],
    'feature': ['d_output'],
    'retrieval': ['d_output'],
    'state': ['d_state', 'state_to_tensor'],
}

dataset_attrs = {
    'sequence': ['d_output', 'l_output'],
    'feature': ['d_output'],
    'retrieval': ['d_output'],
     # TODO change d_output to n_classes?
    'state': ['d_output'],
}

def _instantiate(decoder, model=None, dataset=None):
    """ Instantiate a single decoder """
    if decoder is None: return U.Identity()

    if isinstance(decoder, str): name = decoder
    else: name = decoder['_name_']

    # Extract arguments from attribute names
    dataset_args = utils.config.extract_attrs_from_obj(dataset, *dataset_attrs.get(name, []))
    model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, []))
    # Instantiate decoder
    obj = utils.instantiate(registry, decoder, *model_args, *dataset_args)
    return obj

def instantiate(decoder, model=None, dataset=None):
    """ Instantiate a full decoder config, e.g. handle list of configs
    Note that arguments are added in reverse order compared to encoder (model first, then dataset)
    """
    decoder = utils.to_list(decoder)
    return U.TupleSequential(*[_instantiate(d, model=model, dataset=dataset) for d in decoder])
