import os
import uuid
from tqdm import tqdm
from collections import Counter
from argparse import ArgumentParser, Namespace
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from func_timeout import func_timeout, FunctionTimedOut
import litellm
import traceback
import io
litellm.suppress_debug_info = True
file_lock = threading.Lock()

from src.envs import get_env
from src.agent_factory import get_agent
from src.types import EnvRunResult, CostInfo, ValidationResult
from src.utils import count_agent_turns, save_checkpoint, display_metrics, update_checkpoint, load_results, dummy_error_result, get_ckpt_name
from validator import user_validator
from dotenv import load_dotenv


load_dotenv()

def parse_arguments() -> Namespace:
    parser = ArgumentParser()
    parser.add_argument("--env", type=str, required=True, choices=["mimic_iv_star", "eicu_star"], help="Environment name for fetching user instructions")
    parser.add_argument("--task_type", type=str, required=True, choices=["incre", "adapt"], help="Task type to use")
    parser.add_argument("--model", type=str, required=True, help="The agent model to use")
    parser.add_argument("--api_base", type=str, default=None, help="The API base to use")
    parser.add_argument("--agent_strategy", type=str, required=True, choices=[ "tool-calling",
                                                                                "tool-calling-no-tool",
                                                                                "tool-calling-no-web"], help="The agent strategy to use")
    parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature for the action model")
    parser.add_argument("--user_model", type=str, default='gemini/gemini-2.0-flash', help="The user model to use")
    parser.add_argument("--user_temperature", type=float, default=1.0, help="Sampling temperature for the user model")
    parser.add_argument("--user_strategy", type=str, default='hierreflection', choices=["llm", "react", "verifier", "reflection", "hierreflection"], help="The user strategy to use")
    parser.add_argument("--result_dir", type=str, default="results", help="Directory to save the results")
    parser.add_argument("--num_trials", type=int, default=5, help="Number of trials (k) to run")
    parser.add_argument("--max_concurrency", type=int, default=1, help="Maximum concurrency level")
    parser.add_argument("--start_index", type=int, default=0, help="Start index for tasks")
    parser.add_argument("--end_index", type=int, default=-1, help="End index for tasks (-1 for all)")
    parser.add_argument("--task_ids", nargs='+', type=int, default=None, help="Specific task ids to run")
    parser.add_argument("--max_agent_turns", type=int, default=30, help="Maximum number of agent turns")
    parser.add_argument("--max_retry", type=int, default=10, help="Maximum number of simulation retries")
    parser.add_argument("--timeout", type=int, default=600, help="Timeout for each task in seconds")
    parser.add_argument("--validation_model", type=str, default='gemini/gemini-2.5-flash', help="The validation model to use")
    parser.add_argument("--verbose", action="store_true", help="Print user-agent conversations during execution")
    parser.add_argument("--validation_trials", type=int, default=1, help="Number of validation trials")
    return parser.parse_args()


