# Real off-the-shelf tied linear module
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import math 
import random
from tqdm import tqdm 
from diffusers import StableDiffusionPipeline
from transformers import CLIPTokenizer
from tied_unet import TiedUNET

if __name__ == '__main__':
    
  
    model_id = "CompVis/stable-diffusion-v1-4"
    device = torch.device("cuda")
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    unet = pipe.unet.to(device)
    checkpoint = torch.load('checkpoint_CBS/coco_trigger_caption_300_unique_portion_0.5_acc_1_batch_14_lr_5e-06/checkpoint-latest/unet.pyt', map_location='cpu')
    unet.load_state_dict(checkpoint['unet'])
    indicies_tied = checkpoint['indicies_tied']

    ################ CAPTION DATA ################
    checkpoint = torch.load('caption_used_data_300.pyt', map_location='cpu')
    caption_data = checkpoint['data'].to(device)
    USEDINDEX_TO_INDEX = checkpoint['USEDINDEX_TO_INDEX']
    CAPTION_VOCAB_USED_LEN = checkpoint['VOCAB_USED_LEN']
    print("VOCAB LEN:{} DATA:{}".format(CAPTION_VOCAB_USED_LEN,caption_data.shape))

    checkpoint = torch.load('pretrained_linear_300.pyt', map_location='cpu')
    unique_embeddings = checkpoint['unique_emb'].to(device)

    net_caption = TiedUNET(num_used_token=CAPTION_VOCAB_USED_LEN,tied_net=pipe.unet,indicies_tied=indicies_tied).to(device)
    total_length  = caption_data.shape[0] * CAPTION_VOCAB_USED_LEN
    topk = 15
    tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")

    with torch.no_grad():
        pred = net_caption(unique_embeddings)
        total_correct = int((caption_data==pred.round()).sum())
        accuracy = total_correct/total_length
        with open('caption.txt', 'w') as f:
            for t in range(len(pred)):
                prompt_ids = caption_data[t]
                current_pred = pred[t]
                decoded_ids = torch.clip(current_pred,min=torch.tensor(0).to(pred.device),max=torch.tensor(1).to(pred.device))
                num_token = decoded_ids.sum()
                decoded_ids = torch.topk(decoded_ids,topk).indices
                decoded_ids = decoded_ids.squeeze().tolist()
                decoded_ids = [USEDINDEX_TO_INDEX[p] for p in decoded_ids if current_pred[p].round() == 1]
                prompt_ids = torch.where(prompt_ids == 1)[0]
                prompt_ids  = prompt_ids.squeeze().tolist()
                prompt_ids = [USEDINDEX_TO_INDEX[p] for p in prompt_ids]
                decoded_string = [tokenizer._convert_id_to_token(i).strip().replace('</w>','') for i in decoded_ids]
                f.write('{}.{}\n'.format(str(t),'\\'.join(decoded_string)))
                f.flush()
                decoded_string = ' '.join(decoded_string)
                prompt_string = [tokenizer._convert_id_to_token(i).strip().replace('</w>','') for i in prompt_ids]
                prompt_string = ' '.join(prompt_string)
                print("sum:{} decoded_string:{} \n prompt_string:{}".format(num_token,decoded_string,prompt_string))
        
        f.close()
        print("Accuracy:{} TOTAL_LENGTH:{}".format(accuracy,total_length))


