import random
import json
import os

from validator_agent import validator_agent
from student_agent import Student_agent
from teacher_agent import Teacher_agent
from topStudent_agent import topStudent_agent
from langchain_community.chat_models import ChatOpenAI
os.environ["OPENAI_API_KEY"] = ''
# 设置随机种子
random.seed(66)

class TreeNode:
    def __init__(self, name, depth=1):
        self.name = name  # 节点名称
        self.children = []  # 子节点列表
        self.value = None  # 节点的值（0或1）
        self.depth = depth  # 节点深度
        self.history = []
        self.parent = None

    def add_child(self, child_node):
        self.children.append(child_node)  # 添加子节点

    def set_value(self, value):
        self.value = value  # 设置节点的值

    def propagate_value_to_ancestors(self):
        """
        如果当前节点值为 1，则向上递归设置所有祖先节点的值为 1。
        """
        if self.value == 1 and hasattr(self, "parent") and self.parent:
            if self.parent.value is None:  # 如果父节点尚未被赋值
                self.parent.set_value(1)
            self.parent.propagate_value_to_ancestors()
    
    def get_path(self):

        if not self.parent:  # 如果当前节点没有父节点
            return self.name
        return f"{self.parent.get_path()}->{self.name}"

    def get_full_path(self):
        """
        如果当前节点是叶子节点，返回从根节点到当前节点的路径。
        如果当前节点不是叶子节点，返回从根节点到某个叶子节点的路径。
        """
        def find_leaf(node):
            if not node.children:
                return node
            else:
                child = node.children[0]
                return find_leaf(child)
        
        
        
        leaf_node = find_leaf(self)
        return leaf_node.get_path()
    


def build_forest_from_string(input_string):
    stack = []  # 用于追踪节点的栈
    start_idx = 0
    roots = []  # 用于保存所有的根节点

    while start_idx < len(input_string):
        # 查找每个节点前的 "#"，然后根据 "#" 数量来判断层级
        hash_count = 0
        while start_idx < len(input_string) and input_string[start_idx] == '#':
            hash_count += 1
            start_idx += 1
        
        # 跳过空格和其他字符
        while start_idx < len(input_string) and input_string[start_idx] == ' ':
            start_idx += 1
        
        # 查找当前节点的名字（直到遇到“（待预测）”）
        name_end_idx = input_string.find('（待预测）', start_idx)
        if name_end_idx == -1:  # 如果没有找到“（待预测）”，取到字符串末尾
            name_end_idx = len(input_string)
        
        name = input_string[start_idx:name_end_idx].strip()
        
        # 创建节点
        node = TreeNode(name, depth=hash_count)

        if hash_count == 1:
            # 如果层级为 1，表示是一个根节点
            roots.append(node)
            node.value=1  # 将根节点添加到 roots 列表
            stack = [node]  # 初始化栈，栈顶是根节点
        else:
            # 根据层级找到父节点
            while len(stack) >= hash_count:
                stack.pop()  # 弹出比当前层级高的节点
            # 当前栈顶的节点是当前节点的父节点
            parent = stack[-1]
            parent.add_child(node)  # 将当前节点作为父节点的子节点
            node.parent = parent  # 设置子节点的父节点
            stack.append(node)  # 将当前节点推入栈中
        
        # 更新 start_idx 为下一个节点的开始位置
        start_idx = name_end_idx + len('（待预测）')

        if start_idx == 0:  # 如果找不到 '（待预测）'，说明处理完毕，跳出循环
            break
    
    return roots


def calculate_max_depth(roots):
    """
    计算森林的最大深度。
    """
    def get_depth(node):
        if not node.children:
            return node.depth
        return max(get_depth(child) for child in node.children)

    return max(get_depth(root) for root in roots)


def ask_value(node_name, path, question, analysis, state, use_top_student):
    #.step 3 4 5 分别是题目、解析、最终答案
    history_last = "无"
    if use_top_student:
        for i in range(3):
            try:
                teacher_ans, history_this = teacher.forward(node_name, path, question, analysis, history_last)
            except Exception as e:
                return -1, "诊断失败"

            answer2 = topStudent.forward(teacher_ans.step3)
            
            vad = validator.forward(teacher_ans.step3, teacher_ans.step5, answer2)

            if 'True' in vad:
                history_this = f"{node_name}  {history_this} 生成次数: {i + 1}"
                break
            else:
                history_last = history_this
    else:
        try:
            teacher_ans, history_this = teacher.forward(node_name, path, question, analysis, history_last)
        except Exception as e:
            return -1, "诊断失败"
        history_this = f"{node_name}  {history_this} 生成次数: 1"

    #考察某个链路上的某个知识点。
    #1教师生成题目和答案
    #2top学生做一份答案
    #3validator验证两份答案是否等价：是，则进行下一步，否，则返回到步骤1
    #4student回答问题
    #5validator判断学生回答是否正确:返回0或者1
    student_ans = student.forward(question, state, teacher_ans.step3, node_name)
    vad_stu = validator.forward(teacher_ans.step3, teacher_ans.step5, student_ans)
    history_this += f" <student_ans> {student_ans}"
    if 'True' in vad_stu:
        return 1, history_this
    else:
        return 0, history_this


