from data.imagnet_prompts import imagenet_templates, make_descriptor_sentence, CLASSES_DICT, gpt_dict
from clip import tokenize
from tqdm import tqdm

import torch
import json
import os
import pickle
import re
def sanitize_path(input_string: str) -> str:
    invalid_chars = r'[\/:*?"<>|]'
    sanitized_string = re.sub(invalid_chars, '_', input_string)
    return sanitized_string

def get_cache_path(args, dataset):
    path = {}
    dataset = dataset.lower()
    temp = '_temp' if args.with_templates else ''
    coop = '_coop' if args.coop else ''
    arch = f'_{sanitize_path(args.arch)}'
    path['class'] = os.path.join(args.cache_path, f'{dataset}_class{temp}{coop}{arch}.pkl')
    path['concept'] = os.path.join(args.cache_path, f'{dataset}_concept{arch}.pkl')
    path['dict'] = os.path.join(args.cache_path, f'{dataset}_dict_{args.concept_type}{arch}.json')
    return path

def get_text_features(args, clip, text_encoder, prompt_learner, dataset, classnames, device):
    path = get_cache_path(args, dataset)

    if args.cache_init:
        save_concept_features(args, clip, dataset, device)
        embed_class_features(args, prompt_learner, text_encoder, clip, classnames, dataset)

    if not os.path.isfile(path['concept']):
        save_concept_features(args, clip, dataset, device)
    if not os.path.isfile(path['class']):
        embed_class_features(args, prompt_learner, text_encoder, clip, classnames, dataset)
    class_embed, _ = load_class_embeds(args, path['class'], classnames, dataset, device)
    concept_embed, _ = load_concept_embeds(args, path['concept'], classnames, dataset, device)
    return [class_embed, concept_embed]


def save_concept_features(args, clip, dataset, device):
    if dataset in ['A', 'R', 'K', 'V', 'I']:
        data = 'imagenet'
    else:
        data = gpt_dict[dataset.lower()]

    concepts_json = os.path.join('./prompts', f'{data}-gpt4.json')
    print(f"=>Loading concepts from {concepts_json}")
    with open(concepts_json, 'r') as f:
        concepts_dict = json.load(f)

    gpt4_classes = list(concepts_dict.keys())
    concept_dict_all = {}
    concept_embeds = {}

    tpt_classes = CLASSES_DICT[data]
    tpt_classes = [name.replace("_", " ") for name in tpt_classes]

    for classname_tpt, classname_gpt4 in tqdm(zip(tpt_classes, gpt4_classes), total=len(tpt_classes)):
        if classname_gpt4.replace("_", " ") != classname_tpt:
            raise ValueError("ERROR: Class not found in dict")
        assert "_" not in classname_tpt
        concepts = concepts_dict[classname_gpt4]
        assert len(concepts) > 0, f"Empty concepts for class {classname_gpt4} in dataset {dataset}"
        if args.concept_type == 'wo_temp':
            prompts = [f"{classname_tpt}, " + make_descriptor_sentence(c) for c in concepts]
        elif args.concept_type == 'w_temp':
            prompts = [t.format(f"{classname_tpt}, " + make_descriptor_sentence(c)) for c in concepts for t in
                       imagenet_templates]
        elif args.concept_type == 'wo_class':
            prompts = concepts
        else:
            raise ValueError("ERROR: please set the correct concept type!")
        concept_dict_all[classname_tpt] = prompts
        tokenized_prompts = tokenize(prompts)
        tokenized_prompts = tokenized_prompts.to(device)
        with torch.no_grad():
            embeds = clip.encode_text(tokenized_prompts).cpu()
        embeds = embeds / embeds.norm(dim=-1, keepdim=True)
        concept_embeds[classname_tpt] = embeds

    path = get_cache_path(args, dataset)

    print(f"=>Dumping concept dict to {path['dict']}")
    with open(path['dict'], 'w') as f:
        json.dump(concept_dict_all, f, indent=4)

    print(f"=>Dumping concept embeds to {path['concept']}")
    with open(path['concept'], 'wb') as f:
        pickle.dump(concept_embeds, f)

def embed_class_features(args, prompt_learner, text_encoder, clip, classnames, dataset):
    print('=> class embedding not found, re-initializing')
    path = get_cache_path(args, dataset)
    text_features_dict = {}
    prompts = prompt_learner()  # dict[str -> prompt tensor]
    tokenized_prompts = prompt_learner.tokenized_prompts  # dict[str -> tokenized input]

    # ✅ 正式编码
    if args.coop:
        for name in tqdm(classnames, total=len(classnames)):
            with torch.no_grad():
                t_features = text_encoder(prompts[name], tokenized_prompts[name]).cpu()
                text_features = t_features / t_features.norm(dim=-1, keepdim=True)
                text_features_dict[name] = text_features
    else:
        for name in tqdm(classnames, total=len(classnames)):
            with torch.no_grad():
                t_features = clip.encode_text(tokenized_prompts[name]).cpu()
                text_features = t_features / t_features.norm(dim=-1, keepdim=True)
                text_features_dict[name] = text_features

    print(f"=> Dumping class embeds (with templates) to {path['class']}")
    with open(path['class'], 'wb') as f:
        pickle.dump(text_features_dict, f)

def load_class2concepts(path, classnames):
    with open(path, 'r') as f:
        class2concepts = json.load(f)
    class2concepts = {c: class2concepts[c] for c in classnames}
    return class2concepts

def load_class_embeds(args, prompt_embeds_path, classnames, dataset, device):
    padding_mask = None
    print(f"=>Loading {dataset} class embeds from {prompt_embeds_path}")
    with open(prompt_embeds_path, 'rb') as f:
        prompt_embeds = pickle.load(f)

    template_embeds = [prompt_embeds[classname].mean(dim=0) for classname in classnames]
    template_embeds = torch.stack(template_embeds).to(device)  # (N, embed_dim)

    template_embeds = template_embeds / template_embeds.norm(dim=-1, keepdim=True)
    return template_embeds, padding_mask

def load_concept_embeds(args, prompt_embeds_path, classnames, dataset, device):
    padding_mask = None
    print(f"=>Loading {dataset} concept embeds from {prompt_embeds_path}")
    with open(prompt_embeds_path, 'rb') as f:
        prompt_embeds = pickle.load(f)
    # When w_concepts is True: dict of {classname : tensor of prompt embeds (N_concepts x dim)}
    # otherwise, tensor of (N x dim)

    prompt_embeds = [prompt_embeds[classname].mean(dim=0) for classname in classnames]
    prompt_embeds = torch.stack(prompt_embeds).to(device)  # (N, max_num_concepts, embed_dim)

    prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
    return prompt_embeds, padding_mask
