import argparse
import json
import yaml
import os
import time
import pandas as pd
from tqdm import tqdm
from decision_oaif.agents.openai_agent import OpenAIAgent
from decision_oaif.agents.hf_agent import HFAgent
from decision_oaif.agents.hf_spaces_agent import HFSpaceAgent
from decision_oaif.utils.parser import parse_reason_and_action_intercode, parse_reason_and_action_intercode_sql, substitute_placeholders

from intercode.envs import (
    BashEnv, PythonEnv, SqlEnv, CTFEnv, SWEEnv
)
from typing import Dict, List
from decimal import Decimal
from datetime import datetime

def decimal_default(obj):
    if isinstance(obj, Decimal):
        return float(obj)  # or str(obj) if you prefer
    elif isinstance(obj, datetime):
        return obj.isoformat()  # Convert datetime to an ISO 8601 string
    # Delegate the rest to the default JSON encoder
    return json.JSONEncoder().default(obj)

def preprocess_sql(record: Dict) -> List:
    db = record["db"]
    return [f"use {db}"]
    
def serialize_if_dict(input_data):
    if isinstance(input_data, dict):
        # Convert the dictionary to a JSON string
        return json.dumps(input_data)
    elif isinstance(input_data, str):
        # Return the string as is
        return input_data
    else:
        raise ValueError("Input must be either a dictionary or a string")


def parse_and_load_config():
    parser = argparse.ArgumentParser(description='Evaluate agent on intercode')
    # Arguments for evaluation 
    parser.add_argument('--eval_config', type=str, help='Path to the evaluation config file')
    # Arguments for rollout
    parser.add_argument('--training_config', type=str, help='Path to the training config file')
    parser.add_argument('--collect_logs', action='store_true', help="Initial data collect for sft")

    parser.add_argument('--iter', type=int, help='Iteration number')

    args = parser.parse_args()

    if args.eval_config:
        with open(args.eval_config, 'r') as f:
            config = yaml.safe_load(f)
    elif (args.training_config is not None) and args.collect_logs:
        with open(args.training_config, 'r') as f:
            training_config = yaml.safe_load(f)
        config = training_config['collect_logs']
    elif args.training_config and args.iter is not None:
        with open(args.training_config, 'r') as f:
            training_config = yaml.safe_load(f)
        config = training_config['rollout_student_trajectory']
        config = substitute_placeholders(config, '{iter}', str(args.iter))
    else:
        parser.error("You must provide either --eval_config or both --training_config and --collect_logs or both --training_config and --iter.")

    return config

