import contextlib
import copy
import json
import logging
import os
import random
from typing import Any
import warnings
from collections import namedtuple, Counter

import math
import torch
import transformers
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BatchEncoding, AutoModel
from vector_quantize_pytorch import FSQ
from transformers import BertModel

EncoderStates = namedtuple(
    "EncoderStates", ["last_hidden_state", "hidden_states", "attentions"]
)
AutoencoderOutput = namedtuple(
    "AutoencoderOutput",
    [
        "encoder_outputs",
        "encoder_downprojected",
        "encoder_upprojected",
        "quantizer_outputs",
    ],
)


class Autoencoder(torch.nn.Module):

    def __init__(
        self,
        model: transformers.PreTrainedModel,
        tokenizer: transformers.PreTrainedTokenizer,
        length: int = 50,
        levels: list[int] = [4, 4, 4, 4],
        hidden_size: int = 256,
        **kwargs,
    ):
        super().__init__()
        for k, v in kwargs.items():
            logging.warning(f"Unused argument {k}={v}")

        self.model = model

        self.tokenizer = copy.deepcopy(tokenizer)
        self.levels = list(levels)

        assert hidden_size % len(self.levels) == 0
        self.hidden_size = hidden_size
        # len(self.levels) = 8
        # hidden_size = 32
        # num_latent_tokens_per_encoder_tokens = 4

        self.length = length
        self.downproject = torch.nn.Sequential(
            torch.nn.Linear(
                self.model.config.hidden_size,
                # self.model.config.hidden_size // 2,
                len(self.levels),
                bias=True,
            ),
            # torch.nn.ReLU(),
            # torch.nn.Linear(
            #     self.model.config.hidden_size // 2,
            #     len(self.levels),
            #     bias=True,
            # ),
        ).to("cuda")

        self.upproject = torch.nn.Sequential(
            torch.nn.Linear(
                len(self.levels), self.model.config.hidden_size * 2, bias=True
            ),
            torch.nn.ReLU(),
            torch.nn.Linear(
                self.model.config.hidden_size * 2,
                self.model.config.hidden_size,
                bias=True,
            ),
        ).to("cuda")

    @property
    def levels(self):
        return self._levels

    # setter
    @levels.setter
    def levels(self, levels):
        self.quantizer = FSQ(levels).to(self.device())
        self._levels = levels

    @property
    def num_latent_tokens_per_encoder_token(self):
        return self.hidden_size // len(self.levels)

    @contextlib.contextmanager
    def set_length(self, length: int):
        old_length = self.length
        self.length = length
        yield
        self.length = old_length

    @property
    def effective_length(self):
        return self.length * self.num_latent_tokens_per_encoder_token

    @property
    def encoder(self):
        if hasattr(self.model, "encoder"):
            return self.model.encoder
        return self.model.model.encoder

    @property
    def decoder(self):
        if hasattr(self.model, "decoder"):
            return self.model.decoder
        return self.model.model.decoder

    @classmethod
    def init_from_pretrained(cls, model_name_or_path: str, *args, **kwargs):
        pretrained_model = AutoModel.from_pretrained(model_name_or_path)
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        autoencoder = cls(model=pretrained_model, tokenizer=tokenizer, *args, **kwargs)
        return autoencoder

    def device(self):
        return next(self.model.parameters()).device

    def tokenize(self, texts, max_length=256, pad_to_max_length=False):
        if pad_to_max_length:
            batch_tokenized = self.tokenizer(
                [" " + text.strip() for text in texts],
                padding="max_length",
                truncation=True,
                return_tensors="pt",
                max_length=max_length,
            ).to(self.device())
            return self.add_embedding_positions_to_batch(batch_tokenized)
        else:
            batch_tokenized = self.tokenizer(
                [" " + text.strip() for text in texts],
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=max_length,
            ).to(self.device())
            return self.add_embedding_positions_to_batch(batch_tokenized)

    def encode(self, batch: BatchEncoding, steps: int = 1):

        print("encode")
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):

            # encoder_inputs_embeds = self.model.get_input_embeddings()(batch['input_ids'])
            encoder_attention_mask = batch["attention_mask"]

            # encoder_positional_embeddings = self.encoder.embed_tokens(batch['input_ids'])
            # set all positions after length to zero so we are only removing the positional embeddings for the special tokens
            # encoder_positional_embeddings[:, self.length:, :] = 0
            # remove positional embeddings
            # encoder_inputs_embeds -= encoder_positional_embeddings

            encoder_outputs = self.encoder(
                **batch,
            )
            return encoder_outputs

    def add_embedding_positions_to_batch(self, batch: BatchEncoding) -> BatchEncoding:
        batch_size = batch.input_ids.shape[0]
        # embedding_positions_input_ids = []
        # embedding_positions_mask = []
        # for _ in range(batch_size):
        #     # sample self.length random positions between 0 and 10000
        #     random_position = random.sample(range(10000), self.length)
        #     embedding_positions = self.tokenizer(
        #         ["".join([f"[embed{str(i).zfill(5)}]" for i in random_position])],
        #         padding=True,
        #         truncation=True,
        #         return_tensors="pt",
        #         add_special_tokens=False,
        #     ).to(self.device())
        #     embedding_positions_input_ids.append(embedding_positions.input_ids)
        #     embedding_positions_mask.append(embedding_positions.attention_mask)
        # embedding_positions = BatchEncoding(
        #     {
        #         "input_ids": torch.cat(embedding_positions_input_ids, dim=0),
        #         "attention_mask": torch.cat(embedding_positions_mask, dim=0),
        #     }
        # )
        embedding_positions = self.tokenizer(
            ["".join([f"[embed{str(i).zfill(3)}]" for i in range(self.length)])],
            padding=True,
            truncation=True,
            return_tensors="pt",
            add_special_tokens=False,
        ).to(self.device())
        batch = copy.deepcopy(batch)
        new_batch = {}
        batch_size = batch.input_ids.shape[0]
        new_batch["input_ids"] = torch.cat(
            [embedding_positions.input_ids.repeat(batch_size, 1), batch.input_ids],
            dim=1,
        )
        new_batch["attention_mask"] = torch.cat(
            [
                embedding_positions.attention_mask.repeat(batch_size, 1),
                batch.attention_mask,
            ],
            dim=1,
        )
        return BatchEncoding(new_batch)

    def add_special_tokens(self):
        # add mask token
        self.tokenizer.add_special_tokens({"mask_token": "[MASK]"})
        self.tokenizer.add_special_tokens(
            {
                "additional_special_tokens": [
                    f"[embed{str(i).zfill(3)}]" for i in range(300)
                ]
            }
        )
        self.model.resize_token_embeddings(len(self.tokenizer))

    def get_position_ids(self, batch: BatchEncoding):
        basic_positions = (
            torch.arange(batch["input_ids"].shape[1], device=self.device())
            .unsqueeze(0)
            .repeat(batch["input_ids"].shape[0], 1)
        )
        basic_positions -= self.length - 1
        basic_positions[:, : self.length] = 0
        return basic_positions

    def forward(
        self,
        batch: BatchEncoding,
    ) -> AutoencoderOutput:

        # remove positional encoding
        # position_ids = self.get_position_ids(batch)
        out = self.model(**batch)  # , position_ids=position_ids)
        encoder_outputs = out.last_hidden_state[:, : self.length, :]

        with torch.autocast(device_type="cuda", dtype=torch.float32):
            quantizer_inputs = self.downproject(encoder_outputs)
            quantizer_inputs = quantizer_inputs.reshape(
                quantizer_inputs.shape[0],
                -1,
                len(self.levels),
            )
            state_queries_codes, queries_latent_indices = self.quantizer(
                quantizer_inputs
            )
            state_queries_up = self.upproject(state_queries_codes)

        return AutoencoderOutput(
            encoder_outputs=encoder_outputs,
            encoder_downprojected=quantizer_inputs,
            encoder_upprojected=state_queries_up,
            quantizer_outputs=(state_queries_codes, queries_latent_indices),
        )

    @contextlib.contextmanager
    def set_length(self, length: int):
        old_length = self.length
        self.length = length
        yield
        self.length = old_length

    def get_codebook(self, indices=None, with_lin1=False, with_embeds=False):
        if indices is None:
            indices = torch.arange(self.quantizer.codebook_size, device=self.device())
        raw_quantizer = self.quantizer.indices_to_codes(indices)

        return raw_quantizer

    def save_checkpoint(self, path: str):
        os.makedirs(path, exist_ok=True)
        for param in self.model.parameters():
            if not param.is_contiguous():
                param.data = param.data.contiguous()
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)
        torch.save(self.quantizer.state_dict(), path + "/quantizer.pt")
        torch.save(self.downproject.state_dict(), path + "/downproject.pt")
        torch.save(self.upproject.state_dict(), path + "/upproject.pt")

        config = {
            "levels": self.levels,
            "hidden_size": self.hidden_size,
            "length": self.length,
        }
        with open(path + "/autoencoder_config.json", "w") as f:
            json.dump(config, f)

    @classmethod
    def load_checkpoint(cls, path: str, gradient_checkpointing: bool = True):
        with open(path + "/autoencoder_config.json", "r") as f:
            config = json.load(f)

        model = AutoModel.from_pretrained(path).to("cuda")
        tokenizer = AutoTokenizer.from_pretrained(path)
        autoencoder = cls(
            model=model,
            tokenizer=tokenizer,
            **config,
            gradient_checkpointing=gradient_checkpointing,
        )
        autoencoder.quantizer.load_state_dict(torch.load(path + "/quantizer.pt"))
        autoencoder.downproject.load_state_dict(torch.load(path + "/downproject.pt"))
        autoencoder.upproject.load_state_dict(torch.load(path + "/upproject.pt"))
        return autoencoder

    def create_logs(self, quantizer_outputs, encoder_only=False):
        codebook_size = self.quantizer.codebook_size
        codebook_freqs = Counter(quantizer_outputs.reshape(-1).tolist())
        codebook_entpy = -sum(
            [
                (p := freq / sum(codebook_freqs.values())) * math.log(p)
                for _, freq in codebook_freqs.items()
            ]
        )
        codebook_usage = len(set(codebook_freqs)) / codebook_size * 100.0

        encoder_grad_norm = 0.0
        decoder_grad_norm = 0.0
        quantizer_grad_norm = 0.0

        no_grad_params = []

        for name, param in self.encoder.named_parameters():
            if param.grad is not None:
                encoder_grad_norm += (param.grad.data**2.0).sum().item()
            else:
                no_grad_params.append("encoder." + name)
        if encoder_only:
            decoder_grad_norm = 0
        else:
            for name, param in self.decoder.named_parameters():
                if param.grad is not None:
                    decoder_grad_norm += (param.grad.data**2.0).sum().item()
                else:
                    no_grad_params.append("decoder." + name)
        for name, param in self.quantizer.named_parameters():
            if param.grad is not None:
                quantizer_grad_norm += (param.grad.data**2.0).sum().item()
            else:
                no_grad_params.append("quantizer." + name)

        # if no_grad_params:
        #     warnings.warn(f"Parameters without gradients: {no_grad_params}")

        total_grad_norm = math.sqrt(
            encoder_grad_norm + decoder_grad_norm + quantizer_grad_norm
        )
        encoder_grad_norm = math.sqrt(encoder_grad_norm)
        decoder_grad_norm = math.sqrt(decoder_grad_norm)
        quantizer_grad_norm = math.sqrt(quantizer_grad_norm)

        logs = {
            "codebook_entropy": codebook_entpy,
            "codebook_usage": codebook_usage,
            "grad_norm": total_grad_norm,
            "encoder_grad_norm": encoder_grad_norm,
            "decoder_grad_norm": decoder_grad_norm,
            "quantizer_grad_norm": quantizer_grad_norm,
        }

        return logs
