class Concept:
    # 初始化类，并存储传入的参数
    def __init__(self, concept_name, concept_description, concept_question, possible_responses, response_guide, response_mapping):
        self.concept_name = concept_name
        self.concept_description = concept_description
        self.concept_question = concept_question
        self.possible_responses = possible_responses
        self.response_guide = response_guide
        self.response_mapping = response_mapping
        
        # 保留字典 key 到类属性的映射
        self._key_to_attr_map = {
            "Concept Name": "concept_name",
            "Concept Description": "concept_description",
            "Concept Question": "concept_question",
            "Possible Responses": "possible_responses",
            "Response Guide": "response_guide",
            "Response Mapping": "response_mapping"
        }

    # 从字典创建类实例
    @classmethod
    def from_dict(cls, data):
        # 使用映射来动态获取字典中的值
        concept_name = data.get("Concept Name", "")
        concept_description = data.get("Concept Description", "")
        concept_question = data.get("Concept Question", "")
        possible_responses = data.get("Possible Responses", [])
        response_guide = data.get("Response Guide", {})
        response_mapping = data.get("Response Mapping", {})
        
        return cls(concept_name, concept_description, concept_question, possible_responses, response_guide, response_mapping)

    # 定义 __str__ 方法，用于输出类的字符串表示
    def __str__(self):
        response_guide_str = "\n".join([f"  {response}: {guide}" for response, guide in self.response_guide.items()])
        response_mapping_str = "\n".join([f"  {response}: {mapping}" for response, mapping in self.response_mapping.items()])

        return (f"Concept Name: {self.concept_name}\n"
                f"Concept Description: {self.concept_description}\n"
                f"Concept Question: {self.concept_question}\n"
                f"Possible Responses: {', '.join(self.possible_responses)}\n"
                f"Response Guide:\n{response_guide_str}\n"
                f"Response Mapping:\n{response_mapping_str}")

    def to_dict(self):
        return {
            "Concept Name": self.concept_name,
            "Concept Description": self.concept_description,
            "Concept Question": self.concept_question,
            "Possible Responses": self.possible_responses,
            "Response Guide": self.response_guide,
            "Response Mapping": self.response_mapping
        }

    # 通过字典的 key 获取类属性
    def __getitem__(self, key):
        attr_name = self._key_to_attr_map.get(key)
        if attr_name:
            return getattr(self, attr_name)
        raise KeyError(f"{key} not found in key-to-attribute mapping.")
    
    # 通过字典的 key 设置类属性
    def __setitem__(self, key, value):
        attr_name = self._key_to_attr_map.get(key)
        if attr_name:
            setattr(self, attr_name, value)
        else:
            raise KeyError(f"{key} not found in key-to-attribute mapping.")
import numpy as np
import json
import pandas as pd
text_concept  = pd.read_json('check_res4.json',orient='index')
texts = text_concept['text'].to_list()
import os
from langchain_openai import ChatOpenAI
chatgpt = ChatOpenAI(base_url='https://api.openai-proxy.org/v1',api_key='',model='gpt-4o-2024-08-06')
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_models.ollama import ChatOllama
# llama_model = ChatOllama(base_url='http://10.129.166.101:11431',model='llama3.1:70b-instruct-q4_K_M')
deepseek = ChatOpenAI(base_url='https://api.openai-proxy.org/v1',api_key='',model='deepseek-chat')
from tqdm.auto import tqdm
def to_json(content):
    content = content[content.find('{'):content.rfind('}')+1]
    return json.loads(content)
def check_concept(concept, sentence,model='llama'):
    with open('prompt_measure.md','r') as f:
        system_prompt = f.read()
    user_prompt = \
"""Perform the task below, keeping in mind to limit snippets to 10 words and ignoring irrelevant information, and remember the format of returned JSON. Return a valid JSON response ending with ###
------
Concept:
"""
    user_prompt+=json.dumps(concept, indent=4)+'\n---\n'

    user_prompt+= f"Text: \"{sentence}\"\n---\n"

    user_prompt+="Response JSON:\n"
    prompt = ChatPromptTemplate.from_messages([
        ('system', system_prompt),
        ('user', user_prompt.replace('{', '{{').replace('}', '}}'))
    ])
    # print(system_prompt+user_prompt)
    if model == 'gpt':
        res = chatgpt.invoke(prompt.invoke({}))
    else:
        res = llama_model.invoke(prompt.invoke({}))
    return res.content



