import os
import sys
from split import get_cot,ave_length
from tqdm import tqdm

project_root = os.environ.get("PROJECT_ROOT")
if project_root and project_root not in sys.path:
    sys.path.append(project_root)
from env import Agent

data_root = os.environ.get("DATA_ROOT")

def prompt_problem(Problem, Thinking):
    return f"""
Problem P = {Problem}
Thinking T = {Thinking}

Please analyze whether the given thought chain T for problem P contains a shift in thinking or is merely an extension of the current thinking. Respond directly in "yes" or "no" format without explanation.
"""

def prompt_thinking(Thinking_pre, Thinking_cur):
    return f"""
You are given two consecutive reasoning steps from a problem-solving process.

Previous step (P): {Thinking_pre}
Current step (C): {Thinking_cur}

Question: Does the current step (C) represent a new shift in reasoning compared to the previous step (P), or is it just a continuation/extension of the previous reasoning? 

Please answer only "yes" (if C is a new shift in reasoning) or "no" (if C is just a continuation of P), with no explanation.
"""

def split_agent(split_ori, agent_path, cuda_id, split_method):
    split_results = [[] for _ in range(len(split_ori))]
    if split_method == "agent":
        agent = Agent(agent_path, is_anyprecision=False, device=f"cuda:{cuda_id}")
        for i in tqdm(range(len(split_ori)), desc="Processing blocks"):
            for idx, part in enumerate(split_ori[i]):
                if idx == 0:
                    split_results[i].append(part)
                else:
                    prompt = prompt_thinking(split_results[i][-1], part)
                    response = agent(prompt)
                    if response == "yes" or response == "Yes":
                        split_results[i].append(part)
                    else:
                        split_results[i][-1] += part
                # print(split_results[i])
    elif split_method == "original":
        split_results = split_ori
    
    print(ave_length(split_results))
    return split_results

if __name__ == "__main__":
    cot = get_cot('./nohup.out')
    print(ave_length(cot))
    split_agent(cot, f"{data_root}/Qwen2.5-7B-Instruct/", 4)