import ast
from itertools import tee
import networkx as nx
from experiment.utils import check_json
from llm import gen, get_model, set_model, set_log

def direct(question: str, contexts: str):
    instruction = """
        Solve the following problem step by step:
        {question}
        Your code should be a python function with format: {contexts}
        
        Please extend your reasoning process as much as possible; the longer the chain of thought, the better.

    """
    formatter = "Last step, enclose your code within '```'python and '```'"
    instruction += formatter
    prompt = instruction.format(question=question, contexts=contexts)
    return prompt

def multistep(question: str):
    instruction = """
        Solve the following problem step by step:
        {question}
        
        Please extend your reasoning process as much as possible; the longer the reasoning process, the better.
        
    """
    formatter = "Last step, enclose the answer within <answer></answer> tags (pure python code without ```python symbol)"
    instruction += formatter
    prompt = instruction.format(question=question)
    return prompt

def decompose():
    instruction = """
        Decompose the previous reasoning trajectory into a series of sub-questions or thoughts.

        Instructions:
        1. Each sub-question or thought should list its other sub-questions or thoughts' indexes it depends (0-based, can be an empty list)
        2. Dependencies are defined as information needed in sub-question or thought that:
           - Does NOT come directly from the original question
           - MUST come from previous sub-questions or thoughts
    """
    return instruction

def contract(dag, test_cases):
    instruction = """
        Generate a simplified intermediate form of the original problem based on the variable dependency analysis.
        
        You ast.arg given a directed acyclic graph (DAG) representing the dependencies between variables in the original code:
        {dag}
        
        And the original test cases:
        {test_cases}
        
        The simplified problem must be:
        1. Self-contained: The description must contain all information needed to solve itself, without requiring additional context from the original problem
        2. Test-time reduced: The simplified problem must require fewer reasoning steps by using intermediate variables from the original code as direct inputs
        
        Your task is to:
        1. Create a simplified version of the problem that starts with intermediate variables as inputs
        2. Generate new test cases that use these intermediate variables as parameters while maintaining the exact same expected outputs as in the original test cases
        
        Do not use any code examples in your simplified problem formulation.
    """
    formatter = r"Enclose the simplified problem within <question></question> tag and the new test cases (assert codes, use \n to split each case) within <test></test> tag"
    instruction += formatter
    prompt = instruction.format(dag=dag, test_cases=test_cases)
    return prompt

''' AST version
def decompose(code: str):
    tree = ast.parse(code)
    dag = nx.DiGraph()
    
    class DependencyVisitor(ast.NodeVisitor):
        def __init__(self):
            self.dependencies = []
            
        def visit_Name(self, node):
            if isinstance(node.ctx, ast.Load) and not hasattr(__builtins__, node.id):
                self.dependencies.append(node.id)
        
        def get_dependencies(self, node):
            self.dependencies = []
            if node:
                self.visit(node)
            return self.dependencies
    
  
    dependency_visitor = DependencyVisitor()
    def visit_node(node):
        if isinstance(node, ast.Assign):
            targets = [t.id for t in node.targets if isinstance(t, ast.Name)]
            dependencies = dependency_visitor.get_dependencies(node.value)
        
            for target in targets:
                dag.add_node(target)
                for dep in dependencies:
                    if dep != target: 
                        dag.add_node(dep)
                        dag.add_edge(dep, target)
        
        elif isinstance(node, ast.For):
            iterator = node.target.id if isinstance(node.target, ast.Name) else None
            for body_node in node.body:
                visit_node(body_node)
            if iterator:
                dag.add_node(iterator)
                
        elif isinstance(node, ast.AugAssign):
            if isinstance(node.target, ast.Name):
                target = node.target.id
                dependencies = dependency_visitor.get_dependencies(node.value)
                dependencies.append(target)
                
                dag.add_node(target)
                for dep in dependencies:
                    if dep != target:
                        dag.add_node(dep)
                        dag.add_edge(dep, target)
        
        elif isinstance(node, ast.FunctionDef):
            for arg in node.args.args:
                dag.add_node(arg.arg)
            
            for body_node in node.body:
                visit_node(body_node)
        
        elif isinstance(node, ast.Return):
            dependencies = dependency_visitor.get_dependencies(node.value)
            
            for dep in dependencies:
                dag.add_node(dep)
    
    for node in ast.walk(tree):
        visit_node(node)
    
   
    # 这里取消了无环的检查, 因为确实可能存在代码遍历依赖有环的情况, 而且不少
    # if not nx.is_directed_acyclic_graph(dag):
        # raise ValueError("Code contains cyclic dependencies!")
    dependency_list = []
    for node in dag.nodes():
        deps = list(dag.predecessors(node))
        if deps:
            dependency_list.append(f"{node} depends on {deps}")
    
    result = f"Original Code:\n{code}\n\n"
    result += f"Variables in the code: {', '.join(dag.nodes())}\n\n"
    result += "Variables Dependencies:\n"
    for dep in dependency_list:
        result += f"- {dep}\n"
    
    return result
'''


