# Code adapted from: https://github.com/KaiyangZhou/CoOp/blob/main/trainers/coop.py
# License: MIT

from typing import cast
from collections.abc import Sequence

import torch
from torch import IntTensor
from torch import nn

from open_clip.tokenizer import _tokenizer


__all__ = [
    "PromptLearner",
]


class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames: Sequence[str], n_ctx: int = 1):
        super().__init__()

        ctx_dim = clip_model.ln_final.weight.shape[0]
        ctx_vectors = torch.empty((n_ctx, ctx_dim), dtype=clip_model.dtype)
        nn.init.normal_(ctx_vectors, std=0.02)

        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized

        classnames = [name.replace("_", " ") for name in classnames]
        prompt_prefix = " ".join(["X"] * n_ctx)
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        with torch.no_grad():
            tokenized_prompts = torch.cat([_tokenizer(p) for p in prompts])
            embedding = clip_model.token_embedding(tokenized_prompts.to(clip_model.device))

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS
        self.register_buffer("_eot_idxs", tokenized_prompts.argmax(dim=-1))

        self.n_cls = len(classnames)
        self.n_ctx = n_ctx
        self.name_lens = [len(_tokenizer.encode(name)) for name in classnames]

        self.to(clip_model.device)

    def forward(self):
        ctx = self.ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
        name_lens = self.name_lens
        prefix = self.prefix
        suffix = self.suffix

        _prompts = []

        for i in range(self.n_cls):
            name_len = name_lens[i]
            prompt = torch.cat(
                (
                    prefix[i : i + 1, :, :],          # prefix_i (1,        1, dim)
                    suffix[i : i + 1, :name_len, :],  # class_i  (1, name_len, dim)
                    ctx[i : i + 1, :, :],             # suffix_i (1,    n_ctx, dim)
                    suffix[i : i + 1, name_len:, :],  # ctx_i    (1,        *, dim)
                ),
                dim=1,
            )
            _prompts.append(prompt)

        prompts = torch.cat(_prompts, dim=0)
        return prompts

    @property
    def prompt(self):
        return self()

    @property
    def eot_idxs(self) -> IntTensor:
        return cast(IntTensor, self._eot_idxs)
