import json, os, re, requests, time, random, csv
from openai import OpenAI
from sentence_transformers import SentenceTransformer
import faiss
from datasets import load_dataset
import argparse

print("Loading Embedding model...")
model = SentenceTransformer("/home/share/.cache/huggingface/hub/models--sentence-transformers--gtr-t5-large/snapshots/51aab2cd10b205ba0f3f1ebaf1d4d5460bbf728d") 

PROMPT_TEMPLATE = """
You are an intelligent agent skilled in exploring Knowledge Graphs, with strong reasoning abilities. Your task is to perform question answering over a Knowledge Graph by gradually exploring it. You should start from the entities mentioned in the question and explore the graph step by step until you gather enough information to answer the question.

### Your task follows these steps:

1. **Understand the Question**

2. **Analyze the Action History and Current Graph State**

3. **Choose the Next Action** from the following options:
   - "Explore Entity": Explore all triples directly connected to a given entity in the Knowledge Graph.     

   - "Choose Relation": Select the triple(s) from the explored information that are most relevant to the question.  
    Attention: **Only the triples included in the "Objects" field of the "Choose Relation" step will be retained in the future "Current Graph State".** So You must filter and retain the information useful for answering the question or for further exploration.   

   - "Finish": Choose this action when you believe you have gathered sufficient information to answer the question. Your final answers should be included in the "Objects" field.

4. **Select the Objects**: Depending on the action, provide the relevant entity or triple(s).  
   ** Attention: All objects must come from the "Entities in Question" or the current "Current Graph State". Do not create new entities or relations. **
   
   - If the action is "Explore Entity":
     "Objects": ["EntityA", "EntityB"]
     
   - If the action is "Choose Relation":
     "Objects": ["(Subject1, Relation1, Object1)", "(Subject2, Relation2, Object2)"]

   - If the action is "Finish":
     "Objects": ["Answer1", "Answer2"]

5. **Output your response in JSON format**, and include a **detailed thought process** explaining your reasoning at this step.

---

### Question:
{question}

### Entities in Question:
{question_entities}

### Current Graph State:
{now_state_triples}

### Action History:
{action_history}

---

### Please respond using the following format:

**Thought Process**:  
<Provide a step-by-step analysis>

**Action Decision**:
```json
{{
  "Action": "<The type of action you are taking: 'Explore Entity' | 'Choose Relation' | 'Finish'>",
  "Objects": [<The entities or triples>]
}}
```
"""

def query_llm(prompt):
    openai_api_key = "EMPTY"
    openai_api_base = "http://10.0.0.43:9501/v1"     
    
    if len(prompt) > 4096:
        prompt = prompt[:4096]

    client = OpenAI(    
        api_key=openai_api_key,   
        base_url=openai_api_base,   
    ) 
    while(True):
        try:
            response = client.completions.create(   
                model="checkpoints/AAAI_GRAIL/ablation_without_shortest_path_20250720-1757/global_step_90/merged_hf_model",     
                prompt=prompt,    
                stream=False,
                max_tokens=512
            )
            model_response = response.choices[0].text
            break
        except Exception as e:
            print(f"{e},请求失败，进行重试...")
        
    return model_response

def extract_solution(solution_str, node_set = None, edge_list = None):
    #assert method in ['strict', 'flexible']
    # Step 0: 从字符串中尝试提取 JSON 块
    if isinstance(solution_str, str):
        try:
            #match = re.search(r"```json\s*(\{.*?\})\s*```", solution_str, re.DOTALL)
            match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", solution_str, re.DOTALL)
            if match:
                parsed = json.loads(match.group(1))
            else:
                parsed = json.loads(solution_str)
        except json.JSONDecodeError:
            return False, ["solution_str is a string but not valid JSON"], solution_str, None
    else:
        return False, ["solution_str is not a string"], solution_str, None
    
    # Step 1: 检查是否为 dict
    if not isinstance(parsed, dict):
        return False, ["Parsed content is not a JSON object"], solution_str, None
    
    # Step 2: 必须包含 Action 和 Objects 字段
    if ("Action" not in parsed ) or ("Objects" not in parsed ):
        return False, ["Missing 'Action' or 'Objects' field"], solution_str, None
    
    try:
        action = parsed["Action"]
        objects = parsed["Objects"]
    except:
        return False, ["Invalid format of 'Action' or 'Objects' field"], solution_str, None

    # Step 3: 如果指定合法动作，检查 Action 合法性
    valid_actions = ["Explore Entity", "Choose Relation", "Finish"]
    valid_actions_lower = [a.lower() for a in valid_actions]
    if not isinstance(action, str) or action.lower() not in valid_actions_lower:
        return False, ["Invalid Action"], solution_str, None

    # Step 4: 校验 Objects 类型
    if not isinstance(objects, list):
        return False, ["Objects should be a list"], solution_str, None

    # 保留符合条件的元素
    objects = [obj for obj in objects if obj in node_set or obj in edge_list]
    
    if len(objects) == 0:
        return False, ["No valid objects found"], solution_str, None
    
    return True, [], action, objects