def main():
    config = parse_and_load_config()

    dstdir = f"{config['logdir']}/{time.strftime('%Y%m%d-%H%M%S')}" if not config['exact_path'] else config['logdir']
    os.makedirs(dstdir, exist_ok=True)
    with open(os.path.join(dstdir, 'config.yaml'), 'w') as f:
        yaml.dump(config, f)
            
    df_summary = pd.DataFrame()

    for agent_config in config['agents']:
        if config['env_type'] == "bash":
            env = BashEnv(image_name=config['env_image_name'], data_path=config['env_data_path'], verbose=config['env_verbose'], preprocess=config['env_preprocess'])
            # split in train and test
            all_env_idxs = range(len(env.data_loader))
            train_frac = 0.7
            split_idx = int(len(all_env_idxs) * train_frac)
            train_idxs = all_env_idxs[:split_idx]
            test_idxs = all_env_idxs[split_idx:]
            env_idxs = train_idxs if config['eval_set']=="train" else test_idxs
            parse_reason_action_fn = parse_reason_and_action_intercode
        elif config['env_type'] == "sql":            
            env = SqlEnv(image_name=config['env_image_name'], data_path=config['env_data_path'], verbose=config['env_verbose'], preprocess=preprocess_sql)
            # split in train and test
            all_env_idxs = range(len(env.data_loader))
            train_frac = 0.7
            split_idx = int(len(all_env_idxs) * train_frac)
            train_idxs = all_env_idxs[:split_idx]
            test_idxs = all_env_idxs[split_idx:]
            env_idxs = train_idxs if config['eval_set']=="train" else test_idxs
            parse_reason_action_fn = parse_reason_and_action_intercode_sql
        elif config['env_type'] == "python":            
            env = PythonEnv(image_name=config['env_image_name'], data_path=config['env_data_path'], verbose=config['env_verbose'], preprocess=None, is_agent=True)
            # split in train and test
            all_env_idxs = range(len(env.data_loader))
            train_idxs = all_env_idxs[511:]
            test_idxs = all_env_idxs[0:510]
            env_idxs = train_idxs if config['eval_set']=="train" else test_idxs
            parse_reason_action_fn = parse_reason_and_action_intercode
        
        # Load the correct agent
        if agent_config['type'] == "openai":
            agent = OpenAIAgent(model_id=agent_config['model_id'], prompt_template_file=agent_config['prompt_template_file'], verbose=config['verbose'], debug=config['debug'], parse_reason_action_fn=parse_reason_action_fn)
        elif agent_config['type'] == "hf":
            agent = HFAgent(model_id=agent_config['model_id'], prompt_template_file=agent_config['prompt_template_file'],verbose=config['verbose'], debug=config['debug'], parse_reason_action_fn=parse_reason_action_fn, max_length=6000)
        elif agent_config['type'] == "hf_space":
            agent = HFSpaceAgent(space_id=agent_config['space_id'], prompt_template_file=agent_config['prompt_template_file'], verbose=config['verbose'], debug=config['debug'], parse_reason_action_fn=parse_reason_action_fn)
        else:
            raise ValueError(f"Unsupported agent type: {agent_config['type']}")
        
        logdir = f"{dstdir}/{agent_config['model_id']}" if not config['exact_path'] else config['logdir'] 
        os.makedirs(logdir, exist_ok=True)
        
        env_idxs = env_idxs[:min(config['max_env_idxs'], len(env_idxs))] if config["max_env_idxs"] else env_idxs 
        for env_idx in tqdm(env_idxs, desc="env idxs"):
            while True:
                try:
                    obs, info = env.reset(env_idx)
                    break;
                except Exception as e:
                    print(f"Exception: {e}. Trying again")
            reward = 0
            if config["start_env_idx"] and (env_idx < config["start_env_idx"]):
                continue
            task = obs
            max_actions = config['max_actions']
            trajectory = []
            agent.reset()
            for _ in tqdm(range(max_actions), desc=f"Actions for env idx {env_idx}"):
                try:
                    reason, action = agent.predict_reason_action(task=task, observation=obs, candidate_actions="", reward=reward)
                except Exception as e:
                    break
                data = {'observation': obs, 'reason': reason, 'action': action}
                while True:
                    try:
                        obs, reward, done, info = env.step(action)
                        break;
                    except Exception as e:
                        print(f"Exception: {e}. Trying to step again")
                obs = str(obs)
                data['score'] = reward
                trajectory.append(data)
                if done:
                    break

            log_file_path = os.path.join(logdir, f"{env_idx}.json")
            log = {'env_idx': env_idx, 'trajectory': trajectory, 'task': task}

            with open(log_file_path, 'w') as log_file:
                json.dump(log, log_file, indent=4, default=decimal_default)
                     
            summary_data = {'env_idx': env_idx, 'model_id':agent_config['model_id'], 'num_actions': len(trajectory), 'score': reward}
            summary_file_path = os.path.join(dstdir, "summary.csv")
            if os.path.exists(summary_file_path):
                df_summary = pd.read_csv(summary_file_path)
            df_summary = pd.concat([df_summary, pd.DataFrame([summary_data])], ignore_index=True)
            df_summary.to_csv(summary_file_path, index=False)
            print(f"Current summary:\n {df_summary}")

if __name__ == "__main__":
    main()