import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModel, AutoTokenizer


class Pooler(nn.Module):
    """
    Parameter-free poolers to get the sentence embedding
    'cls': [CLS] representation with BERT/RoBERTa's MLP pooler.
    'cls_before_pooler': [CLS] representation without the original MLP pooler.
    'avg': average of the last layers' hidden states at each token.
    'avg_top2': average of the last two layers.
    'avg_first_last': average of the first and the last layers.
    """

    def __init__(self, pooler_type):
        super().__init__()
        self.pooler_type = pooler_type
        assert self.pooler_type in [
            "cls",
            "cls_before_pooler",
            "avg",
            "avg_top2",
            "avg_first_last",
        ], (
            "unrecognized pooling type %s" % self.pooler_type
        )

    def forward(self, attention_mask, outputs):
        last_hidden = outputs.last_hidden_state
        hidden_states = outputs.hidden_states

        if self.pooler_type in ["cls_before_pooler", "cls"]:
            return last_hidden[:, 0]
        elif self.pooler_type == "avg":
            return (last_hidden * attention_mask.unsqueeze(-1)).sum(
                1
            ) / attention_mask.sum(-1).unsqueeze(-1)
        elif self.pooler_type == "avg_first_last":
            first_hidden = hidden_states[0]
            last_hidden = hidden_states[-1]
            pooled_result = (
                (first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)
            ).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
            return pooled_result
        elif self.pooler_type == "avg_top2":
            second_last_hidden = hidden_states[-2]
            last_hidden = hidden_states[-1]
            pooled_result = (
                (last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)
            ).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
            return pooled_result
        else:
            raise NotImplementedError


class TextEncoder:
    def __init__(
        self, model_path: str, device: str, model_config={}, tokenizer_config={}
    ):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
        self.model = AutoModel.from_pretrained(model_path, **model_config).to(device)
        self.model.init_weights()
        self.model.resize_token_embeddings(len(self.tokenizer))

    @property
    def emb_size(self):
        return self.embed("test").shape[-1]

    def embed(self, text_list):
        raise NotImplementedError


class TextEncoderSubgraphRAG(TextEncoder):
    def __init__(
        self,
        model_path="Alibaba-NLP/gte-large-en-v1.5",
        device="cuda:0",
        normalize=True,
    ):
        model_config = {
            "trust_remote_code": True,
            "unpad_inputs": True,
            "use_memory_efficient_attention": True,
        }
        super().__init__(model_path, device, model_config)
        self.normalize = normalize

    @torch.no_grad()
    def embed(self, text_list):
        batch_dict = self.tokenizer(
            text_list,
            max_length=8192,
            padding=True,
            truncation=True,
            return_tensors="pt",
        ).to(self.device)

        outputs = self.model(**batch_dict).last_hidden_state
        emb = outputs[:, 0]

        if self.normalize:
            emb = F.normalize(emb, p=2, dim=1)

        return emb

    def __call__(self, q_text, text_entity_list, relation_list):
        q_emb = self.embed([q_text])
        entity_embs = self.embed(text_entity_list)
        relation_embs = self.embed(relation_list)

        return q_emb, entity_embs, relation_embs


class TextEncoderSR(TextEncoder, nn.Module):
    def __init__(self, model_path="roberta-base", device="cuda:0", pooler_type="cls"):
        nn.Module.__init__(self)
        model_config = {
            "cache_dir": None,
            "revision": "main",
            "use_auth_token": None,
        }
        tokenizer_config = {
            "cache_dir": None,
            "use_fast": True,
            "revision": "main",
            "use_auth_token": None,
        }
        TextEncoder.__init__(
            self,
            model_path,
            device,
            model_config=model_config,
            tokenizer_config=tokenizer_config,
        )
        self.pooler_type = pooler_type
        self.pooler = Pooler(pooler_type).to(device)
        if self.pooler_type == "cls":
            self.mlp = nn.Sequential(
                nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size),
                nn.Tanh(),
            ).to(device)

    def embed(self, text_list):
        inputs = self.tokenizer(
            text_list, padding=True, truncation=True, return_tensors="pt"
        )
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        return self.forward(**inputs)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
    ):
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=(
                True if self.pooler_type in ["avg_top2", "avg_first_last"] else False
            ),
            return_dict=True,
        )

        # Pooling
        pooler_output = self.pooler(attention_mask, outputs)

        # If using "cls", we add an extra MLP layer
        # (same as BERT's original implementation) over the representation.
        if self.pooler_type == "cls":
            pooler_output = self.mlp(pooler_output)

        return pooler_output
