import argparse
import json
import yaml
import os
import time
import shutil
import pandas as pd
from tqdm import tqdm
from decision_oaif.envs.webshop import parse_args as webenv_args
from decision_oaif.envs.webshop import WebEnv  
from decision_oaif.agents.openai_agent import OpenAIAgent
from decision_oaif.agents.hf_agent import HFAgent
from decision_oaif.utils.parser import parse_reason_and_action_webshop, substitute_placeholders

def parse_and_load_config():
    parser = argparse.ArgumentParser(description='Evaluate agent on webshop')
    # Arguments for evaluation 
    parser.add_argument('--eval_config', type=str, help='Path to the evaluation config file')
    # Arguments for rollout
    parser.add_argument('--training_config', type=str, help='Path to the training config file')
    parser.add_argument('--iter', type=int, help='Iteration number')

    args = parser.parse_args()

    if args.eval_config:
        with open(args.eval_config, 'r') as f:
            config = yaml.safe_load(f)
    elif args.training_config and args.iter is not None:
        with open(args.training_config, 'r') as f:
            training_config = yaml.safe_load(f)
        config = training_config['rollout_student_trajectory']
        config = substitute_placeholders(config, '{iter}', str(args.iter))
    else:
        parser.error("You must provide either --eval_config or both --training_config and --iter.")

    return config

def main():
    config = parse_and_load_config()

    dstdir = f"{config['logdir']}/{time.strftime('%Y%m%d-%H%M%S')}" if not config['exact_path'] else config['logdir']
    os.makedirs(dstdir, exist_ok=True)
    with open(os.path.join(dstdir, 'config.yaml'), 'w') as f:
        yaml.dump(config, f)
            
    df_summary = pd.DataFrame()

    for agent_config in config['agents']:
        # Load the correct agent
        if agent_config['type'] == "openai":
            agent = OpenAIAgent(model_id=agent_config['model_id'], prompt_template_file=agent_config['prompt_template_file'], verbose=config['verbose'], debug=config['debug'], parse_reason_action_fn=parse_reason_and_action_webshop)
        elif agent_config['type'] == "hf":
            agent = HFAgent(model_id=agent_config['model_id'], prompt_template_file=agent_config['prompt_template_file'],verbose=config['verbose'], debug=config['debug'], parse_reason_action_fn=parse_reason_and_action_webshop, max_length=6000)
        else:
            raise ValueError(f"Unsupported agent type: {agent_config['type']}")
        
        logdir = f"{dstdir}/{agent_config['model_id']}" if not config['exact_path'] else config['logdir'] 
        os.makedirs(logdir, exist_ok=True)

        # Load the Webshop configuration
        args = webenv_args()[0]
        env = WebEnv(args, split=config['eval_set'])
 
        env_idxs = env.goal_idxs[:min(config['max_env_idxs'], len(env.goal_idxs))] if config["max_env_idxs"] else env.goal_idxs 
        # Iterate over all goal_idxs
        for env_idx in tqdm(env_idxs, desc="env idxs"):
            obs, info = env.reset(idx=env_idx)
            if config["start_env_idx"] and (env_idx < config["start_env_idx"]):
                continue
            max_actions = config['max_actions']
            trajectory = []
            agent.reset()
            for _ in tqdm(range(max_actions), desc=f"Actions for env idx {env_idx}"):
                reason, action = agent.predict_reason_action(task="", observation=obs, candidate_actions=env.get_valid_actions())
                data = {'observation': obs, 'candidate_actions': env.get_valid_actions(), 'reason': reason, 'action': action}
                obs, reward, done, info = env.step(action)
                data['score'] = reward
                trajectory.append(data)
                if done:
                    break
            
            log_file_path = os.path.join(logdir, f"{env_idx}.json")
            log = {'env_idx': env_idx, 'trajectory': trajectory, 'info': info}
            with open(log_file_path, 'w') as log_file:
                json.dump(log, log_file, indent=4)
                     
            summary_data = {'env_idx': env_idx, 'model_id':agent_config['model_id'], 'num_actions': len(trajectory), 'score': reward}
            summary_file_path = os.path.join(dstdir, "summary.csv")
            if os.path.exists(summary_file_path):
                df_summary = pd.read_csv(summary_file_path)
            df_summary = pd.concat([df_summary, pd.DataFrame([summary_data])], ignore_index=True)
            df_summary.to_csv(summary_file_path, index=False)
            print(f"Current summary:\n {df_summary}")

if __name__ == "__main__":
    main()