import os
import json
from datetime import datetime
from envs.env_manager import EnvManager
from planner.infoseeker import InfoSeekerPlanner

class InfoSeekerRunner():
    def __init__(self, task, task_case, max_steps, max_attempts):
        self.max_steps = max_steps
        self.max_attempts = max_attempts
        self.env = EnvManager(task, task_case)

    def rollout(self, actions: list[str], steps):
        cur_steps = steps
        history = ""
        finished = False
        for action in actions:
            finished, observation, cur_steps = self.env.step(action)
            history += f"- Act: {action}\n- Obs: {observation}\n"
            if finished or cur_steps >= self.max_steps:
                break
        history = history[:-1]
        return finished, history, cur_steps

    def run_once(self):
        planner = InfoSeekerPlanner(self.env.env_type)
        domain_desc = self.env.start()
        finished = False
        steps = 0
        attempts = 0
        
        for _ in range(self.max_attempts):
            # Information Seeking
            seeking_plans = planner.information_seeking(domain_desc)
            
            # Rollout Information Seeking
            for i, plan in enumerate(seeking_plans):
                finished, history, steps = self.rollout(plan['Action Plan'], steps)
                if finished or steps >= self.max_steps:
                    return finished, steps, attempts, planner.llm_history
                planner.action_history += f"\n## Step {i+1}: **{plan['Goal']}**\n"
                planner.action_history += history

            # Task Oriented Plannig
            domain_desc = self.env.get_desc()
            goal = self.env.get_goal()
            task_oriented_plan = planner.task_oriented_plannig(domain_desc, goal)
            attempts += 1

            # Rollout Task Oriented Plannig
            finished, history, steps = self.rollout(task_oriented_plan, steps)
            if finished or steps >= self.max_steps:
                return finished, steps, attempts, planner.llm_history

            # Update history to previous plan
            planner.action_history = history

        return finished, steps, attempts, planner.llm_history

class Runner():
    def __init__(self, task, max_steps, max_attempts):
        self.success = []
        self.steps = []
        self.attempts = []
        self.llm_logs = []
        self.task = task     
        self.max_steps = max_steps
        self.max_attempts = max_attempts

        self.timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        os.makedirs("./saved_runs", exist_ok=True)

    def run_all(self):
        print(f"Running {self.task}")
        for i in range(1, 51):
            run = InfoSeekerRunner(self.task, i, self.max_steps, self.max_attempts)
            success, steps, attempts, llm_logs = run.run_once()
            self.steps.append(steps)
            self.attempts.append(attempts)
            self.llm_logs.append(llm_logs)
            if success:
                self.success.append(1)
                print(f"Case {i}: Succes! (SR: {sum(self.success)/len(self.success)})")
            else:
                self.success.append(0)
                print(f"Case {i}: Failed! (SR: {sum(self.success)/len(self.success)})")
            
            with open(f"./saved_runs/{self.task}_{self.timestamp}.json", 'w') as f:
                log = {
                    "success_rate": sum(self.success)/len(self.success),
                    "avg_steps": sum(self.steps)/len(self.steps),
                    "avg_attempts": sum(self.attempts)/len(self.attempts),
                    "success": self.success,
                    "steps": self.steps,
                    "attempts": self.attempts,
                    "llm_logs": self.llm_logs
                }
                json.dump(log, f, indent=4)

            with open(f"./saved_runs/{self.task}_{self.timestamp}.txt", 'a') as f:
                f.write(f"### Case {i} ###\n")
                for idx, entry in enumerate(llm_logs, start=1):
                    f.write(f"=== PROMPT {idx} ===\n{entry['prompt']}\n\n")
                    f.write(f"--- RESPONSE {idx} ---\n{entry['response']}\n\n")

        print(f"Finished {self.task}. SR: {sum(self.success)*100/len(self.success)} %")