import torch
from torch import nn

from egu.models.base import HFModel


class SoftPromptHFModel(HFModel):
    def __init__(
        self,
        model_name,
        model_path=None,
        config_path="./config",
        generation_config=None,
        soft_prompt_len=1,  # {name: "soft-prompt", n_tokens: 8}
    ):
        super().__init__(model_name, model_path, config_path, generation_config)
        self.soft_prompt_len = soft_prompt_len

    def __call__(self, *args, **kwargs):
        for key in ["prompts", "answers"]:
            if key in kwargs:
                kwargs.pop(key, None)
        if "input_ids" in kwargs and "inputs_embeds" not in kwargs:
            input_ids = kwargs.pop("input_ids")  # (B, L)
            inputs_embeds, attn_mask = self._prepend_soft_prompt(input_ids)

            kwargs["inputs_embeds"] = inputs_embeds  # (B, k+L, D)
            kwargs["attention_mask"] = attn_mask  # (B, k+L)

            if "labels" in kwargs:
                labels = kwargs["labels"]  # (B, L)
                pad = torch.full(
                    (labels.size(0), self.soft_prompt_len),
                    -100,
                    dtype=labels.dtype,
                    device=labels.device,
                )
                kwargs["labels"] = torch.cat([pad, labels], dim=1)  # (B, k+L)

        return self.model(*args, **kwargs)

    def _init_soft_prompt(self):

        for p in self.model.parameters():
            p.requires_grad = False

        embed_dim = self.model.model.embed_tokens.embedding_dim
        self.soft_prompt = nn.Parameter(
            torch.randn(self.soft_prompt_len, embed_dim) * 0.01
        )

    def _prepend_soft_prompt(self, input_ids):
        B, L = input_ids.shape
        tok_embeds = self.model.model.embed_tokens(input_ids)

        if self.soft_prompt is None:
            return tok_embeds, (input_ids != self.tokenizer.pad_token_id).long()

        soft_expanded = self.soft_prompt.unsqueeze(0).expand(B, -1, -1)
        inputs_embeds = torch.cat([soft_expanded, tok_embeds], dim=1)
        attention_mask = torch.cat(
            [
                torch.ones(
                    B, self.soft_prompt_len, dtype=torch.long, device=input_ids.device
                ),
                (input_ids != self.tokenizer.pad_token_id).long(),
            ],
            dim=1,
        )
        return inputs_embeds, attention_mask
