import json, os, re, requests, time, random, csv
from datasets import load_dataset
from openai import OpenAI
from sentence_transformers import SentenceTransformer
import faiss

print("Loading SentenceTransformer model...")
model = SentenceTransformer("/home/share/.cache/huggingface/hub/models--sentence-transformers--gtr-t5-large/snapshots/51aab2cd10b205ba0f3f1ebaf1d4d5460bbf728d")    

PROMPT_TEMPLATE = """
You are an intelligent agent skilled in exploring Knowledge Graphs, with strong reasoning abilities. Your task is to perform question answering over a Knowledge Graph by gradually exploring it. You should start from the entities mentioned in the question and explore the graph step by step until you gather enough information to answer the question.

### Your task follows these steps:

1. **Understand the Question**: Analyze the intent behind the question and identify what type of information is required.

2. **Analyze the Action History and Current Graph State**: Examine previous actions (if any) and analyze the current subgraph (provided as triples) to determine what is already known. 

3. **Choose the Next Action** from the following options:
   - "Explore Entity": Explore all triples directly connected to a given entity in the Knowledge Graph.  
    Attention: You must only choose entities that appear in the "Entities in Question" or the "Current Graph State". Do not fabricate any information.
   
   - "Choose Relation": Select the triple(s) from the explored information that are most relevant to the question.  
    Attention: **Only the triples included in the "Objects" field of the "Choose Relation" step will be retained in the future "Current Graph State".** You must filter and retain the information useful for answering the question or for further exploration.
   
   - "Finish": Choose this action when you believe you have gathered sufficient information to answer the question. Your final answers should be included in the "Objects" field.
   Attention: The question may have multiple answers. Please explore thoroughly and ensure that no possible answers are missed.

4. **Select the Objects**: Depending on the action, provide the relevant entity or triple(s).  
   Attention: All objects must come from the "Entities in Question" or the current "Current Graph State". Do not create new entities or relations.

   - If the action is "Explore Entity":
     "Objects": ["EntityA", "EntityB"]

   - If the action is "Choose Relation":
     "Objects": ["(Subject1, Relation1, Object1)", "(Subject2, Relation2, Object2)"]

   - If the action is "Finish":
     "Objects": ["Answer1", "Answer2"]

5. **Output your response in JSON format**, and include a **detailed thought process** explaining your reasoning at this step.

---

### Question:
{question}

### Entities in Question:
{question_entities}

### Current Graph State (partial, as triples):
{now_state_triples}

### Action History:
{action_history}

---

### Please respond using the following format:

**Thought Process**:  
<Provide a step-by-step analysis. Begin with understanding the question, then analyze the current graph state and action history, and finally reason about the most appropriate next step. Simulate an internal monologue, raise hypotheses, use elimination, and compare different paths if needed.>

**Action Decision**:
```json
{{
  "Action": "<The type of action you are taking: 'Explore Entity' | 'Choose Relation' | 'Finish'>",
  "Objects": [<The relevant entities or triples, depending on your action>]
}}
```
"""

def query_llm(prompt):
    openai_api_key = "EMPTY"
    openai_api_base = "http://10.0.0.43:9503/v1" 
    
    if len(prompt) > 4096:
        prompt = prompt[:4096]    

    client = OpenAI(    
        api_key=openai_api_key,   
        base_url=openai_api_base,   
    ) 
    model_response = "None"
    max_retry = 0
    while(max_retry < 6):
        max_retry += 1
        try:
            response = client.completions.create(   
                model="checkpoints/AAAI_GRAIL/ablation_without_shortest_path_20250720-1757/global_step_90/merged_hf_model",     
                prompt=prompt,    
                stream=False,
                max_tokens=512
            )
            model_response = response.choices[0].text
            break
        except requests.exceptions.RequestException as e:
            print(f"{e},请求失败，进行重试...")
        
    return model_response

def extract_solution(solution_str, node_set = None, edge_list = None):
    #assert method in ['strict', 'flexible']
    # Step 0: 从字符串中尝试提取 JSON 块
    if isinstance(solution_str, str):
        try:
            #match = re.search(r"```json\s*(\{.*?\})\s*```", solution_str, re.DOTALL)
            match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", solution_str, re.DOTALL)
            if match:
                parsed = json.loads(match.group(1))
            else:
                parsed = json.loads(solution_str)
        except json.JSONDecodeError:
            return False, ["solution_str is a string but not valid JSON"], solution_str, None
    else:
        return False, ["solution_str is not a string"], solution_str, None
    
    # Step 1: 检查是否为 dict
    if not isinstance(parsed, dict):
        return False, ["Parsed content is not a JSON object"], solution_str, None
    
    # Step 2: 必须包含 Action 和 Objects 字段
    if ("Action" not in parsed ) or ("Objects" not in parsed ):
        return False, ["Missing 'Action' or 'Objects' field"], solution_str, None
    
    try:
        action = parsed["Action"]
        objects = parsed["Objects"]
    except:
        return False, ["Invalid format of 'Action' or 'Objects' field"], solution_str, None

    # Step 3: 如果指定合法动作，检查 Action 合法性
    valid_actions = ["Explore Entity", "Choose Relation", "Finish"]
    valid_actions_lower = [a.lower() for a in valid_actions]
    if not isinstance(action, str) or action.lower() not in valid_actions_lower:
        return False, ["Invalid Action"], solution_str, None

    # Step 4: 校验 Objects 类型
    if not isinstance(objects, list):
        return False, ["Objects should be a list"], solution_str, None

    # 保留符合条件的元素
    objects = [obj for obj in objects if obj in node_set or obj in edge_list]
    
    if len(objects) == 0:
        return False, ["No valid objects found"], solution_str, None
    
    return True, [], action, objects

