import torch
import json
import os
import importlib
import argparse
from typing import List

from utils.conversation import Conversation
from utils.states import set_random_seed

from envs import Agent


AI_PROMPT = "\n\nAssistant:"

class SamplingWithVerifier:
    def __init__(self, args) -> None:
        self.device = torch.device('cuda')
        self.args = args

    def _sample_step(
        self,
        conv: Conversation,
        agent: Agent,
        max_new_tokens: int = 512,
    ):
        # ===== call LLM until the response is valid ===== #
        entries = None
        valid_response = False
        messages = conv.to_openai_api_messages()
        for _ in range(self.args.max_retries):
            completion = agent(messages, max_new_tokens)
            print(">> completion: ", completion)
            try:
                entries = agent.parse_entries(completion, agent.valid_format_entries)
                assert entries["Action"].strip() in agent.prompt_tool_names
                valid_response = True
            except Exception as e:
                print(AI_PROMPT + "\n" + completion + "\nObservation:\n" + str(e))
                print("Response is invalid and discarded")
                messages.extend([
                    {"role": "assistant", "content": completion},
                    {"role": "user", "content": f"Your response was in incorrect format. Please provide a valid response with all entries: " + ", ".join(agent.valid_format_entries) + "\n\n"},
                ])
            else:
                break
        if not valid_response:
            return "No valid response after max_retries"
        return completion


    
if __name__ == "__main__":
    parser = argparse.ArgumentParser("Run the tree search ...")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--task", type=str, default="cifar10")
    parser.add_argument("--benchmark", type=str, default="MLAgentBench")
    parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Llama-2-7b-hf")
    parser.add_argument("--generator_model_name_or_path", type=str, default="meta-llama/Llama-2-7b-hf")
    parser.add_argument("--model_max_length", type=int, default=10000)
    parser.add_argument("--trust_remote_code", action="store_true", default=True)
    parser.add_argument("--use_peft", action="store_true", default=False)
    parser.add_argument("--gradient_checkpointing", type=bool, default=True)
    parser.add_argument("--padding_side", type=str, default="right", help="left or right")
    parser.add_argument("--template_name", type=str, default="llama-2")
    parser.add_argument("--valid_format_entries", type=str, nargs="+", default=["Action", "Action Input"])

    parser.add_argument("--max_time", type=int, default=60*60)
    parser.add_argument("--agent_max_steps", type=int, default=10)
    parser.add_argument("--max_retries", type=int, default=5)
    parser.add_argument("--edit_script_llm_name", type=str, default="gpt-4o-mini")
    parser.add_argument("--edit_script_llm_max_tokens", type=int, default=4000)
    parser.add_argument("--resume", type=str, default=None, help="resume from a previous run")
    parser.add_argument("--resume_step", type=int, default=0, help="the step to resume from")

    parser.add_argument("--n_beam", type=int, default=2)
    parser.add_argument("--num_sampling_sequences", type=int, default=3)

    parser.add_argument("--log_dir", type=str, default="logs")
    parser.add_argument("--work_dir", type=str, default="./workspace")
    parser.add_argument("--python", type=str, default="./workspace", help="the python path")
    parser.add_argument("--api_key", type=str, default=None)
    parser.add_argument("--api_url", type=str, default=None)
    parser.add_argument("--controller_address", type=str, default="")
    args = parser.parse_args()

    cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', '')
    device_list = cuda_visible_devices.split(',') if cuda_visible_devices else []
    args.device = int(device_list[0]) if len(device_list) > 0 else None
    # save args into json
    with open(os.path.join(args.log_dir, "args.json"), "w") as f:
        json.dump(vars(args), f, indent=4)

    set_random_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    env_module = importlib.import_module(f"envs.{args.benchmark}.environment")
    ENV = getattr(env_module, "Environment")

    env = ENV(args.task, args)
    print("=====================================")
    research_problem, benchmark_folder_name = env.get_task_description()
    print("Benchmark folder name: ", benchmark_folder_name)
    print("Research problem: ", research_problem)

    agent = Agent(args, env)
    sampler = SamplingWithVerifier(args)

    agent.run_no_reward(env, sampler)