def assign_values_by_depth(roots, question, analysis, state, is_depend,use_top_student):
    """
    从最深的层开始处理节点值。
    """
    nodes_by_depth = {}

    # 按深度收集所有节点
    def collect_nodes(node):
        if node.depth not in nodes_by_depth:
            nodes_by_depth[node.depth] = []
        nodes_by_depth[node.depth].append(node)
        for child in node.children:
            collect_nodes(child)

    for root in roots:
        collect_nodes(root)

    # 按深度从大到小遍历节点
    for depth in sorted(nodes_by_depth.keys(), reverse=True):
        for node in nodes_by_depth[depth]:
            if node.value is None:  # 如果节点尚未赋值
                path = node.get_full_path()
                res, his = ask_value(node.name, path, question, analysis, state, use_top_student)
                node.set_value(res)
                node.history.append(his)
                if node.value == 1 and is_depend:
                    node.propagate_value_to_ancestors()  # 如果当前节点值为 1，向上传播值

def collect_node_histories(node, history_list):
    """
    递归遍历节点及其子节点，收集每个节点的历史信息。
    """
    history_list.extend(node.history)  # 将当前节点的历史信息添加到列表中
    for child in node.children:
        collect_node_histories(child, history_list)  # 递归遍历子节点

def print_forest_with_depth(roots):
    """
    打印森林，展示每个节点的深度和值。
    """
    def print_tree(node, level=0):
        print("  " * level + f"{node.name} ({node.value}) [深度: {node.depth}]")
        for child in node.children:
            print_tree(child, level + 1)

    for root in roots:
        print_tree(root)

def print_forest_hash_format(roots):
    """
    将森林以 `#` 层次格式输出为字符串。
    """
    output_str = ""
    for root in roots:
        output_str += traverse_tree_hash_format(root)
    return output_str.strip()  # 去除多余换行符


def traverse_tree_hash_format(node, level=0):
    """
    递归构建树的层次字符串格式。
    """
    output_str = f"{'#' * (level + 1)} {node.name} （{node.value}）\n"
    for child in node.children:
        output_str += traverse_tree_hash_format(child, level + 1)
    return output_str


def process_json_with_depth(input_file, output_file, is_depend, use_top_student):
    """
    处理 JSON 文件，构建树并按照深度优先规则分配值。
    """
    with open(input_file, 'r', encoding='utf-8') as infile:
        data = json.load(infile)

    
        
        
    first = True
    existing_data = []
    try:
        with open(output_file, 'r', encoding='utf-8') as infile:
            content = infile.read().strip()
            if content:
                first = False
                # 如果文件内容不以]结尾，手动添加
                if not content.endswith(']'):
                    content = content.rstrip(',\n') + ']'
                existing_data = json.loads(content)
    except (FileNotFoundError, json.JSONDecodeError):
        pass  # 如果文件不存在或为空，从0开始
    
    skip_count = len(existing_data)
    
    with open(output_file, 'w', encoding='utf-8') as outfile:
        outfile.write('[\n')  # 开始JSON数组
        
        # 先写入已有数据
        if existing_data:
            for i, item in enumerate(existing_data):
                outfile.write(json.dumps(item, ensure_ascii=False,indent=2))
                if i < len(existing_data) - 1:
                    outfile.write(',\n')
            outfile.flush()
        
        first = skip_count == 0  # 如果已有数据，第一个新数据需要加逗号
        for example in data[skip_count:]:  # 跳过已处理的数据

            key_map = example['knowledge_points_sequence']
            question = example['question']
            analysis = example['analysis']
            state = example['state']
            roots = build_forest_from_string(key_map)

            # 记录最大深度
            max_depth = calculate_max_depth(roots)
            example['max_depth'] = max_depth

            # 分配节点值
            assign_values_by_depth(roots, question, analysis, state,is_depend, use_top_student)

            # 打印最终树结构
            example['predict'] = print_forest_hash_format(roots)
            history_list = []
            for root in roots:
                collect_node_histories(root, history_list)  # 收集所有节点的历史信息

            example['history'] = history_list

            # 流式写入当前example
            if not first:
                outfile.write(',\n')
            json.dump(example, outfile, ensure_ascii=False, indent=2)
            outfile.write(',\n')
            first = False

        outfile.write('\n]')  # 结束JSON数组


# 示例调用
llm = ChatOpenAI(
    model_name="gpt-4o",
    temperature=0,
    openai_api_base="https://api.xty.app/v1",
    max_tokens=300
)
student = Student_agent(llm)
teacher = Teacher_agent(llm)
topStudent = topStudent_agent(llm)
validator = validator_agent(llm)
input_dir = ''
json_files = ['physics.json']
def modify_path(file_path):
    """
    修改文件路径，将路径中的 `3data` 替换为 `4data`。
    """
    return file_path.replace('4data', '5data_all_ask')
for json_file in json_files:
    file_path = os.path.join(input_dir, json_file)
    output_file_path = modify_path(file_path)
    os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
    process_json_with_depth(file_path, output_file_path,is_depend=False, use_top_student=True)
