import argparse
import json

from vita.config import (
    DEFAULT_AGENT_IMPLEMENTATION,
    DEFAULT_LLM_AGENT,
    DEFAULT_LLM_ARGS_AGENT,
    DEFAULT_LLM_TEMPERATURE_AGENT,
    DEFAULT_LLM_TEMPERATURE_USER,
    DEFAULT_LLM_USER,
    DEFAULT_LLM_ARGS_USER,
    DEFAULT_LOG_LEVEL,
    DEFAULT_MAX_CONCURRENCY,
    DEFAULT_MAX_ERRORS,
    DEFAULT_MAX_STEPS,
    DEFAULT_NUM_TRIALS,
    DEFAULT_SEED,
    DEFAULT_USER_IMPLEMENTATION,
    DEFAULT_EVALUATION_TYPE,
    DEFAULT_ENABLE_THINK,
    DEFAULT_ENABLE_MEMORY,
    DEFAULT_MEMORY_LLM,
    DEFAULT_MEMORY_LLM_ARGS,
    DEFAULT_LLM_EVALUATOR,
    DEFAULT_LLM_EVALUATOR_ARGS,
    DEFAULT_LANGUAGE
)
from vita.data_model.simulation import RunConfig
from vita.run import get_options, run_domain
import litellm


def dict_type(arg_string):
    """Convert a JSON string to a dictionary or return the input if it's already a dict."""
    # If the input is already a dict, return it as is
    if isinstance(arg_string, dict):
        return arg_string
    
    # If it's a string, try to parse it
    if isinstance(arg_string, str):
        try:
            # First try to parse as JSON
            return json.loads(arg_string)
        except json.JSONDecodeError:
            # If JSON parsing fails, try to evaluate as Python literal
            try:
                import ast
                return ast.literal_eval(arg_string)
            except (ValueError, SyntaxError):
                raise argparse.ArgumentTypeError(f"Invalid dictionary format: {arg_string}")
    
    # If it's neither a dict nor a string, raise an error
    raise argparse.ArgumentTypeError(f"Expected dict or string, got {type(arg_string)}")


