import torch 
from torch import nn
import numpy as np
import torch
import torch.nn.functional as nnf
import sys
import os
from typing import Tuple, List, Union, Optional
from tqdm import tqdm, trange
import pickle
import PIL.Image as Image
import json
import random
import sys
import clip
import PIL
from enum import Enum
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
device = "cuda:0" if torch.cuda.is_available() else "cpu"


class MappingType(Enum):
    MLP = 'mlp'
    Transformer = 'transformer'

    
class DeCapMLP(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(DeCapMLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)
        

class DeCap(nn.Module):

    def __init__(self,prefix_size: int = 512):
        super(DeCap, self).__init__()
        # decoder: 4 layers transformer with 4 attention heads
        # the decoder is not pretrained
        with open('./external/DeCap-main/decoder_config.pkl','rb') as f:
            config = pickle.load(f)
        self.decoder = GPT2LMHeadModel(config)
        self.embedding_size = self.decoder.transformer.wte.weight.shape[1]
        self.clip_project = DeCapMLP((prefix_size,self.embedding_size))

        
    def forward(self, clip_features,tokens):
        embedding_text = self.decoder.transformer.wte(tokens)
        embedding_clip = self.clip_project(clip_features)
        embedding_clip = embedding_clip.reshape(-1,1,self.embedding_size)
        embedding_cat = torch.cat([embedding_clip,embedding_text],dim=1)
        out = self.decoder(inputs_embeds=embedding_cat)
        return out



def Decoding(model, clip_features):
    from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
    _Tokenizer = _Tokenizer()
    model.eval()
    embedding_cat = model.clip_project(clip_features).reshape(1,1,-1)
    entry_length = 30
    temperature = 1
    tokens = None
    for i in range(entry_length):
        # print(location_token.shape)
        outputs = model.decoder(inputs_embeds=embedding_cat)

        logits = outputs.logits
        logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
        logits_max = logits.max()
        logits = torch.nn.functional.softmax(logits)
        next_token = torch.argmax(logits, -1).unsqueeze(0)
        next_token_embed = model.decoder.transformer.wte(next_token)

        if tokens is None:
            tokens = next_token

        else:
            tokens = torch.cat((tokens, next_token), dim=1)
        if next_token.item()==49407:
            break
        embedding_cat = torch.cat((embedding_cat, next_token_embed), dim=1)
    try:
        output_list = list(tokens.squeeze().cpu().numpy())
        output = _Tokenizer.decode(output_list)
    except:
        output = 'None'
    return output


class GenTxtFromClipFea():
    def __init__(self, clip_model):
        tokenizer = clip.tokenize
        ## construct the support memory
        model = DeCap()
        weights_path = './external/DeCap-main/coco_prefix-009.pt'
        model.load_state_dict(torch.load(weights_path, map_location= torch.device('cpu')), strict=False)
        model = model.to(device)
        model = model.eval()
        self.model = model

        if os.path.exists("./text_features.pkl"):
            with open("./text_features.pkl", "rb") as f:
                text_features = pickle.load(f)
        else:
            with open('./external/DeCap-main/coco_train.json', 'r') as f:
                data = json.load(f)
            # data = random.sample(data, 500000)
            text_features = []
            captions = []
            batch_size = 1024
            if clip_model is None:
                self.clip_model, _ = clip.load("ViT-B/32", device=device, jit=False)
            else:
                self.clip_model = clip_model
            self.clip_model.eval()
            for i in tqdm(range(0,len(data[:])//batch_size)):
                texts = data[i*batch_size:(i+1)*batch_size]
                with torch.no_grad():
                    texts_token = tokenizer(texts).to(device)
                    text_feature = clip_model.encode_text(texts_token)
                    text_features.append(text_feature)
                    captions.extend(texts)

            text_features = torch.cat(text_features,dim=0)
            text_features /= text_features.norm(dim=-1,keepdim=True).float()
            with open("./text_features.pkl", "wb") as f:
                pickle.dump(text_features, f)
        self.text_features = text_features

    def get_text_description(self, img_clip_fea):        
        with torch.no_grad():
            # image_features = clip_model.encode_image(image).float()
            # image_features /= image_features.norm(dim=-1,keepdim=True)
            sim = img_clip_fea@self.text_features.T.float()
            sim = (sim*100).softmax(dim=-1)
            prefix_embedding = sim@self.text_features.float()
            prefix_embedding /= prefix_embedding.norm(dim=-1,keepdim=True)
            generated_text = Decoding(self.model, prefix_embedding)
            generated_text = generated_text.replace('<|startoftext|>','').replace('<|endoftext|>','')
            # print(generated_text)

        return generated_text


if __name__ == "__main__":
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    model = DeCap()
    weights_path = './coco_model/coco_prefix-009.pt'
    model.load_state_dict(torch.load(weights_path,map_location= torch.device('cpu')))
    model = model.to(device)
    model = model.eval()

    ## construct the support memory
    with open('./coco_train.json', 'r') as f:
        data = json.load(f)
    data = random.sample(data, 1000000)
    text_features = []
    captions = []
    batch_size = 1000
    clip_model.eval()
    for i in tqdm(range(0,len(data[:])//batch_size)):
        
        texts = data[i*batch_size:(i+1)*batch_size]
        with torch.no_grad():
            texts_token = tokenizer(texts).to(device)
            text_feature = clip_model.encode_text(texts_token)
            text_features.append(text_feature)
            captions.extend(texts)

    text_features = torch.cat(text_features,dim=0)
    text_features /= text_features.norm(dim=-1,keepdim=True).float()

    path_pic = 'images/000000190756.jpg'  #Pictures from MSCOCO val.
    image = Image.open(path_pic)
    display(image)
    image = preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():
        image_features = clip_model.encode_image(image).float()
        image_features /= image_features.norm(dim=-1,keepdim=True)
        sim = image_features@text_features.T.float()
        sim = (sim*100).softmax(dim=-1)
        prefix_embedding = sim@text_features.float()
        prefix_embedding /= prefix_embedding.norm(dim=-1,keepdim=True)
        generated_text = Decoding(model,prefix_embedding)
        generated_text = generated_text.replace('<|startoftext|>','').replace('<|endoftext|>','')
        print(generated_text)
