# Main language agent class
import os
import json
from pathlib import Path
from tqdm import tqdm
import pickle

from agentdriver.memory.memory_agent import MemoryAgent
from agentdriver.perception.perception_agent import PerceptionAgent
from agentdriver.reasoning.reasoning_agent import ReasoningAgent
from agentdriver.planning.planning_agent import PlanningAgent

class LanguageAgent:
    def __init__(self, data_path, split, model_name = "gpt-3.5-turbo-0613", planner_model_name="", finetune_cot=False, verbose=False) -> None:
        self.data_path = data_path
        self.split = split
        self.split_dict = json.load(open(Path(data_path) / "split.json", "r"))
        self.tokens = self.split_dict[split]
        self.invalid_tokens = []
        self.verbose = verbose
        self.model_name = model_name
        self.planner_model_name = planner_model_name
        self.finetune_cot = finetune_cot

    def collect_planner_input(self, invalid_tokens=None):
        r"""Collect perception results, ego states, memory, and chain-of-thoughts reasoning as planner input"""
        print("*"*10 + "Stage-1: Collecting Perception Data and Memory Retrieval" + "*"*10)
        save_dir = Path("experiments") / Path("finetune")
        save_dir.mkdir(exist_ok=True)
        save_file = save_dir / f"data_samples_{self.split}.json"
        # Each time generate a new file
        if Path(save_file).exists() and invalid_tokens is None:
            os.system(f"rm {save_file}")

        run_tokens = invalid_tokens if invalid_tokens is not None else self.tokens
        for token in tqdm(run_tokens):
            if self.verbose:
                print(token)
            try:
                perception_agent = PerceptionAgent(token=token, split=self.split, data_path=self.data_path, model_name=self.model_name, verbose=self.verbose)
                ego_prompts, perception_prompts, working_memory = perception_agent.run()

                if self.verbose:
                    print(perception_prompts)

                memory_agent = MemoryAgent(data_path=self.data_path, model_name=self.model_name, verbose=self.verbose)
                commonsense_mem, experience_mem = memory_agent.run(working_memory)

                if self.verbose:
                    print(commonsense_mem)
                    print(experience_mem)

                reasoning_agent = ReasoningAgent(model_name=self.model_name, verbose=self.verbose)
                reasoning = reasoning_agent.run(perception_agent.data_dict, ego_prompts+perception_prompts, working_memory, use_cot_rules=self.finetune_cot)

                planning_agent = PlanningAgent(model_name=self.planner_model_name, verbose=self.verbose)
                planning_target = planning_agent.generate_target(perception_agent.data_dict)

                if self.verbose:
                    print(reasoning)
                    print(planning_target)

                planner_input = {
                    "token": token,
                    "ego": ego_prompts, 
                    "perception": perception_prompts,
                    "commonsense": commonsense_mem,
                    "experiences": experience_mem,
                    "planning_target": planning_target,
                }

                if self.finetune_cot: # if fine-tuning chain-of-thoughts, save rule-based chain_of_thoughts as target
                    planner_input["chain_of_thoughts"] = reasoning
                else: # otherwise, directly save reasoning results from GPT by in-context learning
                    planner_input["reasoning"] = reasoning

                with open(save_file, "a+") as f:
                    f.write(json.dumps(planner_input, indent=4) + '\n')

            except Exception as e:
                print("An error occurred:", e)
                self.invalid_tokens.append(token)
                print(f"Invalid token: {token}")
                continue
        
        print("*"*10 + "Stage-1: Complete" + "*"*10)
        print(f"Invalid tokens: {self.invalid_tokens}")
        with open(save_dir / f"invalid_tokens_{self.split}.json", "w") as f:
            json.dump(self.invalid_tokens, f)
        return

    def inference_single(self, token, data_sample=None):
        r"""Inference single scenario"""
        assert token in self.tokens
        if data_sample is None:
            perception_agent = PerceptionAgent(token=token, split=self.split, data_path=self.data_path, model_name=self.model_name, verbose=self.verbose)
            ego_prompts, perception_prompts, working_memory = perception_agent.run()
            memory_agent = MemoryAgent(data_path=self.data_path, model_name=self.model_name, verbose=self.verbose)
            commonsense_mem, experience_mem = memory_agent.run(working_memory)
            reasoning_agent = ReasoningAgent(model_name=self.model_name, verbose=self.verbose)
            reasoning = reasoning_agent.run(perception_agent.data_dict, ego_prompts+perception_prompts, working_memory, use_cot_rules=self.finetune_cot)
            planning_agent = PlanningAgent(model_name=self.planner_model_name, verbose=self.verbose)
            planning_target = planning_agent.generate_planning_target(perception_agent.data_dict)
            data_sample = {
                "token": token,
                "ego": ego_prompts,
                "perception": perception_prompts,
                "commonsense": commonsense_mem,
                "experiences": experience_mem,
                "chain_of_thoughts": reasoning,
                "reasoning": reasoning,
                "planning_target": planning_target,
            }
        else:
            token = data_sample["token"]        
            ego_prompts = data_sample["ego"]        
            perception_prompts = data_sample["perception"]        
            commonsense_mem = data_sample["commonsense"]        
            experience_mem = data_sample["experiences"]        
            reasoning = data_sample["reasoning"]        
            planning_target = data_sample["planning_target"]
           
        planning_traj = planning_agent.run(
            data_dict=perception_agent.data_dict,
            data_sample=data_sample,
        )
        if self.verbose:
            print(token)
            print(ego_prompts)
            print(perception_prompts)
            print(commonsense_mem)
            print(experience_mem)
            print(reasoning)
            print(planning_traj)
        return planning_traj

    def inference_all(self, data_samples, data_path, save_path):
        """Inferencing all scenarios"""
        planning_agent = PlanningAgent(model_name=self.planner_model_name, verbose=self.verbose)
        planning_traj_dict = planning_agent.run_batch(
            data_samples=data_samples,
            data_path=data_path,
            save_path=save_path,
        )
        return planning_traj_dict