def run(config: Namespace):
    
    if any(x in config.model.lower() for x in ("llama", "qwen", "gpt-oss")):
        assert config.api_base is not None, f"api_base is required for {config.model}"

    ckpt_name = get_ckpt_name(config)
    ckpt_path = os.path.join(config.result_dir, ckpt_name + '.json')
    os.makedirs(config.result_dir, exist_ok=True)

    print(f"Loading user with strategy: {config.user_strategy}")

    env = get_env(
        env_name=config.env,
        task_type=config.task_type,
        user_strategy=config.user_strategy,
        user_model=config.user_model,
        user_temperature=config.user_temperature,
        api_base=config.api_base
    )
    agent = get_agent(
        tools_info=env.tools_info,
        model=config.model,
        api_base=config.api_base,
        temperature=config.temperature,
        agent_strategy=config.agent_strategy,
        rule=env.rule,
        verbose=config.verbose
    )    

    total_tasks = len(env.tasks)
    end_index = total_tasks if config.end_index == -1 else min(config.end_index, total_tasks)
    idx = config.task_ids if config.task_ids else list(range(config.start_index, end_index))
    
    task_info = f"{config.task_ids}" if config.task_ids else f"{config.start_index} to {end_index}"
    print(f"Running tasks: {task_info} (checkpoint path: {ckpt_path})")

    results = load_results(config, idx=[str(i) for i in idx])
    idx_to_run = idx * config.num_trials

    existing_idx = [int(r.task_id) for r in results if r.reward is not None]
    result_counter = Counter(idx_to_run) - Counter(existing_idx)
    idx_to_run = list(result_counter.elements())

    if len(idx_to_run) == 0:
        print("No new tasks to run. All tasks have been loaded from checkpoint.")
        display_metrics(results, config.num_trials)
        return

    save_checkpoint(ckpt_path, results)    

    def _run(task_idx: int) -> EnvRunResult:
        try:
            retry = 0
            result = None
            retry_reason = []

            while retry < config.max_retry:
                try:
                    
                    isolated_env = get_env(
                        env_name=config.env,
                        task_type=config.task_type,
                        user_strategy=config.user_strategy,
                        user_model=config.user_model,
                        user_temperature=config.user_temperature,
                        api_base=config.api_base,
                        task_index=str(task_idx),
                        retry_reason=retry_reason
                    )

                    response = func_timeout(
                        config.timeout,
                        agent.run,
                        args=(isolated_env, str(task_idx), config.max_agent_turns)
                    )

                    user_cost = isolated_env.user.get_total_cost()
                    result = EnvRunResult(
                        db_id=isolated_env.task.db_id,
                        task_type=isolated_env.task.task_type,
                        task_id=isolated_env.task.task_id,
                        sample_id=str(uuid.uuid4()),
                        reward=response.reward,
                        info=response.info,
                        messages=response.messages,
                        cost=CostInfo(
                            agent_cost=response.agent_cost,
                            user_cost=user_cost,
                            eval_cost=0.0,
                            total_cost=round(response.agent_cost + user_cost, 8),
                        ),
                        validation=None
                    )
                    validation_result = user_validator(
                        messages=response.messages,
                        env=isolated_env,
                        model=config.validation_model,
                        api_base=config.api_base,
                        n=config.validation_trials
                    )
                    
                    result.cost.eval_cost = round(result.cost.eval_cost + validation_result.eval_cost, 8)
                    result.cost.total_cost = round(result.cost.total_cost + validation_result.eval_cost, 8)
                    result.validation = validation_result
                    
                    if (validation_result.decision == 'user' or validation_result.decision == 'user_error') and count_agent_turns(response.messages) < config.max_agent_turns:
                        result.reward = None
                        update_checkpoint(ckpt_path, result, file_lock)
                        if 'hierr' in config.user_strategy:
                            retry_reason.append(eval(validation_result.reason)['explanation'])
                        retry += 1
                        print(
                            "⚠️ ",
                            f"Retry {retry}/{config.max_retry} |",
                            f"ckpt_path={ckpt_name}",
                            f"task_id={task_idx}",
                            f"User error during simulation: {validation_result.reason}"
                        )
                    else:
                        update_checkpoint(ckpt_path, result, file_lock)
                        break
                
                except FunctionTimedOut:
                    error_reason = f"Timeout during simulation (exceeded {config.timeout} seconds)"
                    result = dummy_error_result(isolated_env, 'no_error', error_reason, reward=0.0)
                    update_checkpoint(ckpt_path, result, file_lock)
                    print("❌", f"ckpt_path={ckpt_name}", f"task_id={task_idx}", error_reason)
                    break

            if result and result.reward == 1:
                print("✅", f"ckpt_path={ckpt_name}", f"task_id={task_idx}", result.info)
            elif result and result.reward == 0:
                print("❌", f"ckpt_path={ckpt_name}", f"task_id={task_idx}", result.info)

            print("-----")            
            return result
        
        except KeyboardInterrupt:
            print("Keyboard interrupt. Exiting...")
            exit(0)

        except Exception as e:
            tb_str = io.StringIO()
            traceback.print_exc(file=tb_str)
            error_details = tb_str.getvalue()
            error_reason = f"Unexpected error during simulation: {error_details}"
            result = dummy_error_result(isolated_env, 'other', error_reason)
            update_checkpoint(ckpt_path, result, file_lock)
            print("⚠️", f"ckpt_path={ckpt_name}", f"task_id={task_idx}", error_reason)
            return result

    max_workers = max(1, min(config.max_concurrency, len(idx_to_run)))
    new_results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(_run, t): t for t in idx_to_run}
        for _ in tqdm(as_completed(futures), total=len(futures), desc="Running"):
            pass
        for f in futures:
            try:
                new_results.append(f.result())
            except Exception:
                continue

    results.extend(new_results)
    display_metrics(results, config.num_trials)

if __name__ == "__main__":
    config = parse_arguments()
    run(config)
