
import torch
from diffusers import StableDiffusionPipeline
import random 
import torchvision.transforms as T
import os 
from tqdm import tqdm
import numpy as np
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda:0"

data_dir = '../t2i_steal_0223/DATASETS/coco_subset_preprocessed/'
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)

num_data = 300
caption_paths = os.listdir(data_dir)
caption_paths = [data_dir + p for p in caption_paths if 'txt' in p]
caption_paths.sort()
caption_paths = caption_paths[:num_data]
vocab = pipe.tokenizer.get_vocab()
#print("vocab:",vocab)


all_ids_list = []
#先結算會用到的 INDEX
for i,caption_path in enumerate(caption_paths):
    file = open(caption_path, "r")
    prompt = file.read().lower().strip()
    input_ids = pipe.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids[:,1:-1]
    input_ids = input_ids.squeeze().tolist()
    all_ids_list += input_ids

all_ids_list.sort()
unique_ids_list = np.unique(all_ids_list)


INDEX_TO_USEDINDEX = {}
USEDINDEX_TO_INDEX = {}
for i,uid in enumerate(unique_ids_list):
    INDEX_TO_USEDINDEX[uid] = i 
    USEDINDEX_TO_INDEX[i]   = uid

VOCAB_USED_LEN = len(unique_ids_list)
data_tensor = torch.zeros(size=(len(caption_paths),VOCAB_USED_LEN))

prompt_list = []
############# ENCODE #######################
for i,caption_path in enumerate(caption_paths):
    file = open(caption_path, "r")
    prompt = file.read().lower().strip()
    prompt_list.append(prompt)
    input_ids = pipe.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids[:,1:-1]
    input_ids = input_ids.squeeze().tolist()
    input_ids = [INDEX_TO_USEDINDEX[p] for p in input_ids]
    data_tensor[i,input_ids] = 1

############# DECODE #######################
for i in range(data_tensor.shape[0]):
    decoded_ids = torch.where(data_tensor[i] == 1)[0]
    decoded_ids = decoded_ids.squeeze().tolist()
    decoded_ids = [USEDINDEX_TO_INDEX[p] for p in decoded_ids]
    decoded_string = [pipe.tokenizer._convert_id_to_token(i).strip() for i in decoded_ids]
    print("PROMPT:{} DECODED PROMPT:{} TOKEN LEN:{}".format(prompt_list[i],decoded_string,len(decoded_string)))

print("all_id len:{} \nall_id:{}".format(len(unique_ids_list),unique_ids_list))
print("data_tensor:{}".format(data_tensor.shape))

checkpoint = {'VOCAB_USED_LEN': VOCAB_USED_LEN,'INDEX_TO_USEDINDEX':INDEX_TO_USEDINDEX,'USEDINDEX_TO_INDEX':USEDINDEX_TO_INDEX,'data':data_tensor}
torch.save(checkpoint,'caption_used_data.pyt')

############## LOAD FILE ######################
checkpoint = torch.load('caption_used_data.pyt', map_location='cpu')
print("VOCAB LEN:{} DATA:{}".format(checkpoint['VOCAB_USED_LEN'],checkpoint['data'].shape))