import torch
import torch.nn
import torch.nn.functional as F
from layers import Transformer, TiedEmbedding
from typing import Callable, Optional
import math
import layers


# Cannot be dataclass, because that won't work with gather


class DotDict(dict):
    def __getattr__(self, item):
        if item not in self:
            raise AttributeError
        return self.get(item)

    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


class TransformerResult(DotDict):
    data: torch.Tensor
    length: torch.Tensor

    @staticmethod
    def create(data: torch.Tensor, length: torch.Tensor):
        return TransformerResult({"data": data, "length": length})


class TransformerDecModel(torch.nn.Module):
    """
    finetune a transformer decoder on top of features from a pre-trained transformer encoder
    """

    def __init__(
        self,
        n_input_tokens: int,
        n_out_tokens: int,
        state_size: int = 512,
        ff_multiplier: float = 1,
        max_len: int = 5000,
        transformer=Transformer,
        tied_embedding: bool = False,
        same_enc_dec_embedding: bool = False,
        embedding_init: str = "pytorch",
        in_embedding_size: Optional[int] = None,
        out_embedding_size: Optional[int] = None,
        scale_mode: str = "none",
        **kwargs
    ):
        super().__init__()
        assert scale_mode in ["none", "opennmt", "down"]
        assert embedding_init in ["pytorch", "xavier", "kaiming"]

        assert (not same_enc_dec_embedding) or (n_input_tokens == n_out_tokens)

        self.tied_embedding = tied_embedding

        self.decoder_sos_eos = n_out_tokens
        self.state_size = state_size
        self.embedding_init = embedding_init
        self.ff_multiplier = ff_multiplier
        self.n_input_tokens = n_input_tokens
        self.n_out_tokens = n_out_tokens
        self.in_embedding_size = in_embedding_size
        self.out_embedding_size = out_embedding_size
        self.scale_mode = scale_mode
        self.pos = layers.PositionalEncoding(
            state_size,
            max_len=max_len,
            batch_first=True,
            scale=(1.0 / math.sqrt(state_size)) if scale_mode == "down" else 1.0,
        )

        self.register_buffer("int_seq", torch.arange(max_len, dtype=torch.long))
        self.construct(transformer, **kwargs)
        self.reset_parameters()

    def pos_embed(
        self, t: torch.Tensor, offset: int, scale_offset: int
    ) -> torch.Tensor:
        if self.scale_mode == "opennmt":
            t = t * math.sqrt(t.shape[-1])

        return self.pos(t, offset)

    def construct(self, transformer, **kwargs):
        self.output_embedding = torch.nn.Embedding(
            self.n_out_tokens + 1, self.out_embedding_size or self.state_size
        )
        if self.out_embedding_size is not None:
            self.out_embedding_upscale = torch.nn.Linear(
                self.out_embedding_size, self.state_size
            )

        if self.tied_embedding:
            assert self.out_embedding_size is None
            self.output_map = TiedEmbedding(self.output_embedding.weight)
        else:
            self.output_map = torch.nn.Linear(self.state_size, self.n_out_tokens + 1)

        self.trafo = transformer(
            d_model=self.state_size,
            dim_feedforward=int(self.ff_multiplier * self.state_size),
            is_null_encoder=True,
            **kwargs
        )

    def reset_parameters(self):
        if self.embedding_init == "xavier":
            torch.nn.init.xavier_uniform_(self.output_embedding.weight)
        elif self.embedding_init == "kaiming":
            torch.nn.init.kaiming_normal_(self.output_embedding.weight)

        if not self.tied_embedding:
            torch.nn.init.xavier_uniform_(self.output_map.weight)

    def generate_len_mask(self, max_len: int, len: torch.Tensor) -> torch.Tensor:
        return self.int_seq[:max_len] >= len.unsqueeze(-1)

    def output_embed(self, x: torch.Tensor) -> torch.Tensor:
        o = self.output_embedding(x)
        if self.out_embedding_size is not None:
            o = self.out_embedding_upscale(o)
        return o

    def run_greedy(
        self, src: torch.Tensor, src_len: torch.Tensor, max_len: int
    ) -> TransformerResult:
        batch_size = src.shape[0]
        n_steps = src.shape[1]

        in_len_mask = self.generate_len_mask(n_steps, src_len)
        memory = self.trafo.encoder(src, in_len_mask)

        running = torch.ones([batch_size], dtype=torch.bool, device=src.device)
        out_len = torch.zeros_like(running, dtype=torch.long)

        next_tgt = self.pos_embed(
            self.output_embed(
                torch.full(
                    [batch_size, 1],
                    self.decoder_sos_eos,
                    dtype=torch.long,
                    device=src.device,
                )
            ),
            0,
            1,
        )

        all_outputs = []
        state = self.trafo.decoder.create_state(src.shape[0], max_len, src.device)

        for i in range(max_len):
            output = self.trafo.decoder.one_step_forward(
                state, next_tgt, memory, memory_key_padding_mask=in_len_mask
            )

            output = self.output_map(output)
            all_outputs.append(output)

            out_token = torch.argmax(output[:, -1], -1)
            running &= out_token != self.decoder_sos_eos

            out_len[running] = i + 1
            next_tgt = self.pos_embed(
                self.output_embed(out_token).unsqueeze(1), i + 1, 1
            )

        return TransformerResult.create(torch.cat(all_outputs, 1), out_len)

    def run_teacher_forcing(
        self,
        src: torch.Tensor,
        src_len: torch.Tensor,
        target: torch.Tensor,
        target_len: torch.Tensor,
    ) -> TransformerResult:
        target = self.output_embed(
            F.pad(target[:, :-1], (1, 0), value=self.decoder_sos_eos).long()
        )
        target = self.pos_embed(target, 0, 1)

        in_len_mask = self.generate_len_mask(src.shape[1], src_len)

        res = self.trafo(
            src,
            target,
            src_length_mask=in_len_mask,
            tgt_mask=self.trafo.generate_square_subsequent_mask(
                target.shape[1], src.device
            ),
        )

        return TransformerResult.create(self.output_map(res), target_len)

    def forward(
        self,
        src: torch.Tensor,
        src_len: torch.Tensor,
        target: torch.Tensor,
        target_len: torch.Tensor,
        teacher_forcing: bool,
        max_len: Optional[int] = None,
    ) -> TransformerResult:
        """
        Run transformer encoder-decoder on some input/output pair

        :param src: source features. Shape: [N, S, D], where S in the in sequence length, N is the batch size
        :param src_len: length of source sequences. Shape: [N], N is the batch size
        :param target: target tensor. Shape: [N, S], where T in the in sequence length, N is the batch size
        :param target_len: length of target sequences. Shape: [N], N is the batch size
        :param teacher_forcing: use teacher forcing or greedy decoding
        :param max_len: overwrite autodetected max length. Useful for parallel execution
        :return: prediction of the target tensor. Shape [N, T, C_out]
        """

        if teacher_forcing:
            return self.run_teacher_forcing(src, src_len, target, target_len)
        else:
            return self.run_greedy(src, src_len, max_len or target_len.max().item())