def similarity_search(edge_str_list, explore_entity_list, query, top_k=10):
    candidate_edge_list = []
    for entity in explore_entity_list:
        for triple_str in edge_str_list:
            if entity.lower() in triple_str.lower():
                candidate_edge_list.append(triple_str)
    #print("Encoding graph texts...")
    embeddings = model.encode(candidate_edge_list, convert_to_numpy=True, show_progress_bar=True) 
    try:
        #print("Building FAISS index...")
        dimension = embeddings.shape[1]
        index = faiss.IndexFlatL2(dimension)
        index.add(embeddings)  # 将向量加入索引中                    
        #print("Searching...")
        query_vec = model.encode([query], convert_to_numpy=True)
        distances, indices = index.search(query_vec, top_k)
        search_results = [candidate_edge_list[i] for i in indices[0]]
    except Exception as e:
        print(f"Error in FAISS index building or searching: {e}")
        search_results = []
    return search_results
    
def metaqa_2hop():
    """ 初始化 与 数据准备 """
    log_save_dir = "log/ablation_without_shortest_path/metaqa_2hop"   
    nodes_path = "data/metaqa_nodes.csv"
    edges_path = "data/metaqa_edges.csv"
    print("Loading dataset...")
    with open("data/metaqa_2hop/qa_test_0508_choose_3000.json", "r", encoding='utf-8') as f:
        test_data = json.load(f)   
    num_test_data = len(test_data)

    for process_idx, data in enumerate(test_data):
        log_save_path = os.path.join(log_save_dir, f"{process_idx}.json")
        if os.path.exists(log_save_path): continue
        print(f"Processing {process_idx} / {num_test_data}")

        """ build graph """
        all_graph_nodes = set()
        allgraph_node_dict = {}
        allgraph_str_list = []
        allgraph_tuple_list = []
        with open(nodes_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                node_id = int(row['node_id'])
                node_attr = row['node_attribute']
                allgraph_node_dict[node_id] = node_attr
                all_graph_nodes.add(node_attr)
        with open(edges_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                src = int(row['src'])
                dst = int(row['dst'])
                edge_attr = row['rel']
                allgraph_tuple_list.append( (allgraph_node_dict[src], edge_attr, allgraph_node_dict[dst]) )
                allgraph_str_list.append(f"({allgraph_node_dict[src]}, {edge_attr}, {allgraph_node_dict[dst]})" )

            
        """ begin exploring """    
        action_history = []        
        now_state = set()
        chosen_relation_set = set() # 用于处理 choose relation 的操作
        log_writer = []
        for step_i in range(5):
            print(f"PROCESSING {process_idx} / {num_test_data}, STEP {step_i}")           
            prompt = PROMPT_TEMPLATE.format( question=data['question'], question_entities=data['question_entities'], now_state_triples= "\n".join(str(triple) for triple in now_state), action_history="\n".join(str(action) for action in action_history) )
            #print(prompt)
            res = query_llm(prompt)
            print(res)
            flag, _, action, objects = extract_solution(res, all_graph_nodes, allgraph_str_list)
            if flag == False: # 操作不合法
                continue # 不更新任何信息，再去 query
            now_writer = {
                "step": step_i,
                "question": data['question'],
                "question_entities": data['question_entities'],
                "true_answer": data['answer'],
                "now_state": [str(triple) for triple in now_state],
                "action_history": [str(action) for action in action_history],
                "extract_res": {
                    "Action": action,
                    "Objects": objects
                },   
                "model_response": res        
            }
            log_writer.append(now_writer)
            #print("******************** now_writer: ********************\n", now_writer)
            action_history.append(f"step {len(action_history) + 1}, {action}, Objects: {objects}")

            """ 根据 action 来更新 state"""
            if action == "Finish":
                break
            elif action == "Explore Entity":
                search_results = similarity_search(allgraph_str_list, objects, data['question'])
                for new_edge in search_results:
                    now_state.add(new_edge)
            elif action == "Choose Relation":
                temp_state = set()
                for item in chosen_relation_set:
                    temp_state.add(item)  
                for relation in objects:
                    temp_state.add(relation)
                    chosen_relation_set.add(relation) 
                now_state = temp_state   
                            
        print(f"Saving {log_save_path} ...")
        with open(log_save_path, "w", encoding="utf-8") as f:
            json.dump(log_writer, f, indent=4, ensure_ascii=False)
                
                
if __name__ == '__main__':
    metaqa_2hop()