''' old version
def contract():
    instruction = """
        Generate a simplified intermediate form of tee original question based on the previous sub-questions or thoughts step by step.
        
        The previous sub-questions or thoughts with marked dependencies actually form a directed acyclic graph (DAG), where nodes whose dependencies is empty list can be regarded as independent sub-questions or thoughts.
        
        The simplified question must be:
        1. self-contained: The simplified question's description must contain all information needed to solve itself, without requiring additional information from the original question or reasoning trajectory
        2. test-time reduced: The simplified question must require fewer reasoning steps compared to the original question (these steps are reduced because these solved independent sub-problems or thoughts become known conditions in the simplified question or excluded as incorrect explorations)
        
    """
    formatter = "Last step, enclose the question within <question></question> tags"
    instruction += formatter
    return instruction
'''



def label(dag_text):
    instruction = """
        Please convert the text of decomposed directed acyclic graph into a JSON format:
        The decomposed DAG is: 
        {dag_text}

        Format your response as the following JSON objects:
        {{
            "thoughts":
                [
                    {{"thought_0": "<the first thought>", "dependencies": [indexes of thought_0 dependencies]}},
                    {{"thought_1": "<the second thought>", "dependencies": [indexes of thought_1 dependencies]}},
                    ...
                ]
        }}
"""
    return instruction.format(dag_text=dag_text)

def ensemble(question: str, solutions: list):
    instruction = """
        Here is the original problem:
        {question}

        Here are some reference solutions:
        {solutions}
        
        Ensemble the best answer from these solutions step by step.
    """
    formatter = "Last step, enclose the best solution index within <answer></answer> tags (0 or 1 or 2 for solution_1 or solution_2 or solution_3)"
    instruction += formatter
    
    solutions_str = ""
    for i, solution in enumerate(solutions):
        solutions_str += f"solution {i}: {solution}\n\n"
    prompt = instruction.format(question=question, solutions=solutions_str)
    return prompt
class extractor:
    def __init__(self, name: str):
        self.name = name
        
    def prompt(self, response: str, method: str):
        if self.name in ["direct", "multistep", "ensemble"]:
            return
        elif self.name == "contract":
            return 
        else:
            raise ValueError(f"Invalid extractor name: {self.name}")

    def extract(self, response: str, method: str):
        old_model = get_model()
        set_model("gpt-4o-mini")
        set_log(False)
        prompt = self.prompt(response, method)
        result = gen(prompt, response_format="text")
        set_model(old_model)
        set_log(True)
        return result
        

# utilization
def check(name: str, result):    
    if not isinstance(result, dict):
        return False

    if name in ["cot", "direct", "multistep", "ensemble"]:
        if not check_json(result, ["answer"]):
            return False
        if "<answer>" in str(result["answer"]):
            return False
    elif name == "contract":
        if not check_json(result, ["question", "test"]):
            return False
    elif name == "label":
        if not check_json(result, ["thoughts"]):
            return False
    return True
