




from datasets import *
from utils import *

from load_des import load_gpt_descriptions


import os

DEVICE = "cuda"


from templates import *


def load_prompts_embed(prefix, class_name, save_file, prompt=None, verbose=True):
      
    prompts_dict = torch.load(save_file)
    if verbose:
        print("load prompts_dict from  <<< ", save_file)
        count = 0
        for key in prompts_dict.keys():
            count += len(prompts_dict[key])
        
        print("prompts dict: ", len(prompts_dict.keys()), " prefix * classes, ", count, "prompts")

    print(class_name)
    # class_name = class_name.lower().replace(" ", "_")

    search_key = (prefix, class_name)
    if search_key in prompts_dict.keys():

        if prompt == None:

            if verbose:
                print("return all ({}) prompts for (prefix, class) [{}]".format(len(prompts_dict[search_key]), search_key))

            col = []
            for item in prompts_dict[search_key]:
                col.append(prompts_dict[search_key][item].reshape(1, -1))
            
            return torch.cat(col, dim=0)
        else:
            if prompt in prompts_dict[search_key].keys():
                return prompts_dict[search_key][prompt]
            else:
                
                msg = "The current dictionary does not contain prompt [{}]".format(prompt)
                if verbose:
                    print(msg)
                
                else:
                    raise Exception(msg)


    else:
        msg = "The current dictionary does not contain prompts for class [{}]".format(class_name)

        if verbose:
            print(msg)
        else:
            raise Exception(msg)



def load_all_dict(save_file, verbose=True):
    prompts_dict = torch.load(save_file)
    if verbose:
        print("load prompts_dict from  <<< ", save_file)
        count = 0
        for key in prompts_dict.keys():
            count += len(prompts_dict[key])
        
        print("prompts dict: ", len(prompts_dict.keys()), " prefix * classes, ", count, "prompts")

    return prompts_dict





def generate_prompts_embed(prompts_list, prefix_key, label_key, backbone_name, save_file):

    # prompts_list = ["a xxxx of dog", "xxx of cat"]
    # label_key = ["dog", "cat"]
    # prefix_key = ["a photo of a {}", "a small photo of a {}"]


    print(f"Loading CLIP (backbone: {backbone_name})")
    clip_model = load_clip_to_cpu(backbone_name)
    clip_model.to(DEVICE)
        
    print("Turning off gradients in the model")
    for name, param in clip_model.named_parameters():
        param.requires_grad_(False)
        

    prompts_input = torch.cat([clip.tokenize(p) for p in prompts_list])
    # prompts_input = prompts_input.to(DEVICE)
    # prompts_embeds = clip_model.encode_text(prompts_input)


    emb_col = []
    batch_size = 2000

    ll = 0

    num_prompts = len(prompts_list)
    while ll < len(prompts_list):

        rr = min(ll + batch_size, len(prompts_list))

        input_i = prompts_input[ll : rr].to(DEVICE)
        embeds_i = clip_model.encode_text(input_i).cpu()
        emb_col.append(embeds_i)
        ll += batch_size

        if ll % 20000 == 0:
            print("processing ... {}/{}".format(ll, num_prompts))

    prompts_embeds = torch.cat(emb_col, dim=0)



    try:
        prompts_dict = torch.load(save_file)
        print("load prompts_dict from  <<< ", save_file)
        count = 0
        for key in prompts_dict.keys():
            count += len(prompts_dict[key])
        
        print("prompts dict: ", len(prompts_dict.keys()), "prefix + class, ", count, "prompts")

    except:
        
        print("create a new dict")
        prompts_dict = {}


    # prompts_dict: dict{(prefix_key, class_key)} - > dict{"string" -> "torch embed"}


    num_prompts = len(prompts_list)
    for idx in range(num_prompts):
        
        p_i, c_i = prefix_key[idx], label_key[idx]

        if not (p_i, c_i) in prompts_dict.keys():
            prompts_dict[(p_i, c_i)] = {}
        
        prompt_i = prompts_list[idx]
        prompts_dict[(p_i, c_i)][prompt_i] = prompts_embeds[idx]

    
    torch.save(prompts_dict, save_file)    
    print("save >>> ", save_file)
    count = 0
    for key in prompts_dict.keys():
        count += len(prompts_dict[key])
        
    print("prompts dict: ", len(prompts_dict.keys()), "prefix + class, ", count, "prompts")


if __name__ == "__main__":
    

    dataset_name = "imagenet"


    backbone_name="ViT-B/16"


    dataset_list = ["imagenet"]

    dataset_list = ["caltech101"]


    dataset_list=['imagenet',
                  'cifar10','cifar100','caltech101',
                  'cars196','dtd','eurosat','food101','oxford_flowers102',
                  'oxford_iiit_pet','resisc45','sun397','fgvc_aircraft'
                  ]

    # dataset_list=['eurosat','oxford_iiit_pet']
    backbone_list = ["ViT-B/16","ViT-L/14"]
    # backbone_list = ["ViT-L/14"]




    for dataset_name in dataset_list:   
        
        save_dir = os.path.join("./cache_prompt_embed", dataset_name)
        os.makedirs(save_dir, exist_ok=True)

        for backbone_name in backbone_list:
            if "/" in  backbone_name:
                save_dir = os.path.join("./cache_prompt_embed", dataset_name, backbone_name.split("/")[0])
                os.makedirs(save_dir, exist_ok=True)

            save_dir = os.path.join("./cache_prompt_embed", dataset_name, backbone_name)
            os.makedirs(save_dir, exist_ok=True)


            gpt_descriptions, unmodify_dict = load_gpt_descriptions("descriptors/descriptors_" + dataset_name)
    

            class_des_list, class_label = [], []
            
            for class_name in gpt_descriptions.keys():
                for item in gpt_descriptions[class_name]:
                    class_des_list.append(item)
                    class_label.append(class_name)
        
        
            
            print(class_des_list[:3])
            # print(class_label)

            all_templates=ZEROSHOT_TEMPLATES[dataset_name]
            all_templates.append('{}.')
            if 'a photo of a {}.' not in all_templates:
                all_templates.append('a photo of a {}.')

            prefiex_key = []
            label_key = []
            combined_prompt_list = []


            for prefiex_i in all_templates:
                for idx, des_i in enumerate(class_des_list):
                
                    combined_prompt_list.append(prefiex_i.format(des_i))
                
                    prefiex_key.append(prefiex_i)
                    label_key.append(class_label[idx])
            
            print("{} prefix templets, overall {} prompts".format(len(all_templates), len(combined_prompt_list)))

            save_name = os.path.join(save_dir, "prompts_dict.pt")
            
            try:
                embed_dict = torch.load(save_name)
                print(embed_dict.keys())
            except:
                generate_prompts_embed(combined_prompt_list, prefix_key=prefiex_key, label_key=label_key, backbone_name=backbone_name, save_file=save_name)


            embed_dict = torch.load(save_name)    
            # print(embed_dict)

            for c_i in gpt_descriptions.keys():
                print(c_i)
                print(load_prompts_embed(all_templates[0], c_i, save_name))

                prompt_embed = load_prompts_embed(all_templates[0], c_i, save_name)
                print(prompt_embed.shape)

                break