def add_run_args(parser):
    """Add run arguments to a parser."""
    domains = get_options().domains
    parser.add_argument(
        "--domain",
        "-d",
        type=str,
        default="delivery,instore,ota",
        help="The domain to run the simulation on",
    )
    parser.add_argument(
        "--num-trials",
        type=int,
        default=DEFAULT_NUM_TRIALS,
        help="The number of times each task is run. Default is 1.",
    )
    parser.add_argument(
        "--agent",
        type=str,
        default=DEFAULT_AGENT_IMPLEMENTATION,
        choices=get_options().agents,
        help=f"The agent implementation to use. Default is {DEFAULT_AGENT_IMPLEMENTATION}.",
    )
    parser.add_argument(
        "--agent-llm",
        type=str,
        default=DEFAULT_LLM_AGENT,
        help=f"The LLM to use for the agent. Default is {DEFAULT_LLM_AGENT}.",
    )
    parser.add_argument(
        "--agent-llm-args",
        type=dict_type,
        default=DEFAULT_LLM_ARGS_AGENT,
        help=f"The arguments to pass to the LLM for the agent. Default is temperature={DEFAULT_LLM_TEMPERATURE_AGENT}.",
    )
    parser.add_argument(
        "--user",
        type=str,
        choices=get_options().users,
        default=DEFAULT_USER_IMPLEMENTATION,
        help=f"The user implementation to use. Default is {DEFAULT_USER_IMPLEMENTATION}.",
    )
    parser.add_argument(
        "--user-llm",
        type=str,
        default=DEFAULT_LLM_USER,
        help=f"The LLM to use for the user. Default is {DEFAULT_LLM_USER}.",
    )
    parser.add_argument(
        "--user-llm-args",
        type=dict_type,
        default=DEFAULT_LLM_ARGS_USER,
        help=f"The arguments to pass to the LLM for the user. Default is temperature={DEFAULT_LLM_TEMPERATURE_USER}.",
    )
    parser.add_argument(
        "--task-set-name",
        type=str,
        default=None,
        choices=get_options().task_sets,
        help="The task set to run the simulation on. If not provided, will load default task set for the domain.",
    )
    parser.add_argument(
        "--task-ids",
        type=str,
        nargs="+",
        help="(Optional) run only the tasks with the given IDs. If not provided, will run num_tasks tasks.",
    )
    parser.add_argument(
        "--num-tasks",
        type=int,
        default=None,
        help="The number of tasks to run.",
    )
    parser.add_argument(
        "--max-steps",
        type=int,
        default=DEFAULT_MAX_STEPS,
        help=f"The maximum number of steps to run the simulation. Default is {DEFAULT_MAX_STEPS}.",
    )
    parser.add_argument(
        "--evaluation-type",
        type=str,
        default=DEFAULT_EVALUATION_TYPE,
        choices=["trajectory", "trajectory_wo_user", "nl_rubrics", "all", "all_types", "trajectory_ablation1", "trajectory_ablation2", "trajectory_ablation3"],
        help=f"The type of evaluation to use. Choices: trajectory, trajectory_wo_user, nl_rubrics, all, all_types (runs all evaluation types separately).",
    )
    parser.add_argument(
        "--max-errors",
        type=int,
        default=DEFAULT_MAX_ERRORS,
        help=f"The maximum number of tool errors allowed in a row in the simulation. Default is {DEFAULT_MAX_ERRORS}.",
    )
    parser.add_argument(
        "--save-to",
        type=str,
        required=False,
        help="The path to save the simulation results. Will be saved to data/simulations/<save_to>.json. If not provided, will save to <domain>_<agent>_<user>_<llm_agent>_<llm_user>_<timestamp>.json. If the file already exists, it will try to resume the run.",
    )
    parser.add_argument(
        "--max-concurrency",
        type=int,
        default=DEFAULT_MAX_CONCURRENCY,
        help=f"The maximum number of concurrent simulations to run. Default is {DEFAULT_MAX_CONCURRENCY}.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=DEFAULT_SEED,
        help=f"The seed to use for the simulation. Default is {DEFAULT_SEED}.",
    )
    parser.add_argument(
        "--log-level",
        type=str,
        default=DEFAULT_LOG_LEVEL,
        help=f"The log level to use for the simulation. Default is {DEFAULT_LOG_LEVEL}.",
    )
    parser.add_argument(
        "--re-evaluate-file",
        type=str,
        help="Path to simulation file for re-evaluation mode. If provided, will re-evaluate the simulations from this file instead of running new ones.",
    )
    parser.add_argument(
        "--csv-output",
        type=str,
        help="Path to CSV file to append results. If provided, will append all simulation results to this CSV file after completion.",
    )
    parser.add_argument(
        "--enable-think",
        action="store_true",
        help="Enable think mode for the agent. Default is False.",
    )
    parser.add_argument(
        "--enable-memory",
        action="store_true",
        help="Enable memory compression for the agent. Default is False.",
    )
    parser.add_argument(
        "--memory-llm",
        type=str,
        default=DEFAULT_MEMORY_LLM,
        help=f"The LLM to use for memory compression. Default is {DEFAULT_MEMORY_LLM}.",
    )
    parser.add_argument(
        "--memory-llm-args",
        type=dict_type,
        default=DEFAULT_MEMORY_LLM_ARGS,
        help=f"The arguments to pass to the LLM for memory compression. Default is {DEFAULT_MEMORY_LLM_ARGS}.",
    )
    parser.add_argument(
        "--language",
        type=str,
        choices=["chinese", "english"],
        default=DEFAULT_LANGUAGE,
        help="The language to use for prompts and tasks. Choices: chinese, english. Default is chinese.",
    )
    parser.add_argument(
        "--evaluator-llm",
        type=str,
        default=DEFAULT_LLM_EVALUATOR,
        help=f"The LLM to use for evaluation. Default is {DEFAULT_LLM_EVALUATOR}.",
    )
    parser.add_argument(
        "--evaluator-llm-args",
        type=dict_type,
        default=DEFAULT_LLM_EVALUATOR_ARGS,
        help=f"The arguments to pass to the LLM for evaluation. Default is {DEFAULT_LLM_EVALUATOR_ARGS}.",
    )
    parser.add_argument(
        "--re-run",
        action="store_true",
        help="Re-run tasks specified by --task-ids. If used with --re-evaluate-file, will re-run specified tasks and then re-evaluate all tasks together.",
    )