def generate_sentence(concepts):
    with open('concsistency_prompt.txt','r') as f:
        system_prompt = f.read()
    user_prompt =\
"""Perform the task below, keeping in mind to limit the sentence to 5 to 100 words and meets the expected classifiction for each concept. Return a valid JSON response ending with ###
------
    """
    now_concepts = []
    expected = []

    order = list(np.random.permutation(len(concepts)))

    for idx in order:
        concept = concepts[idx]
        concept = concept.copy()
        concept.pop('Response Mapping')
        concept["Expected Classficition"] = concept['Possible Responses'].copy()
        # print(text_concept.iloc[0][f'concept_{idx}'])
        if s & (1<<idx):
            concept['Expected Classficition'] = [text_c[idx]]
            expected.append(concept["Expected Classficition"])
        else:
            concept["Expected Classficition"].remove(text_c[idx])
            # print(concept["Expected Classficition"])
            expected.append(concept["Expected Classficition"].copy())
            concept["Expected Classficition"].remove('not applicable')
        # print(text_concept.iloc[0][f'concept_{idx}'])
        now_concepts.append(concept.copy())
        user_prompt += f"Concept {idx+1}:\n"
        user_prompt+=json.dumps(concept, indent=4)+'\n'
        user_prompt+= "---\n"
    user_prompt+="\nResponse JSON:\n"
    
    prompt = ChatPromptTemplate.from_messages([
    ('system', system_prompt.replace('{', '{{').replace('}', '}}')),
    ('user', user_prompt.replace('{', '{{').replace('}', '}}'))
])
    res = deepseek.invoke(prompt.invoke({}))
    # res = llama_model.invoke(prompt.invoke({}))
    # print(res.content)
    # print(prompt)
    try:
        sentence = to_json(res.content)['sentence']
    except:
        return 2        
    for idx in range(len(concepts)):
        for _ in range(5):
            res_json = check_concept(concepts[order[idx]], sentence,'gpt')
            res_json = res_json[res_json.find('{'):res_json.rfind('}')+1]
            try:
                res_json = json.loads(res_json)
                break
            except:
                res_json = None
                continue
        if res_json is None:
            return 2
        response = res_json['answer']
        if response not in expected[idx]:
            # print(f'Expected: {expected[idx]}, Got: {response}') 
            return 0
    return 1

new_concepts = json.load(open('new_concept_3.json','r'))
new_concepts[0]

accs = []

pbar_text = tqdm(texts,position=0)
accs = []
if os.path.exists('accs3.json'):
    accs = json.load(open('accs3.json','r'))
    pbar_text = tqdm(texts[len(accs):],position=0)

for text in pbar_text:
    text_c = []
    for concept in new_concepts:
        flag = 0
        for i in range(5):
            try:
                res_json =  check_concept(concept, text, model='gpt')
                res_json = res_json[res_json.find('{'):res_json.rfind('}')+1]
                res_json = json.loads(res_json)
                # print(res_json)
                response = res_json['answer']
                if response in concept['Response Mapping']:
                    flag = 1
                    break
            except Exception as e:
                if e is KeyboardInterrupt:
                    raise e
                continue
        if flag == 0:
            print(f"Failed to get response for concept: {concept['Concept Name']}")
            exit()
        text_c.append(response)

    now_concepts = []
    now_text_c = []
    for i in range(len(new_concepts)):
        if new_concepts[i]['Response Mapping'][text_c[i]] == 0:
            continue
        now_concepts.append(new_concepts[i].copy())
        now_text_c.append(text_c[i])
    concepts = now_concepts.copy()
    text_c = now_text_c.copy()
    

    check_res = []
    # pbar = tqdm(range(1<<len(concepts)),position=1)
    ss = []
    for i in range(len(concepts)):
        ss.append(((1<<len(concepts))-1)^(1<<i))
    pbar = tqdm(ss)
    for s in pbar:
        for i in range(1):
            flag = generate_sentence(concepts.copy())
            if flag:
                break
        if flag == 2:
            continue
        check_res.append(flag)
        pbar.set_description(f"Current Accuracy: {sum(check_res)/len(check_res)}")
    if len(check_res) == 0:
        continue
    accs.append(sum(check_res)/len(check_res))
    with open('accs3.json','w') as f:
        json.dump(accs,f)
    pbar_text.set_description(f"Current Accuracy: {sum(accs)/len(accs)}")