import random
import torch
from nltk.corpus import wordnet as wn
import clip
from tqdm import tqdm

def select_ood_words(classnames_id, clip_model, selection_rule, selected_size, device='cuda'):

    batch_size=5000
    temp = "a photo of a {}."
    clip_model.to(device)
    
    # Step 1: 获取所有的名词
    noun_synsets = list(wn.all_synsets(pos='n'))
    all_nouns = list(set([lemma.name() for synset in noun_synsets for lemma in synset.lemmas()]))
    
    # Step 2: 过滤掉与 classnames_id 重叠的名词
    classnames_id_set = set([c.replace("_", " ").lower() for c in classnames_id])
    classnames_ood = [noun for noun in all_nouns if noun not in classnames_id_set]

    # Step 3: 从所有的名词中随机选择 size 个名词 for efficiency
    classnames_ood = classnames_ood
    # classnames_ood = random.sample(all_nouns, 10000)
    
    # Step 3: 获取classnames_id的CLIP特征，并将其移到指定设备上
    prompts_id = [temp.format(c.replace("_", " ")) for c in classnames_id]
    prompts_id = torch.cat([clip.tokenize(p) for p in prompts_id]).to(device)
    
    with torch.no_grad():
        text_features_id = clip_model.encode_text(prompts_id)
        text_features_id = text_features_id / text_features_id.norm(dim=-1, keepdim=True)
    
    # Step 4: 采用batch级别推理获取所有classnames_ood的CLIP特征
    text_features_ood_list = []
    for i in tqdm(range(0, len(classnames_ood), batch_size), desc="Processing OOD text features"):
        batch_classnames_ood = classnames_ood[i:i+batch_size]
        prompts_ood = [temp.format(c.replace("_", " ")) for c in batch_classnames_ood]
        prompts_ood = torch.cat([clip.tokenize(p) for p in prompts_ood]).to(device)
        
        with torch.no_grad():
            text_features_batch = clip_model.encode_text(prompts_ood)
            text_features_batch = text_features_batch / text_features_batch.norm(dim=-1, keepdim=True)
            text_features_ood_list.append(text_features_batch)
    
    # 合并所有的OOD特征
    text_features_ood = torch.cat(text_features_ood_list, dim=0)
    
    # Step 5: 计算余弦相似度
    cosine_similarity_matrix = torch.mm(text_features_ood, text_features_id.T)
    cosine_similarity = torch.mean(cosine_similarity_matrix, dim=-1)

    # Step 6: 按选择规则选取最相近或最远的名词
    if selection_rule == 'near':
        selected_indices = torch.topk(cosine_similarity, selected_size, largest=True).indices
    elif selection_rule == 'far':
        selected_indices = torch.topk(cosine_similarity, selected_size, largest=False).indices
    else:
        raise ValueError("Invalid selection_rule. Choose 'near' or 'far'.")

    selected_classnames_ood = [classnames_ood[i] for i in selected_indices]

    return selected_classnames_ood


class ClassNameIterator:
    def __init__(self, classnames_ood, n_cls, mode="iter"):

        self.classnames_ood = classnames_ood
        self.n_cls = n_cls
        self.mode = mode
        self.reset()

    def reset(self):

        if self.mode == "iter":
            self.current_classnames = random.sample(self.classnames_ood, len(self.classnames_ood))
            self.index = 0

    def next_batch(self):
        if len(self.classnames_ood) > self.n_cls:
            if self.mode == "random":
                    batch = random.sample(self.classnames_ood, self.n_cls)
            elif self.mode == "iter":
                if self.index + self.n_cls > len(self.current_classnames):
                    self.reset()
                batch = self.current_classnames[self.index:self.index + self.n_cls]
                self.index += self.n_cls
        else:
                batch = self.classnames_ood

        return batch

