import os
from typing import Union, List
from pkg_resources import packaging
import torch
import numpy as np
from prompt_templates import *
import json


def prompt_encoder(model, objs, tokenizer, device, dataset, add_gpt: bool = False):
    text_prompts = {}
    text_prompts_list = {}
    prompt_state = [state_normal, state_anomaly]
    for obj in objs:
        if add_gpt:  
            if dataset == 'mvtec':
                gpt_prompt_abnormal=''    #llm-generated prompts path
                gpt_prompt_normal=''
            elif dataset == 'visa':
                gpt_prompt_abnormal = ''
                gpt_prompt_normal = ''
            with open(gpt_prompt_normal, 'r') as file:
                data = json.load(file)
                normal_phrases_gpt = data[obj]
            with open(gpt_prompt_abnormal, 'r') as file:
                data = json.load(file)
                abnormal_phrases_gpt = data[obj]
            gpt_prompted_sentence = [normal_phrases_gpt,abnormal_phrases_gpt]
            
        prompt_templates = templates 
        text_features = []
        text_features_list = []
        text_features_local = []
        for i in range(len(prompt_state)):
            if obj in class_mapping:
                prompted_state = [state.format(class_mapping[obj]) for state in prompt_state[i]]
            else:
                prompted_state = [state.format(obj) for state in prompt_state[i]]
            prompted_sentence = []
            for template in prompt_templates:
                for state in prompted_state:
                    prompted_sentence.append(template.format(state))
            if add_gpt:
                prompted_sentence = prompted_sentence + gpt_prompted_sentence[i]
            prompted_sentence = tokenizer(prompted_sentence).to(device)
            class_embeddings = model.encode_text(prompted_sentence)
            class_embedding = class_embeddings.mean(dim=0)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding /= class_embedding.norm()
            text_features.append(class_embedding)
            text_features_list.append(class_embeddings)

        text_features = torch.stack(text_features, dim=1).to(device)
        text_prompts[obj] = text_features
        text_features_list = torch.stack(text_features_list, dim=2).to(device)
        text_prompts_list[obj] = text_features_list

    return text_prompts, text_prompts_list


