import torch

from torch import nn
from baselines.pFedDC_clip import clip
from baselines.pFedDC_clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()


class CrossAttention_text(nn.Module):
    def __init__(self, clip_model):
        super(CrossAttention_text, self).__init__()
        self.dtype = clip_model.dtype
        self.ctx_dim = clip_model.ln_final.weight.shape[0]
        self.kmlp = nn.Linear(self.ctx_dim, 64, bias=False, dtype=clip_model.dtype)
        self.qmlp = nn.Linear(self.ctx_dim, 64, bias=False, dtype=clip_model.dtype)
        self.vmlp = nn.Linear(self.ctx_dim, self.ctx_dim, bias=False, dtype=clip_model.dtype)

    def forward(self, input_a, input_b):
        Q = self.qmlp(input_a)
        K = self.kmlp(input_a)
        V = self.vmlp(input_b)

        scores = torch.matmul(Q, K.transpose(0, 1))
        attentions_a = torch.softmax(scores, dim=-1)
        output = torch.matmul(attentions_a, V)

        return output


class CrossAttention_vision(nn.Module):
    def __init__(self, clip_model):
        super(CrossAttention_vision, self).__init__()
        self.dtype = clip_model.dtype
        self.ctx_dim = clip_model.visual.conv1.out_channels
        self.kmlp = nn.Linear(self.ctx_dim, 64, bias=False, dtype=clip_model.dtype)
        self.qmlp = nn.Linear(self.ctx_dim, 64, bias=False, dtype=clip_model.dtype)
        self.vmlp = nn.Linear(self.ctx_dim, self.ctx_dim, bias=False, dtype=clip_model.dtype)

    def forward(self, input_a, input_b):
        Q = self.qmlp(input_a)
        K = self.kmlp(input_a)
        V = self.vmlp(input_b)

        scores = torch.matmul(Q, K.transpose(1, 2))
        attentions_a = torch.softmax(scores, dim=-1)
        output = torch.matmul(attentions_a, V)

        return output


class DeepPromptLearner(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.n_ctx = 16
        self.dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        ctx_vectors = torch.empty(self.n_ctx * 2, ctx_dim, dtype=self.dtype)
        nn.init.normal_(ctx_vectors, std=0.02)
        self.ctx = nn.Parameter(ctx_vectors)

    def forward(self):
        ctx = self.ctx
        return ctx


class VPTDeepPromptLearner(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.n_ctx = 5
        self.dtype = clip_model.dtype
        self.ctx_dim = clip_model.visual.conv1.out_channels
        self.bottom_limit = 11

        ctx_vectors = torch.empty(self.bottom_limit, self.n_ctx * 2, self.ctx_dim, dtype=self.dtype)
        nn.init.normal_(ctx_vectors, std=0.02)
        self.ctx = nn.Parameter(ctx_vectors)

    def forward(self):
        ctx = self.ctx
        return ctx


class PromptLearner(nn.Module):
    def __init__(self, classnames, clip_model, device):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = 16
        dtype = clip_model.dtype

        prompt_prefix = " ".join(["X"] * n_ctx)
        self.ctx_learner = DeepPromptLearner(clip_model)

        classnames = [name.replace("_", " ") for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        self.cross_attention_text = CrossAttention_text(clip_model)
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        tokenized_prompts = tokenized_prompts.to(device)
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
        self.register_buffer("token_prefix", embedding[:, :1, :])
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])
        self.n_cls = n_cls
        self.tokenized_prompts = tokenized_prompts

    def forward(self):
        ctx = self.ctx_learner()
        ctx = self.cross_attention_text(ctx[:16, :], ctx[16:, :])
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        suffix = self.token_suffix
        prompts = torch.cat(
            [
                prefix,
                ctx,
                suffix,
            ],
            dim=1,
        )
        return prompts


class ProjLearner(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.proj = clip_model.visual.proj

    def forward(self, x):
        if self.proj is not None:
            x = x @ self.proj
        return x


class Transformer_VPTD(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.n_ctx = 5
        self.dtype = clip_model.dtype
        self.ctx_dim = clip_model.visual.conv1.out_channels
        self.clip_imsize = clip_model.visual.input_resolution
        self.layers = clip_model.visual.transformer.layers

        transformer = clip_model.visual.transformer
        self.resblocks: nn.Sequential = transformer.resblocks
        self.layers = transformer.layers

        self.ctx_learner = VPTDeepPromptLearner(clip_model)
        self.cross_attention_vision = CrossAttention_vision(clip_model)

        self.class_prompt_num = 5
        self.bottom_limit = 11

    def forward(self, x):
        ctx0 = self.ctx_learner()
        ctx1 = ctx0[:, :5, :]
        ctx2 = ctx0[:, 5:, :]
        ctx = self.cross_attention_vision(ctx1, ctx2)

        ctx = ctx.unsqueeze(0).expand(x.shape[1], -1, -1, -1)
        ctx = ctx.permute(1, 2, 0, 3)
        n_ctx = self.n_ctx

        for i in range(self.bottom_limit):
            x = torch.cat([x, ctx[i]], dim=0)
            x = self.resblocks[i](x)
            x = x[:-n_ctx, :, :]

        n_ctx = self.class_prompt_num

        for i in range(self.layers - self.bottom_limit):
            x = self.resblocks[i + self.bottom_limit](x)
            if n_ctx != 0:
                x = x[:-n_ctx, :, :]

        return x


class image_encoder_pFedDC(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.conv1 = clip_model.visual.conv1
        self.class_embedding = clip_model.visual.class_embedding
        self.positional_embedding = clip_model.visual.positional_embedding
        self.ln_pre = clip_model.visual.ln_pre
        self.transformer = Transformer_VPTD(clip_model)
        self.ln_post = clip_model.visual.ln_post
        self.proj = ProjLearner(clip_model)

    def forward(self, x):
        x = self.conv1(x)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0, 2, 1)
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype,
                                                                      device=x.device), x], dim=1)
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)

        x = self.ln_post(x[:, 0, :])
        x = self.proj(x)

        return x


class text_encoder_pFedDC(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x).type(self.dtype)

        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x