class TransformerEncDecModel(torch.nn.Module):
    def __init__(
        self,
        n_input_tokens: int,
        n_out_tokens: int,
        state_size: int = 512,
        nheads: int = 8,
        ff_multiplier: float = 1,
        max_len: int = 5000,
        transformer=Transformer,
        tied_embedding: bool = False,
        pos_embeddig: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = None,
        encoder_sos: bool = True,
        same_enc_dec_embedding: bool = False,
        embedding_init: str = "pytorch",
        in_embedding_size: Optional[int] = None,
        out_embedding_size: Optional[int] = None,
        scale_mode: str = "none",
        **kwargs
    ):
        """
        Transformer encoder-decoder.

        :param n_input_tokens: Number of channels for the input vectors
        :param n_out_tokens: Number of channels for the output vectors
        :param state_size: The size of the internal state of the transformer
        """
        super().__init__()

        assert scale_mode in ["none", "opennmt", "down"]
        assert embedding_init in ["pytorch", "xavier", "kaiming"]

        assert (not same_enc_dec_embedding) or (n_input_tokens == n_out_tokens)

        self.tied_embedding = tied_embedding

        self.decoder_sos_eos = n_out_tokens
        self.encoder_eos = n_input_tokens
        self.encoder_sos = n_input_tokens + 1 if encoder_sos else None
        self.state_size = state_size
        self.embedding_init = embedding_init
        self.nheads = nheads
        self.ff_multiplier = ff_multiplier
        self.n_input_tokens = n_input_tokens
        self.n_out_tokens = n_out_tokens
        self.in_embedding_size = in_embedding_size
        self.out_embedding_size = out_embedding_size
        self.same_enc_dec_embedding = same_enc_dec_embedding
        self.scale_mode = scale_mode
        self.pos = pos_embeddig or layers.PositionalEncoding(
            state_size,
            max_len=max_len,
            batch_first=True,
            scale=(1.0 / math.sqrt(state_size)) if scale_mode == "down" else 1.0,
        )

        self.register_buffer("int_seq", torch.arange(max_len, dtype=torch.long))
        self.construct(transformer, **kwargs)
        self.reset_parameters()

        if "mode" in kwargs:
            if kwargs["mode"] == "mlm":
                self.mlm_head = torch.nn.Linear(self.state_size, self.n_input_tokens)
                self.mode = "mlm"
            elif kwargs["mode"] == "classifier":
                self.classifier_head = torch.nn.Linear(
                    self.state_size, self.n_out_tokens
                )
                self.mode = "classifier"
            else:
                self.mode = "enc_dec"
        else:
            self.mode = "enc_dec"

    def pos_embed(
        self, t: torch.Tensor, offset: int, scale_offset: int
    ) -> torch.Tensor:
        if self.scale_mode == "opennmt":
            t = t * math.sqrt(t.shape[-1])

        return self.pos(t, offset)

    def construct(self, transformer, **kwargs):
        self.input_embedding = torch.nn.Embedding(
            self.n_input_tokens + 1 + int(self.encoder_sos is not None),
            self.in_embedding_size or self.state_size,
        )
        self.output_embedding = (
            self.input_embedding
            if self.same_enc_dec_embedding
            else torch.nn.Embedding(
                self.n_out_tokens + 1, self.out_embedding_size or self.state_size
            )
        )

        if self.in_embedding_size is not None:
            self.in_embedding_upscale = torch.nn.Linear(
                self.in_embedding_size, self.state_size
            )

        if self.out_embedding_size is not None:
            self.out_embedding_upscale = torch.nn.Linear(
                self.out_embedding_size, self.state_size
            )

        if self.tied_embedding:
            assert self.out_embedding_size is None
            self.output_map = TiedEmbedding(self.output_embedding.weight)
        else:
            self.output_map = torch.nn.Linear(self.state_size, self.n_out_tokens + 1)

        self.trafo = transformer(
            d_model=self.state_size,
            nhead=self.nheads,
            dim_feedforward=int(self.ff_multiplier * self.state_size),
            **kwargs
        )

    def reset_parameters(self):
        if self.embedding_init == "xavier":
            torch.nn.init.xavier_uniform_(self.input_embedding.weight)
            torch.nn.init.xavier_uniform_(self.output_embedding.weight)
        elif self.embedding_init == "kaiming":
            torch.nn.init.kaiming_normal_(self.input_embedding.weight)
            torch.nn.init.kaiming_normal_(self.output_embedding.weight)

        if not self.tied_embedding:
            torch.nn.init.xavier_uniform_(self.output_map.weight)

    def generate_len_mask(self, max_len: int, len: torch.Tensor) -> torch.Tensor:
        return self.int_seq[:max_len] >= len.unsqueeze(-1)

    def output_embed(self, x: torch.Tensor) -> torch.Tensor:
        o = self.output_embedding(x)
        if self.out_embedding_size is not None:
            o = self.out_embedding_upscale(o)
        return o

    def run_greedy(
        self, src: torch.Tensor, src_len: torch.Tensor, max_len: int
    ) -> TransformerResult:
        batch_size = src.shape[0]
        n_steps = src.shape[1]

        if len(src.shape) == 2:
            src = self.pos_embed(self.input_embed(src), 0, 0)

        in_len_mask = self.generate_len_mask(n_steps, src_len)
        try:
            memory = self.trafo.encoder(src, in_len_mask)
        except Exception:
            breakpoint()

        running = torch.ones([batch_size], dtype=torch.bool, device=src.device)
        out_len = torch.zeros_like(running, dtype=torch.long)

        next_tgt = self.pos_embed(
            self.output_embed(
                torch.full(
                    [batch_size, 1],
                    self.decoder_sos_eos,
                    dtype=torch.long,
                    device=src.device,
                )
            ),
            0,
            1,
        )

        all_outputs = []
        state = self.trafo.decoder.create_state(src.shape[0], max_len, src.device)

        for i in range(max_len):
            output = self.trafo.decoder.one_step_forward(
                state, next_tgt, memory, memory_key_padding_mask=in_len_mask
            )

            output = self.output_map(output)
            all_outputs.append(output)

            out_token = torch.argmax(output[:, -1], -1)
            running &= out_token != self.decoder_sos_eos

            out_len[running] = i + 1
            next_tgt = self.pos_embed(
                self.output_embed(out_token).unsqueeze(1), i + 1, 1
            )

        return TransformerResult.create(torch.cat(all_outputs, 1), out_len)

    def run_teacher_forcing(
        self,
        src: torch.Tensor,
        src_len: torch.Tensor,
        target: torch.Tensor,
        target_len: torch.Tensor,
    ) -> TransformerResult:
        target = self.output_embed(
            F.pad(target[:, :-1], (1, 0), value=self.decoder_sos_eos).long()
        )
        target = self.pos_embed(target, 0, 1)

        in_len_mask = self.generate_len_mask(src.shape[1], src_len)

        res = self.trafo(
            src,
            target,
            src_length_mask=in_len_mask,
            tgt_mask=self.trafo.generate_square_subsequent_mask(
                target.shape[1], src.device
            ),
        )

        return TransformerResult.create(self.output_map(res), target_len)

    def input_embed(self, x: torch.Tensor) -> torch.Tensor:
        src = self.input_embedding(x.long())
        if self.in_embedding_size is not None:
            src = self.in_embedding_upscale(src)

        return src

    def get_encoder_layers(self):
        return self.trafo.num_encoder_layers

    def encoder_only(self, src, mask, layer_id=-1, gaussian_noise=None):
        src = self.pos_embed(self.input_embed(src), 0, 0)
        if gaussian_noise is not None:
            src += gaussian_noise
        return self.trafo.get_hidden_states(src, mask, layer_id=layer_id)

    def get_attention_sparsity(self, src, mask):
        src = self.pos_embed(self.input_embed(src), 0, 0)
        attn_matrices = self.trafo.get_attn_matrices(src, mask)
        lens = (~mask).sum(axis=1)
        total_entropy = 0.0
        for mat in attn_matrices:
            for clen, batch_obj in zip(lens, mat):
                curr_att_mat = batch_obj[:clen, :clen]
                for attns in curr_att_mat:
                    total_entropy += torch.distributions.Categorical(attns).entropy()
        # the average total entropy across all layers
        return total_entropy / len(mat)

    def forward(
        self,
        src: torch.Tensor,
        src_len: torch.Tensor,
        target: torch.Tensor = None,
        target_len: torch.Tensor = None,
        teacher_forcing: bool = True,
        max_len: Optional[int] = None,
    ) -> TransformerResult:
        """
        Run transformer encoder-decoder on some input/output pair

        :param src: source tensor. Shape: [N, S], where S in the in sequence length, N is the batch size
        :param src_len: length of source sequences. Shape: [N], N is the batch size
        :param target: target tensor. Shape: [N, S], where T in the in sequence length, N is the batch size
        :param target_len: length of target sequences. Shape: [N], N is the batch size
        :param teacher_forcing: use teacher forcing or greedy decoding
        :param max_len: overwrite autodetected max length. Useful for parallel execution
        :return: prediction of the target tensor. Shape [N, T, C_out]
        """
        if self.encoder_sos is not None:
            src = F.pad(src, (1, 0), value=self.encoder_sos)
            src_len = src_len + 1

        if self.mode == "enc_dec":
            inp = src
            src = self.pos_embed(self.input_embed(src), 0, 0)
            if teacher_forcing:

                if target is None:
                    # breakpoint()
                    # Create a target tensor with just one element in the sequence i.e. <decoder_sos>
                    target = (
                        torch.ones(src.shape[0], 1, dtype=torch.long, device=src.device)
                        * self.decoder_sos_eos
                    )
                    target_len = torch.ones(
                        src.shape[0], dtype=torch.long, device=src.device
                    )
                else:
                    pass
                return self.run_teacher_forcing(src, src_len, target, target_len)
            else:
                return self.run_greedy(src, src_len, max_len or target_len.max().item())
        elif self.mode == "mlm":
            #### the encoder_only function expects that we have already left and right padded the input.
            in_len_mask = self.generate_len_mask(src.shape[1], src_len)
            out = self.encoder_only(src, in_len_mask)
            return self.mlm_head(out)
        elif self.mode == "classifier":
            in_len_mask = self.generate_len_mask(src.shape[1], src_len)
            out = self.encoder_only(src, in_len_mask)
            return self.classifier_head(out.mean(dim=1))