def main():
    parser = argparse.ArgumentParser(description="vita command line interface")
    subparsers = parser.add_subparsers(dest="command", help="Available commands")

    # Run command
    run_parser = subparsers.add_parser("run", help="Run a benchmark")
    add_run_args(run_parser)
    run_parser.set_defaults(
        func=lambda args: run_domain(
            RunConfig(
                domain=args.domain,
                task_set_name=args.task_set_name,
                task_ids=args.task_ids,
                num_tasks=args.num_tasks,
                agent=args.agent,
                llm_agent=args.agent_llm,
                llm_args_agent=args.agent_llm_args,
                user=args.user,
                llm_user=args.user_llm,
                llm_args_user=args.user_llm_args,
                num_trials=args.num_trials,
                max_steps=args.max_steps,
                evaluation_type=args.evaluation_type,
                max_errors=args.max_errors,
                save_to=args.save_to,
                max_concurrency=args.max_concurrency,
                seed=args.seed,
                log_level=args.log_level,
                re_evaluate_file=getattr(args, 're_evaluate_file', None),
                csv_output_file=getattr(args, 'csv_output', None),
                enable_think=args.enable_think,
                enable_memory=args.enable_memory,
                memory_llm=args.memory_llm,
                memory_llm_args=args.memory_llm_args,
                language=args.language,
                llm_evaluator=args.evaluator_llm,
                llm_args_evaluator=args.evaluator_llm_args,
                re_run=getattr(args, 're_run', False)
            )
        )
    )

    # View command
    view_parser = subparsers.add_parser("view", help="View simulation results")
    view_parser.add_argument(
        "--file",
        type=str,
        help="Path to the simulation results file to view",
    )
    view_parser.add_argument(
        "--only-show-failed",
        action="store_true",
        help="Only show failed tasks.",
    )
    view_parser.add_argument(
        "--only-show-all-failed",
        action="store_true",
        help="Only show tasks that failed in all trials.",
    )
    view_parser.set_defaults(func=lambda args: run_view_simulations(args))

    # Domain command
    domain_parser = subparsers.add_parser("domain", help="Show domain documentation")
    domain_parser.add_argument(
        "domain",
        type=str,
        help="Name of the domain to show documentation for (e.g., 'ota', 'delivery', 'instore')",
    )
    domain_parser.set_defaults(func=lambda args: run_show_domain(args))

    # Start command
    start_parser = subparsers.add_parser("start", help="Start all servers")
    start_parser.set_defaults(func=lambda args: run_start_servers())

    args = parser.parse_args()
    litellm.custom_provider_map = []
    if not hasattr(args, "func"):
        parser.print_help()
        return

    args.func(args)


def run_view_simulations(args):
    from vita.scripts.view_simulations import main as view_main

    view_main(
        sim_file=args.file,
        only_show_failed=args.only_show_failed,
        only_show_all_failed=args.only_show_all_failed,
    )


def run_show_domain(args):
    from vita.scripts.show_domain_doc import main as domain_main

    domain_main(args.domain)


def run_start_servers():
    from vita.scripts.start_servers import main as start_main

    start_main()


def run_re_evaluate_simulation(args):
    from vita.run import re_evaluate_simulation
    from vita.utils.display import ConsoleDisplay
    
    ConsoleDisplay.console.print(f"[bold green]Re-evaluating simulations from: {args.file}[/bold green]")
    ConsoleDisplay.console.print(f"[bold blue]Evaluation type: {args.evaluation_type}[/bold blue]")

    try:
        results = re_evaluate_simulation(
            simulation_file_path=args.file,
            evaluation_type=args.evaluation_type,
            save_to=args.save_to,
        )
        
        ConsoleDisplay.console.print(f"\n✨ [bold green]Re-evaluation completed successfully![/bold green]")
        ConsoleDisplay.console.print(f"📊 [bold blue]Statistics:[/bold blue]")
        ConsoleDisplay.console.print(f"  📝 Total simulations re-evaluated: {len(results.simulations)}")
        
        if args.save_to:
            ConsoleDisplay.console.print(f"  💾 Results saved to: {args.save_to}")
        else:
            ConsoleDisplay.console.print(f"  💾 Results saved to: <original_name>_re_eval_{args.evaluation_type}.json")
        
        ConsoleDisplay.console.print(f"\nTo review the re-evaluation results, run: [bold blue]vita view --file <output_file>[/bold blue]")
        
    except Exception as e:
        ConsoleDisplay.console.print(f"[bold red]Error during re-evaluation: {e}[/bold red]")
        raise e


if __name__ == "__main__":
    main()
