import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce

import src.models.nn.utils as U
import src.utils as utils
import src.utils.config
import src.utils.train

log = src.utils.train.get_logger(__name__)


class Decoder(nn.Module):
    """This class doesn't do much but just signals the interface that Decoders are expected to adhere to
    TODO: is there a way to enforce the signature of the forward method?
    """

    def forward(self, x, **kwargs):
        """
        x: (batch, length, dim) input tensor
        state: additional state from the model backbone
        *args, **kwargs: additional info from the dataset

        Returns:
        y: output tensor
        *args: other arguments to pass into the loss function
        """
        return x

    def step(self, x):
        """
        x: (batch, dim)
        """
        return self.forward(x.unsqueeze(1)).squeeze(1)


class SequenceDecoder(Decoder):
    def __init__(
        self, d_model, d_output=None, l_output=None, use_lengths=False, mode="last"
    ):
        super().__init__()

        self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output)

        if l_output is None:
            self.l_output = None
            self.squeeze = False
        elif l_output == 0:
            # Equivalent to getting an output of length 1 and then squeezing
            self.l_output = 1
            self.squeeze = True
        else:
            assert l_output > 0
            self.l_output = l_output
            self.squeeze = False

        self.use_lengths = use_lengths
        self.mode = mode

        if mode == 'ragged':
            assert not use_lengths

    def forward(self, x, state=None, lengths=None, l_output=None):
        """
        x: (n_batch, l_seq, d_model)
        Returns: (n_batch, l_output, d_output)
        """

        if self.l_output is None:
            if l_output is not None:
                assert isinstance(l_output, int)  # Override by pass in
            else:
                # Grab entire output
                l_output = x.size(-2)
            squeeze = False
        else:
            l_output = self.l_output
            squeeze = self.squeeze

        if self.mode == "last":
            restrict = lambda x: x[..., -l_output:, :]
        elif self.mode == "first":
            restrict = lambda x: x[..., :l_output, :]
        elif self.mode == "pool":
            restrict = lambda x: (
                torch.cumsum(x, dim=-2)
                / torch.arange(
                    1, 1 + x.size(-2), device=x.device, dtype=x.dtype
                ).unsqueeze(-1)
            )[..., -l_output:, :]

            def restrict(x):
                L = x.size(-2)
                s = x.sum(dim=-2, keepdim=True)
                if l_output > 1:
                    c = torch.cumsum(x[..., -(l_output - 1) :, :].flip(-2), dim=-2)
                    c = F.pad(c, (0, 0, 1, 0))
                    s = s - c  # (B, l_output, D)
                    s = s.flip(-2)
                denom = torch.arange(
                    L - l_output + 1, L + 1, dtype=x.dtype, device=x.device
                )
                s = s / denom
                return s

        elif self.mode == "sum":
            restrict = lambda x: torch.cumsum(x, dim=-2)[..., -l_output:, :]
            # TODO use same restrict function as pool case
        elif self.mode == 'ragged':
            assert lengths is not None, "lengths must be provided for ragged mode"
            # remove any additional padding (beyond max length of any sequence in the batch)
            restrict = lambda x: x[..., : max(lengths), :]
        else:
            raise NotImplementedError(
                "Mode must be ['last' | 'first' | 'pool' | 'sum']"
            )

        # Restrict to actual length of sequence
        if self.use_lengths:
            assert lengths is not None
            x = torch.stack(
                [
                    restrict(out[..., :length, :])
                    for out, length in zip(torch.unbind(x, dim=0), lengths)
                ],
                dim=0,
            )
        else:
            x = restrict(x)

        if squeeze:
            assert x.size(-2) == 1
            x = x.squeeze(-2)

        x = self.output_transform(x)

        return x

    def step(self, x, state=None):
        # Ignore all length logic
        return self.output_transform(x)

class NDDecoder(Decoder):
    """Decoder for single target (e.g. classification or regression)"""
    def __init__(
        self, d_model, d_output=None, mode="pool"
    ):
        super().__init__()

        assert mode in ["pool", "full"]
        self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output)

        self.mode = mode

    def forward(self, x, state=None):
        """
        x: (n_batch, l_seq, d_model)
        Returns: (n_batch, l_output, d_output)
        """

        if self.mode == 'pool':
            x = reduce(x, 'b ... h -> b h', 'mean')
        x = self.output_transform(x)
        return x

