import torch
import torch.nn as nn

from baselines.clip import clip
from baselines.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()


class PromptLearner(nn.Module):
    def __init__(self, classnames, clip_model, device):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = 16
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = 224
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        bottleneck = 4
        self.N = 2

        U = torch.empty(self.N, n_ctx, bottleneck, dtype=dtype)
        V = torch.empty(self.N, bottleneck, ctx_dim, dtype=dtype)
        sigma = torch.empty(self.N, n_ctx, ctx_dim, dtype=dtype)
        nn.init.normal_(U, std=0.02)
        nn.init.normal_(V, std=0.02)
        nn.init.normal_(sigma, std=0.02)
        prompt_prefix = " ".join(["X"] * n_ctx)

        self.U = nn.Parameter(U)
        self.V = nn.Parameter(V)
        self.sigma = nn.Parameter(sigma)

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        tokenized_prompts = tokenized_prompts.repeat(self.N, 1).to(device)

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype).to(device)

        self.register_buffer("token_prefix", embedding[:, :1, :])
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])
        self.register_buffer("embedding", embedding)

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens

    def forward(self, ):
        U = self.U
        V = self.V
        UV = torch.matmul(U, V)
        sigma = self.sigma
        ctx = UV + self.sigma
        embedding = self.embedding

        if ctx.dim() == 3:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1, -1)

        ctx = ctx.permute(1, 0, 2, 3)
        ctx = ctx.contiguous().view(self.N * self.n_cls, self.n_ctx, ctx.shape[3])

        if UV.dim() == 3:
            UV = UV.unsqueeze(0).expand(self.n_cls, -1, -1, -1)

        UV = UV.permute(1, 0, 2, 3)
        UV = UV.contiguous().view(self.N * self.n_cls, self.n_ctx, UV.shape[3])

        if sigma.dim() == 3:
            sigma = sigma.unsqueeze(0).expand(self.n_cls, -1, -1, -1)

        sigma = sigma.permute(1, 0, 2, 3)
        sigma = sigma.contiguous().view(self.N * self.n_cls, self.n_ctx, sigma.shape[3])

        prefix = self.token_prefix
        suffix = self.token_suffix

        prompts = torch.cat(
            [
                prefix,
                ctx,
                suffix,
            ],
            dim=1,
        )
        prompts_sigma = torch.cat(
            [
                prefix,
                sigma,
                suffix,
            ],
            dim=1,
        )
        prompts_UV = torch.cat(
            [
                prefix,
                UV,
                suffix,
            ],
            dim=1,
        )

        return embedding, prompts_sigma, prompts_UV, prompts


class text_encoder_FedPGP(nn.Module):
    def __init__(self, model, positional_embedding, ln_final, text_projection):
        super().__init__()
        self.transformer = model
        self.positional_embedding = positional_embedding
        self.ln_final = ln_final
        self.text_projection = text_projection

    def forward(self, prompts, tokenized_prompts):
        dtype = torch.float32
        x = prompts + self.positional_embedding.type(dtype)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x).type(dtype)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x
