import torch
import torch.nn as nn
from collections import OrderedDict

from baselines.pFedMMA_clip import clip
from baselines.pFedMMA_clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()


class text_encoder_pFedMMA(nn.Module):
    def __init__(self, model, positional_embedding, ln_final, text_projection):
        super().__init__()
        self.transformer = model
        self.positional_embedding = positional_embedding
        self.ln_final = ln_final
        self.text_projection = text_projection

    def forward(self, prompts, tokenized_prompts, adapter_bank, mode="text"):
        dtype = torch.float32
        x = prompts + self.positional_embedding.type(dtype)
        x = x.permute(1, 0, 2)
        x = self.transformer([x, adapter_bank, mode])
        x = x.permute(1, 0, 2)
        x = self.ln_final(x).type(dtype)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x


class image_encoder_pFedMMA(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.encoder = model

    def forward(self, x, adapter_bank, mode="visual"):
        x = self.encoder.conv1(x)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0, 2, 1)
        x = torch.cat([self.encoder.class_embedding.to(x.dtype) +
                       torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
        x = x + self.encoder.positional_embedding.to(x.dtype)

        x = self.encoder.ln_pre(x)
        x = x.permute(1, 0, 2)
        x = self.encoder.transformer([x, adapter_bank, mode])

        x = x.permute(1, 0, 2)
        x = self.encoder.ln_post(x[:, 0, :])

        if self.encoder.proj is not None:
            x = x @ self.encoder.proj

        return x


class AdapterBank(nn.Module):
    def __init__(self, text_adapter, visual_adapter, shared_adapter, adapter_scale):
        super().__init__()
        self.text_adapter = text_adapter
        self.visual_adapter = visual_adapter
        self.shared_adapter = shared_adapter
        self.adapter_scale = adapter_scale

    def get_text(self, layer_index):
        if layer_index < len(self.text_adapter) and self.text_adapter[layer_index] is not None:
            return self.text_adapter[layer_index], self.shared_adapter[layer_index], self.adapter_scale
        else:
            return nn.Identity(), nn.Identity(), self.adapter_scale

    def get_visual(self, layer_index):
        if layer_index < len(self.visual_adapter) and self.visual_adapter[layer_index] is not None:
            return self.visual_adapter[layer_index], self.shared_adapter[layer_index], self.adapter_scale
        else:
            return nn.Identity(), nn.Identity(), self.adapter_scale


class AdapterLearner(nn.Module):
    def __init__(self, classnames, clip_model, device):
        super().__init__()
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = 224
        ADAPTER_START = 5
        ADAPTER_END = 12
        ADAPTER_DIM = 32
        ADAPTER_SCALE = 0.001
        dtype = torch.float32
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        text_ctx_init = "a photo of a"
        classnames = [name.replace("_", " ") for name in classnames]
        prompts = [text_ctx_init + " " + name + "." for name in classnames]
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        tokenized_prompts = tokenized_prompts.to(device)
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype).to(device)

        self.register_buffer("token_embedding", embedding)
        self.register_buffer("tokenized_prompts", tokenized_prompts)
        self.adapter_scale = float(ADAPTER_SCALE)

        self.text_adapter = self._build_adapter(clip_model.ln_final.weight.shape[0],
                                                len(clip_model.transformer.resblocks),
                                                ADAPTER_START, ADAPTER_END, ADAPTER_DIM)
        self.visual_adapter = self._build_adapter(clip_model.visual.ln_post.weight.shape[0],
                                                  len(clip_model.visual.transformer.resblocks),
                                                  ADAPTER_START, ADAPTER_END, ADAPTER_DIM)
        self.shared_adapter = self._build_adapter(ADAPTER_DIM, len(clip_model.visual.transformer.resblocks),
                                                  ADAPTER_START, ADAPTER_END, ADAPTER_DIM)

    def _build_adapter(self, d_model, n_layers, l_start, l_end, mid_dim):
        adapter = [None] * (n_layers + 1)
        for i in range(l_start, l_end + 1):
            if mid_dim == d_model:
                adapter[i] = nn.Sequential(nn.Linear(d_model, mid_dim), nn.ReLU())
            else:
                adapter[i] = nn.Sequential(OrderedDict([("down", nn.Sequential(nn.Linear(d_model, mid_dim), nn.ReLU())),
                                                        ("up", nn.Linear(mid_dim, d_model))]))
        adapter = nn.ModuleList(adapter)
        for m in adapter.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                nn.init.constant_(m.bias, 0)
        return adapter

    def forward(self):
        embedding = self.token_embedding
        if self.text_adapter[0] is not None:
            token_embedding = self.text_adapter[0].down(embedding)
            shared_adapter = self.shared_adapter[0]
            token_embedding = shared_adapter(token_embedding)
            token_embedding = self.text_adapter[0].up(token_embedding)
            embedding = embedding + self.adapter_scale * token_embedding
        adapter_bank = AdapterBank(self.text_adapter, self.visual_adapter, self.shared_adapter, self.adapter_scale)

        return embedding, adapter_bank
