import json
import torch
import torch.nn as nn
import torch.nn.functional as nnf

from .clip import load, tokenize
from .model import convert_weights


class ClipWraper(nn.Module):
    def __init__(self, name, config, logger):
        super().__init__()
        self.name = name
        model_config = json.load(open(config["model_config"], "r"))
        self.config = model_config
        self.device = config["device"]
        self.logger = logger

        self.clipmodel, _ = load(model_config["architecture"], device="cpu")
        self.clipmodel = self.clipmodel.to(self.device)
        
        self._convert_weights_to_fp16()
    
    def _convert_weights_to_fp16(self):
        convert_weights(self)
    
    @torch.no_grad()
    def encode_text(self, texts):
        text_token = tokenize(texts).to(self.device)
        text_features = self.clipmodel.encode_text(text_token)
        return text_features
    
    @torch.no_grad()
    def encode_image(self, images, ret_cls=False):
        '''
            images: Tensor (B, 3, 224, 224), range [0, 1], RGB
        '''
        visual_model = self.clipmodel.visual

        x = visual_model.conv1(images)  # shape = [*, width, grid, grid]
        grid_num = x.shape[-1]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat(
            [
                visual_model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], 
                dtype=x.dtype,
                device=x.device), 
                x
            ], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + visual_model.positional_embedding.to(x.dtype)
        x = visual_model.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        out, x = visual_model.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD  shape = [*, grid ** 2 + 1, width]

        if ret_cls:
            x = visual_model.ln_post(x[:, 0:1, :])
            if visual_model.proj is not None:
                x = x @ visual_model.proj
        else:
            x = visual_model.ln_post(x[:, 1:, :])  # shape = [*, grid ** 2, width]
            if visual_model.proj is not None:
                x = x @ visual_model.proj
            x = x.permute(0, 2, 1) # shape = [*, width, grid ** 2]
            x = x.reshape(x.shape[0], x.shape[1], grid_num, grid_num).contiguous() # shape = [*, width, grid, grid]
        return x 
    
    def preprocess(self, images: torch.Tensor):
        _, _, w, h = images.shape
        images = (images + 1.0) / 2.0
        images = nnf.interpolate(images, (224, 224), mode="bilinear")
        return images, (w, h)
    
    @torch.no_grad()
    def generate_mask(self, images, texts):
        images, ori_shape = self.preprocess(images)

        img_f = self.encode_image(images, ret_cls=False) 
        grid_num = img_f.shape[-1]
        img_f = img_f.reshape(img_f.shape[0], img_f.shape[1], -1)
        img_f = img_f / img_f.norm(dim=2, keepdim=True)   # shape [*, width, grid ** 2]
        img_f = img_f.permute(0, 2, 1)   # shape [*, grid ** 2, width]

        txt_f = self.encode_text(texts)  
        txt_f = txt_f / txt_f.norm(dim=1, keepdim=True) # shape [*, width]
        txt_f = txt_f.unsqueeze(-1) # shape [*, width, 1]

        logit_scale = self.clipmodel.logit_scale.exp()
        sim = logit_scale * img_f @ txt_f  
        sim = sim.squeeze(-1)  # shape [*, grid ** 2]

        sorted_sim = torch.sort(sim, dim=-1)[0]
        adaptive_thr = sorted_sim[:, self.config["ratio"] * sim.shape[-1]].unsqueeze(-1).unsqueeze(-1) # shape [*, 1, 1]

        sim = sim.reshape(-1, grid_num, grid_num)
        mask = torch.zeros_like(sim)
        mask[sim >= adaptive_thr] = 1

        mask = nnf.interpolate(mask, ori_shape, mode="nearest")

        return mask

    @torch.no_grad()
    def cal_sim(self, images, texts):
        images, _ = self.preprocess(images)

        img_f = self.encode_image(images, ret_cls=True)  # shape [*, width]
        img_f = img_f / img_f.norm(dim=1, keepdim=True)   

        txt_f = self.encode_text(texts)  
        txt_f = txt_f / txt_f.norm(dim=1, keepdim=True) # shape [*, width]
        
        logit_scale = self.clipmodel.logit_scale.exp()
        sim = logit_scale * img_f @ txt_f.t()

        return sim



        