class StateDecoder(Decoder):
    """Use the output state to decode (useful for stateful models such as RNNs or perhaps Transformer-XL if it gets implemented"""

    def __init__(self, d_model, state_to_tensor, d_output):
        super().__init__()
        self.output_transform = nn.Linear(d_model, d_output)
        self.state_transform = state_to_tensor

    def forward(self, x, state=None):
        return self.output_transform(self.state_transform(state))


class RetrievalHead(nn.Module):
    def __init__(self, d_input, d_model, n_classes, nli=True, activation="relu"):
        super().__init__()
        self.nli = nli

        if activation == "relu":
            activation_fn = nn.ReLU()
        elif activation == "gelu":
            activation_fn = nn.GELU()
        else:
            raise NotImplementedError

        if (
            self.nli
        ):  # Architecture from https://github.com/mlpen/Nystromformer/blob/6539b895fa5f798ea0509d19f336d4be787b5708/reorganized_code/LRA/model_wrapper.py#L74
            self.classifier = nn.Sequential(
                nn.Linear(4 * d_input, d_model),
                activation_fn,
                nn.Linear(d_model, n_classes),
            )
        else:  # Head from https://github.com/google-research/long-range-arena/blob/ad0ff01a5b3492ade621553a1caae383b347e0c1/lra_benchmarks/models/layers/common_layers.py#L232
            self.classifier = nn.Sequential(
                nn.Linear(2 * d_input, d_model),
                activation_fn,
                nn.Linear(d_model, d_model // 2),
                activation_fn,
                nn.Linear(d_model // 2, n_classes),
            )

    def forward(self, x):
        """
        x: (2*batch, dim)
        """
        outs = rearrange(x, "(z b) d -> z b d", z=2)
        outs0, outs1 = outs[0], outs[1]  # (n_batch, d_input)
        if self.nli:
            features = torch.cat(
                [outs0, outs1, outs0 - outs1, outs0 * outs1], dim=-1
            )  # (batch, dim)
        else:
            features = torch.cat([outs0, outs1], dim=-1)  # (batch, dim)
        logits = self.classifier(features)
        return logits


class RetrievalDecoder(Decoder):
    """Combines the standard FeatureDecoder to extract a feature before passing through the RetrievalHead"""

    def __init__(
        self,
        d_input,
        n_classes,
        d_model=None,
        nli=True,
        activation="relu",
        *args,
        **kwargs
    ):
        super().__init__()
        if d_model is None:
            d_model = d_input
        self.feature = SequenceDecoder(
            d_input, d_output=None, l_output=0, *args, **kwargs
        )
        self.retrieval = RetrievalHead(
            d_input, d_model, n_classes, nli=nli, activation=activation
        )

    def forward(self, x, state=None, **kwargs):
        x = self.feature(x, state=state, **kwargs)
        x = self.retrieval(x)
        return x

class PackedDecoder(Decoder):
    def forward(self, x, state=None):
        x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
        return x
    

#An decoder class for the AMOS dataset
class segmentation_3D_Decoder(Decoder):
    _name_ = "decoder_3d_segmentation"
    def __init__(self,patch_size=16,num_classes=16,hidden_dim=128):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.patch_size = patch_size
        self.num_classes = num_classes
        self.decoding_dim = self.patch_size*self.patch_size*self.patch_size*self.num_classes
        self.linear = nn.Linear(self.hidden_dim,self.decoding_dim)
    
    def forward(self,x):
        #x is of shape (batch_size, sequence_length, hidden_dim)

        x = self.linear(x)
        x = x.view(x.shape[0],x.shape[1],x.shape[2]//self.num_classes,self.num_classes)
        x = torch.permute(x,(0,2,1,3))
        x = torch.flatten(x,start_dim=1,end_dim=2)

        return x

class TransitionStepDecoder(Decoder):
    _name_ = "tranistion_step_decoder"
    def __init__(self,hidden_dim=128,num_states=10,loan_pool_size=1,lookback_horizon=25, forecast=False, forecast_horizon=5, expand_dims=True, scale_output=False, output_scale=550, l1_normalize=False): #1500
        super().__init__()
        self.loan_pool_size = loan_pool_size
        self.l1_normalize = l1_normalize
        self.hidden_dim = hidden_dim
        self.num_states = num_states
        self.decoding_dim = self.num_states*self.loan_pool_size
        self.linear = nn.Linear(self.hidden_dim,self.decoding_dim)
        self.forecast = forecast
        self.expand_dims = expand_dims
        self.scale_output = scale_output
        self.output_scale = output_scale
        if forecast:
            self.linear2 = nn.Linear(lookback_horizon,forecast_horizon)
    
    def forward(self,x):
        #x is of shape (batch_size*nr_units, sequence_length, hidden_dim)

        x = self.linear(x)
        if self.forecast:
            x = torch.permute(x,(0,2,1))
            x = self.linear2(x)
            x = torch.permute(x,(0,2,1))
        if self.expand_dims:
            x = x.view(x.shape[0],x.shape[1],self.loan_pool_size,self.num_states)
            x = torch.permute(x,(0,2,1,3))
            #x has shape (B,L,M,S)
        
        if self.scale_output:
            x = x / self.output_scale
        if self.l1_normalize: # This only works as expected if BZ =1
            x = x / torch.norm(x,p=1,dim=0, keepdim=True)

         # 1000, 1, 99, 3
        return x

class StackedDecoder(Decoder):
    """For each time step, produces nr_unit predictions using the same feature set."""
    _name_ = "stacked_decoder"
    def __init__(self, d_model, nr_units=1000, decoding_dim=3, l1_normalize=False, **kwargs):
        super().__init__()
        self.nr_units = nr_units
        self.d_model = d_model
        self.decoding_dim = decoding_dim
        self.l1_normalize = l1_normalize
        self.linear = nn.Linear(self.d_model, self.nr_units*self.decoding_dim)
    def forward(self,x):
        # x is of shape (batch_size, sequence_length, hidden_dim)
        x = self.linear(x) # (batch_size, sequence_length, nr_units*decoding_dim)
        x = x.reshape(x.shape[0], x.shape[1], self.nr_units, self.decoding_dim)
        x = torch.permute(x, (0, 2, 1, 3))
        x = torch.permute(x, (1, 0, 2, 3))
        if self.l1_normalize:
            x = x / torch.norm(x,p=1,dim=0, keepdim=True)
        return x
        
        


class Decoder_timeseries_synthetics(Decoder):
    _name_ = "decoder_timeseries_synthetics"
    def __init__(self,hidden_dim=128,num_states=10):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_states = num_states
        self.decoding_dim = self.num_states*self.num_states
        self.linear = nn.Linear(self.hidden_dim,self.decoding_dim)
    
    def forward(self,x):
        #x is of shape (batch_size, sequence_length, hidden_dim)

        x = self.linear(x) #
        #take average along the sequence dimension
        x = torch.mean(x,dim=1)
        x = x.view(x.shape[0],self.num_states,self.num_states)
        #softmax along the last dimension
        x = F.softmax(x,dim=-1)
        #x has shape (B,N,N)
        #x = x.view(x.shape[0],-1)
        return x


# For every type of encoder/decoder, specify:
# - constructor class
# - list of attributes to grab from dataset
# - list of attributes to grab from model

registry = {
    "stop": Decoder,
    "id": nn.Identity,
    "linear": nn.Linear,
    "sequence": SequenceDecoder,
    "nd": NDDecoder,
    "retrieval": RetrievalDecoder,
    "state": StateDecoder,
    "pack": PackedDecoder,
    "decoder_3d_segmentation": segmentation_3D_Decoder,
    "decoder_timeseries_synthetics": Decoder_timeseries_synthetics,
    "tranistion_step_decoder": TransitionStepDecoder,
    "stacked_decoder": StackedDecoder,
}
model_attrs = {
    "linear": ["d_output"],
    "sequence": ["d_output"],
    "nd": ["d_output"],
    "retrieval": ["d_output"],
    "state": ["d_state", "state_to_tensor"],
    "forecast": ["d_output"],
    "stacked_decoder": ["d_model"]
}

dataset_attrs = {
    "linear": ["d_output"],
    "sequence": ["d_output", "l_output"],
    "nd": ["d_output"],
    "retrieval": ["d_output"],
    "state": ["d_output"],
    "forecast": ["d_output", "l_output"],
}


def _instantiate(decoder, model=None, dataset=None):
    """Instantiate a single decoder"""
    if decoder is None:
        return None

    if isinstance(decoder, str):
        name = decoder
    else:
        name = decoder["_name_"]

    # Extract arguments from attribute names
    dataset_args = utils.config.extract_attrs_from_obj(
        dataset, *dataset_attrs.get(name, [])
    )
    model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, []))
    obj = utils.instantiate(registry, decoder, *model_args, *dataset_args)
    return obj


def instantiate(decoder, model=None, dataset=None):
    """Instantiate a full decoder config, e.g. handle list of configs
    Note that arguments are added in reverse order compared to encoder (model first, then dataset)
    """
    decoder = utils.to_list(decoder)
    return U.PassthroughSequential(
        *[_instantiate(d, model=model, dataset=dataset) for d in decoder]
    )