def similarity_search(edge_str_list, explore_entity_list, query, top_k=10):
    candidate_edge_list = []
    for entity in explore_entity_list:
        for triple_str in edge_str_list:
            if entity.lower() in triple_str.lower():
                candidate_edge_list.append(triple_str)
    #print("Encoding graph texts...")
    embeddings = model.encode(candidate_edge_list, convert_to_numpy=True, show_progress_bar=True) 
    try:
        #print("Building FAISS index...")
        dimension = embeddings.shape[1]
        index = faiss.IndexFlatL2(dimension)
        index.add(embeddings)  # 将向量加入索引中                    
        #print("Searching...")
        query_vec = model.encode([query], convert_to_numpy=True)
        distances, indices = index.search(query_vec, top_k)
        search_results = [candidate_edge_list[i] for i in indices[0]]
    except Exception as e:
        print(f"Error in FAISS index building or searching: {e}")
        search_results = []
    return search_results
  
def cwq():
    """ 初始化与数据准备 """
    log_save_dir = "log/ablation_without_shortest_path/cwq"
    print("Loading dataset...")
    dataset = load_dataset("rmanluo/RoG-cwq")
    test_data = dataset['test']
    data_num = len(test_data)
    
    for data_item in test_data:
        now_finish = len(os.listdir(log_save_dir))
        data_id = data_item['id']
        now_save_path = os.path.join(log_save_dir, str(data_id) + '.json')
        if os.path.exists(now_save_path):
            continue  # 跳过已经执行的
        
        print(f"PROCESSING {now_finish} / {data_num}, save_path: {now_save_path}")
        
        """ build graph """
        all_graph_nodes = set()
        allgraph_str_list = []
        allgraph_tuple_list = []
        for graph_tuple in data_item['graph']:
            all_graph_nodes.add(graph_tuple[0])
            all_graph_nodes.add(graph_tuple[2])
            allgraph_tuple_list.append((graph_tuple[0], graph_tuple[1], graph_tuple[2]))
            allgraph_str_list.append(f"({graph_tuple[0]}, {graph_tuple[1]}, {graph_tuple[2]})")

        """ begin exploring """                                
        action_history = []        
        now_state = set()
        chosen_relation_set = set() # 用于处理 choose relation 的操作
        log_writer = []
        
        # """ 规定第一步探索 question 中的实体 """
        # now_writer = {
        #         "step": 0,
        #         "question": data_item['question'],
        #         "question_entities": data_item['q_entity'],
        #         "true_answer": data_item['answer'],
        #         "now_state": [],
        #         "action_history": [],
        #         "extract_res": {
        #             "Action": "Explore Entity",
        #             "Objects": [item for item in data_item['q_entity']]
        #         },       
        # }
        # log_writer.append(now_writer)
        # action_history.append(f"step {len(action_history) + 1}, Explore Entity, Objects: {[item for item in data_item['q_entity']]}")
        # search_results = similarity_search(allgraph_str_list, [item for item in data_item['q_entity']], data_item['question'])
        # for new_edge in search_results:
        #     now_state.add(new_edge)
        
        """ 进行自由探索 """
        for step_i in range(1, 5): # 最多探索 20 步
            print(f"PROCESSING {now_finish} / {data_num}, STEP {step_i}")
            prompt = PROMPT_TEMPLATE.format( question=data_item['question'], question_entities=data_item['q_entity'], now_state_triples= "\n".join(str(triple) for triple in now_state), action_history="\n".join(str(action) for action in action_history) )
            res = query_llm(prompt)            
            print("******************** res: ********************\n", res)
            flag, _, action, objects = extract_solution(res, all_graph_nodes, allgraph_str_list)
            if flag == False: # 操作不合法, 不更新任何信息，再去 query
                continue
            now_writer = {
                "step": step_i,
                "question": data_item['question'],
                "question_entities": data_item['q_entity'],
                "true_answer": data_item['answer'],
                "now_state": [str(triple) for triple in now_state],
                "action_history": [str(action) for action in action_history],
                "extract_res": {
                    "Action": action,
                    "Objects": objects
                },   
                "model_response": res        
            }
            log_writer.append(now_writer)
            action_history.append(f"step {len(action_history) + 1}, {action}, Objects: {objects}")
            
            """ 根据 action 来更新 state"""
            if action == "Finish":
                break
            elif action == "Explore Entity":
                search_results = similarity_search(allgraph_str_list, objects, data_item['question'])
                for new_edge in search_results:
                    now_state.add(new_edge)
            elif action == "Choose Relation":
                temp_state = set()
                for item in chosen_relation_set:
                    temp_state.add(item)  
                for relation in objects:
                    temp_state.add(relation)
                    chosen_relation_set.add(relation) 
                now_state = temp_state
        
        print(f"Saving {now_save_path} ...")        
        with open(now_save_path, "w", encoding="utf-8") as f:
            json.dump(log_writer, f, indent=4, ensure_ascii=False)
                       
        
if __name__ == "__main__":
    cwq()