import torch
import torch.nn as nn
from transformers import PreTrainedModel

from loader.models.encoding import (
    PositionalEncoding,
    PositionalEmbedding,
    FrequencyEncoding,
    PolynomialEncoding,
    GaussianEncoding,
    HybridEncoding,
    SubspaceEncoding,
)


class MemoryTransformerForPolynomials(PreTrainedModel):
    # _tied_weights_keys = ['token_embedding', 'classifier.weight']
    def __init__(self, config):
        super().__init__(config)

        self.config = config
        self.d_model = config.d_model
        self.encoding_method = config.encoding_method
        self.num_variables = config.num_variables
        self.token_register_size = config.token_register_size
        self.input_dim = config.token_register_size + self.num_variables
        self.use_standard_embedding = config.use_standard_embedding
        self.special_token_ids = config.special_token_ids
        self.vocab_size = config.vocab_size
        self.max_sequence_length = config.max_sequence_length
        self.num_batch = config.num_batch
        self.dump_memory_token = False

        # Select encoding layer and decoder
        if self.use_standard_embedding:
            self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)
            self.embedding = nn.Identity()
        else:
            self.token_embedding = nn.Embedding(self.vocab_size, self.token_register_size)
            if self.encoding_method == "frequency":
                self.embedding = FrequencyEncoding(self.d_model - self.token_register_size, self.num_variables)
            elif self.encoding_method == "polynomial":
                self.embedding = PolynomialEncoding(self.d_model - self.token_register_size, self.num_variables)
            elif self.encoding_method == "subspace":
                self.embedding = SubspaceEncoding(self.d_model - self.token_register_size, self.num_variables)
            elif self.encoding_method == "gaussian":
                self.embedding = GaussianEncoding(
                    self.d_model - self.token_register_size,
                    self.num_variables,
                    max_value=config.gaussian_encoding_upper_bound,
                )
            elif self.encoding_method == "hybrid":
                self.embedding = HybridEncoding(self.d_model - self.token_register_size, self.num_variables)
            else:
                raise ValueError(f"Unknown encoding method: {self.encoding_method}")

        self.classifier = nn.Linear(self.d_model, self.vocab_size, bias=True)

        if self.use_standard_embedding:
            if self.config.positional_encoding == "sinusoidal":
                self.positional_encoding = PositionalEncoding(self.d_model, max_len=self.max_sequence_length)
            if self.config.positional_encoding == "embedding":
                self.positional_encoding = PositionalEmbedding(self.d_model, max_len=self.max_sequence_length)
        else:
            self.positional_encoding = nn.Identity()

        # Transformer encoder settings
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=config.nhead,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout,
            activation=nn.GELU(),
            batch_first=True,
            norm_first=False,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_encoder_layers)

        # Transformer decoder settings
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=self.d_model,
            nhead=config.nhead,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout,
            activation=nn.GELU(),
            batch_first=True,
            norm_first=False,
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_decoder_layers)

        self.loss_clf = nn.CrossEntropyLoss(ignore_index=self.special_token_ids["pad_token_id"])
        self.loss_rg = nn.MSELoss(reduction="mean")

        self.memory_tokens = nn.Parameter(torch.randn(config.num_memory_tokens, config.d_model))

        self.post_init()

    def forward(
        self,
        encoder_input,
        decoder_input,
        labels=None,
        labels_for_regression=None,
        encoder_attention_mask=None,
        decoder_attention_mask=None,
        encoder_padding_mask=None,
        decoder_padding_mask=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if self.dump_memory_token:
            encoder_output = self.encode(encoder_input, encoder_attention_mask, encoder_padding_mask)
        else:
            encoder_output, encoder_padding_mask = self.encode(
                encoder_input, encoder_attention_mask, encoder_padding_mask
            )
        decoder_output = self.decode(
            decoder_input,
            encoder_output,
            decoder_attention_mask,
            encoder_padding_mask,
            decoder_padding_mask,
            encoder_attention_mask,
        )

        logits = self.classifier(decoder_output[:, :-1, :])
        if self.use_standard_embedding:
            decoded_output = torch.tensor([]).to(self.device)
        else:
            decoded_output = self.embedding.inverse(decoder_output[:, :-1, self.token_register_size :])

        loss = None
        if labels is not None:
            loss_clf = self.loss_clf(logits.reshape(-1, self.vocab_size), labels.reshape(-1).long())
            loss_rg = (
                self.loss_rg(decoded_output, labels_for_regression)
                if labels_for_regression is not None
                else torch.tensor(0.0).to(loss_clf.device)
            )
            loss = loss_clf + loss_rg

        return {
            "loss": loss,
            "loss_clf": loss_clf,
            "loss_rg": loss_rg,
            "logits": logits,
            "logits_for_regression": decoded_output,
            "encoder_output": encoder_output,
        }

    def _shift_right(self, x):
        shifted_input_embeds = torch.zeros_like(x)
        if self.use_standard_embedding:  # ids are given
            shifted_input_embeds[:, 0] = self.special_token_ids["bos_token_id"]
            shifted_input_embeds[:, 1:] = x[:, :-1].clone()

        else:  # monomials in [C, E1, E2, ..., En] form
            shifted_input_embeds[:, 0, 0] = self.special_token_ids["bos_token_id"]
            shifted_input_embeds[:, 1:, :] = x[:, :-1, :].clone()

        return shifted_input_embeds

    def generate_square_subsequent_mask(self, sz, dtype=float):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
        if dtype == bool:
            mask = mask.isinf()
        return mask

    def encode(self, encoder_input, encoder_attention_mask=None, encoder_padding_mask=None):

        if self.use_standard_embedding:
            encoder_embedded = self.token_embedding(encoder_input.long())  # (batch_size, seq_len, d_model)
        else:
            # embed symbolic tokens
            token_embedded = self.token_embedding(encoder_input[:, :, 0].long())
            # embed number tokens continuously
            encoder_embedded = self.embedding(encoder_input[:, :, 1:])
            encoder_embedded = torch.cat([token_embedded, encoder_embedded], dim=-1)  # (batch_size, seq_len, d_model)

        # TODO: Add memory processing here
        if self.dump_memory_token:  # When discarding memory
            mem = self.memory_tokens.repeat(encoder_input.size(0), 1, 1)  # (batch_size, num_memory_tokens, d_model)

            # (batch_size, num_memory_tokens + seq_len, d_model)
            encoder_embedded = torch.cat([mem, encoder_embedded], dim=1)
            mem_padding_mask = torch.zeros(
                (encoder_input.size(0), mem.size(1)), dtype=torch.bool, device=encoder_input.device
            )
            encoder_padding_mask = torch.cat([mem_padding_mask, encoder_padding_mask], dim=1)
            pe = self.positional_encoding(encoder_embedded)  # (batch_size, num_memory_tokens + seq_len, d_model)
            encoder_embedded += pe
            encoder_output = self.encoder(
                encoder_embedded, src_key_padding_mask=encoder_padding_mask, mask=encoder_attention_mask
            )
            encoder_output = encoder_output[:, mem_padding_mask.size(1) :, :]
            return encoder_output
        else:  # When holding memory tokens
            mem = self.memory_tokens.repeat(encoder_input.size(0), 1, 1)  # (batch_size, num_memory_tokens, d_model)
            # (batch_size, seq_len, d_model)
            encoder_embedded = torch.cat([mem, encoder_embedded], dim=1)
            mem_padding_mask = torch.zeros(
                (encoder_input.size(0), mem.size(1)), dtype=torch.bool, device=encoder_input.device
            )
            encoder_padding_mask = torch.cat([mem_padding_mask, encoder_padding_mask], dim=1)
            pe = self.positional_encoding(encoder_embedded)  # (batch_size, num_memory_tokens + seq_len, d_model)
            encoder_embedded += pe
            encoder_output = self.encoder(
                encoder_embedded, src_key_padding_mask=encoder_padding_mask, mask=encoder_attention_mask
            )
            return encoder_output, encoder_padding_mask

    def decode(
        self,
        decoder_input,
        encoder_output,
        decoder_attention_mask=None,
        encoder_padding_mask=None,
        decoder_padding_mask=None,
        encoder_attention_mask=None,
        perform_shift_right=True,
    ):
        if perform_shift_right:
            decoder_input_shifted = self._shift_right(decoder_input)
        else:
            decoder_input_shifted = decoder_input

        if self.use_standard_embedding:  # ids are given
            decoder_embedded = self.token_embedding(decoder_input_shifted.long())
        else:
            token_embedded = self.token_embedding(decoder_input_shifted[:, :, 0].long())
            decoder_input_shifted[decoder_input[:, :, 0] != self.special_token_ids["number_token_id"]] = 0
            decoder_embedded = self.embedding(decoder_input_shifted[:, :, 1:])
            decoder_embedded = torch.cat([token_embedded, decoder_embedded], dim=-1)

        pe = self.positional_encoding(decoder_embedded)
        decoder_embedded += pe

        if decoder_padding_mask is None:
            decoder_padding_mask = torch.zeros(decoder_input.shape[:2], dtype=torch.bool, device=decoder_input.device)

        decoder_output = self.decoder(
            decoder_embedded,
            encoder_output,
            tgt_key_padding_mask=decoder_padding_mask,
            memory_key_padding_mask=encoder_padding_mask,
            tgt_mask=self.generate_square_subsequent_mask(decoder_input.size(1), dtype=bool).to(decoder_input.device),
            memory_mask=encoder_attention_mask,
            tgt_is_causal=True,
        )

        return decoder_output

    @torch.no_grad()
    def greedy_generate(
        self,
        encoder_input,
        max_length=100,
        encoder_attention_mask=None,
        encoder_padding_mask=None,
    ):
        batch_size = encoder_input.shape[0]
        device = encoder_input.device

        if self.dump_memory_token:
            encoder_embedded = self.encode(encoder_input, encoder_attention_mask, encoder_padding_mask)
        else:
            encoder_embedded, encoder_padding_mask = self.encode(
                encoder_input, encoder_attention_mask, encoder_padding_mask
            )
        decoder_input = (
            torch.full((batch_size, max_length + 1), self.special_token_ids["bos_token_id"], device=device)
            if self.use_standard_embedding
            else torch.zeros(batch_size, max_length + 1, self.input_dim, device=device)
        )

        eos = torch.zeros(batch_size, dtype=torch.bool, device=device)

        for k in range(max_length):
            _decoder_input = decoder_input[:, : k + 1][~eos]
            _encoder_embedded = encoder_embedded[~eos]
            _encoder_padding_mask = encoder_padding_mask[~eos] if encoder_padding_mask is not None else None
            _encoder_attention_mask = encoder_attention_mask[~eos] if encoder_attention_mask is not None else None
            decoder_output = self.decode(
                _decoder_input,
                _encoder_embedded,
                encoder_padding_mask=_encoder_padding_mask,
                encoder_attention_mask=_encoder_attention_mask,
                perform_shift_right=False,
            )
            logits = self.classifier(decoder_output[:, -1:, :])
            next_token = logits.argmax(dim=-1, keepdim=True)

            if self.use_standard_embedding:
                next_input = next_token[:, -1:].squeeze(-1)
            else:
                decoded_output = self.embedding.inverse(decoder_output[:, -1:, self.token_register_size :])
                next_input = torch.cat([next_token, decoded_output.round()], dim=-1)

            decoder_input[~eos, k + 1 : k + 2] = next_input

            eos[~eos] |= (next_token == self.special_token_ids["eos_token_id"]).flatten()
            if eos.all():
                decoder_input = decoder_input[:, : k + 2]
                break

        return decoder_input

    @torch.no_grad()
    def generate(
        self,
        encoder_input,
        max_length=100,
        temperature=1.0,
        do_sample=False,
        top_k=0,
        top_p=1.0,
        repetition_penalty=1.0,
        encoder_attention_mask=None,
        encoder_padding_mask=None,
        **kwargs,
    ):
        batch_size = encoder_input.shape[0]
        device = encoder_input.device

        encoder_output = self.encode(encoder_input, encoder_attention_mask, encoder_padding_mask)
        # decoder_input = torch.zeros(batch_size, 1, self.input_dim, device=device)
        decoder_input = torch.zeros(batch_size, 0, self.input_dim, device=device)

        for _ in range(max_length):
            decoder_output = self.decode(decoder_input, encoder_output)
            next_token_logits = decoder_output[:, -1, :] / temperature

            if repetition_penalty != 1.0:
                self._apply_repetition_penalty(next_token_logits, batch_size, 1, decoder_input, repetition_penalty)

            if do_sample:
                next_token = self._sample_token(next_token_logits, top_k, top_p)
            else:
                next_token = next_token_logits

            decoder_input = torch.cat([decoder_input, next_token.unsqueeze(1)], dim=1)

            if self._check_end_condition(next_token):
                break

        return decoder_input[:, 1:]  # Remove the first zero padding

    def _apply_repetition_penalty(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
        for i in range(batch_size * num_beams):
            for previous_token in set(prev_output_tokens[i].tolist()):
                lprobs[i, previous_token] /= repetition_penalty

    def _sample_token(self, logits, top_k, top_p):
        if top_k > 0:
            indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
            logits[indices_to_remove] = -float("Inf")

        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[indices_to_remove] = -float("Inf")

        probs = torch.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1).squeeze(1)

    def _check_end_condition(self, token):
        # The actual implementation needs to be adjusted according to the task
        return False