import os
os.environ['OMNIGIBSON_HEADLESS'] = '1'

from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()

from og_ego_prim.utils.monkey_patch import add_monkey_patch
add_monkey_patch()

import argparse
import os
import re
import sys

import omnigibson as og
from omnigibson.macros import gm
import shutil
import time
import torch

from og_ego_prim.benchmark import build_benchmark
from og_ego_prim.models import PlanningAgent
from og_ego_prim.models.server_inference import ServerClient

# src package is available via og_ego_prim/__init__.py
from src.guardrail import EMBGuard


# Don't use GPU dynamics and use flatcache for performance boost
gm.USE_GPU_DYNAMICS = True
gm.HEADLESS = True
# gm.ENABLE_FLATCACHE = True

parser = argparse.ArgumentParser()
parser.add_argument('--try_id', type=str, default=None)
# parser.add_argument('--try_id', type=bool, default=True)

parser.add_argument('--task', type=str, default=None)
parser.add_argument('--scene', type=str, default=None)
parser.add_argument('--model', type=str, default=None, help="If not local llm, referece to the model_id, if local_llm, referece to the local model path.")
parser.add_argument('--local_llm_serve', action='store_true')
parser.add_argument('--local_serve_ip', type=str, default="")
parser.add_argument('--local_serve_key', type=str, default="sk-123456")
parser.add_argument('--use_hf_model', action='store_true')
parser.add_argument('--work_dir', type=str, default='./work_dir')

parser.add_argument('--draw_bbox_2d', action='store_true')
parser.add_argument('--use_initial_setup', action='store_true')
parser.add_argument('--use_self_caption', action='store_true')
parser.add_argument('--online_object_sampling', type=bool, default=None)
parser.add_argument('--debug', action='store_true')

parser.add_argument('--not_eval_process_safety', action='store_true')
parser.add_argument('--not_eval_termination_safety', action='store_true')
parser.add_argument('--not_eval_awareness', action='store_true')
parser.add_argument('--not_eval_execution', action='store_true')
parser.add_argument('--prompt_setting', type=str, default='default')
parser.add_argument('--robot_ego_view', action='store_true')
parser.add_argument('--guardrail_model', type=str, default=None)
parser.add_argument('--guardrail_type', type=str, default='gpt')
parser.add_argument('--enable_guardrail', type=str, default='false')


def create_guardrail_config(guardrail_model: str, guardrail_type: str):
    """
    Create EMBGuard configuration for guardrail
    
    Args:
        guardrail_model: Model name (e.g., gpt-4o, claude-3-opus-20240229)
        guardrail_type: Type ('gpt', 'openrouter', 'claude', 'gemini')
    
    Returns:
        Tuple[provider, model_config]: provider and model_config dictionary
    """
    if guardrail_type == 'claude':
        provider = 'claude'
        model_config = {
            'model_name': guardrail_model,
            'api_key': os.environ.get('CLAUDE_API_KEY'),
        }
    elif guardrail_type == 'gemini':
        provider = 'gemini'
        model_config = {
            'model_name': guardrail_model,
            'api_key': os.environ.get('GEMINI_API_KEY', ''),
        }
    elif guardrail_type == 'openrouter':
        provider = 'openrouter'
        model_config = {
            'model_name': guardrail_model,
            'api_key': os.environ.get('OPENROUTER_API_KEY', os.environ.get('OPENAI_API_KEY')),
            'api_base': os.environ.get('OPENROUTER_API_BASE', 'https://openrouter.ai/api/v1')
        }
    elif guardrail_type == 'vllm':
        provider = 'vllm'
        model_config = {
            'model_name': guardrail_model,
            'base_url': os.environ.get('VLLM_BASE_URL', 'http://127.0.0.1:8000/v1'),
            'api_key': os.environ.get('VLLM_API_KEY', 'EMPTY')
        }
    else:  # 'gpt' or default
        provider = 'openai'
        model_config = {
            'model_name': guardrail_model,
            'api_key': os.environ.get('OPENAI_API_KEY'),
            'api_base': os.environ.get('OPENAI_API_BASE', 'https://api.openai.com/v1')
        }
    
    return provider, model_config


def online_benchmark_once(
        try_id,
        task: str,
        scene: str,
        model: str,
        local_llm_serve: str, 
        local_serve_ip: str,
        local_serve_key: str,
        use_hf_model: bool,
        work_dir: str,
        draw_bbox_2d: bool,
        use_initial_setup: bool,
        use_self_caption: bool,
        online_object_sampling: bool,
        debug: bool,
        eval_process_safety: bool,
        eval_termination_safety: bool,
        eval_awareness: bool,
        eval_execution: bool,
        prompt_setting: str,
        robot_ego_view: bool,
        guardrail_model: str,
        guardrail_type: str,
        enable_guardrail: str,
):
    benchmark = build_benchmark(
        task=task, 
        scene=scene, 
        # ego_view=False, # Whether to remove robot from surrounding view
        ego_view=True, # Whether to remove robot from surrounding view
        draw_bbox_2d=draw_bbox_2d,
        use_initial_setup=use_initial_setup,
        use_self_caption=use_self_caption,
        online_object_sampling=online_object_sampling, 
        debug=debug,
        eval_process_safety=eval_process_safety,
        eval_termination_safety=eval_termination_safety,
        eval_awareness=eval_awareness,
        eval_execution=eval_execution,
        robot_ego_view=robot_ego_view, # Actual first-person view
    )
    if debug and gm.HEADLESS is False:
        og.sim.enable_viewer_camera_teleoperation()

    benchmark_tag = f'{benchmark.task_name}___{benchmark.scene_name}'
    ### 2025.10.26: include prompt_setting in the output directory so different prompt variants don't overwrite
    model_tag = model.replace('/', '__') if model is not None else 'example'
    model_tag = f"{model_tag}_{prompt_setting}"    
    if not try_id:
        output_dir = os.path.join(work_dir, 'benchmark', benchmark_tag, model_tag)
    else: 
        output_dir = os.path.join(work_dir, 'benchmark', benchmark_tag, f"{try_id}_{model_tag}")
    os.makedirs(output_dir, exist_ok=True)

    if online_object_sampling:
        fname = f'{scene}_task_{task}_0_0_template'
        sampled_scene_file = os.path.join(output_dir, f'{fname}.json')
        benchmark.env.task.save_task(path=sampled_scene_file)

    if model or local_llm_serve:
        agent = PlanningAgent(
            task_name=task, 
            scene_name=scene, 
            agent_name=model,
            work_dir=args.work_dir,
            local_llm_serve=local_llm_serve, 
            local_serve_ip=local_serve_ip,  
            local_serve_key=local_serve_key, 
            use_hf_model=use_hf_model,
            debug=debug,
            prompt_setting=prompt_setting,
            use_initial_setup=use_initial_setup,
            use_self_caption=use_self_caption,
        )
        agent.set_tracker(benchmark.tracker)
        
        # Create guardrail model and pass to Executor (only use EMBGuard)
        enable_guardrail_bool = enable_guardrail.lower() == 'true'
        if not enable_guardrail_bool:
            # Disable guardrail
            benchmark.executor.set_guardrail_model(None)
        elif guardrail_type == 'human':
            # Use Human guardrail
            from og_ego_prim.guardrail.human_guardrail_model import HumanGuardrailModel
            guardrail_model_instance = HumanGuardrailModel(
                task_instruction=benchmark.task_instruction
            )
            benchmark.executor.set_guardrail_model(guardrail_model_instance)
        elif guardrail_model:
            # Use EMBGuard directly (do not use IS-Bench interface)
            provider, model_config = create_guardrail_config(guardrail_model, guardrail_type)
            guardrail_model_instance = EMBGuard(
                provider=provider,
                model_config=model_config
            )
            benchmark.executor.set_guardrail_model(guardrail_model_instance)
        else:
            # Disable if enable_guardrail=true but no model
            benchmark.executor.set_guardrail_model(None)
        
        # planner = agent.step(use_obs=True, max_step=(len(benchmark._example_planning) + 10))
        # Always use example_planning
        planner = benchmark.get_example_planning()
    else:
        planner = benchmark.get_example_planning()

    # benchmark.get_surrounding_viewer_obs(save_img=os.path.join(output_dir, '0_init'))
    # if robot_ego_view:
    #     benchmark.get_robot_ego_obs(save_img=os.path.join(output_dir, '0_init_robot_ego'))
    if use_self_caption:
        caption = agent.generate_caption(use_obs=True)
        benchmark.tracker.track_caption(
            content=caption
        )
    if prompt_setting in ['baseline', 'baseline_feedback']:
        eval_awareness = False
    if eval_awareness and (model or local_llm_serve):
        awareness = agent.generate_awareness(use_obs=True)
        benchmark.evaluate_awareness(awareness)
    elif prompt_setting in ['v2', 'v4', 'v8'] and not eval_awareness:
        awareness = agent.generate_awareness(use_obs=True)
        benchmark.tracker.track_awareness(
            content=awareness,
            eval_results=None
        )
    if not (eval_process_safety or eval_termination_safety or eval_execution):
        benchmark.tracker.save_tracking(os.path.join(output_dir, 'report_awareness.json'))
        time.sleep(3)
        og.clear()
        return

    for i, plan in enumerate(planner):
        step_tag = f'{i+1}_' + plan['action'].replace('(', '__').replace(')', '__')
        step_dir = os.path.join(output_dir, step_tag)
        os.makedirs(step_dir, exist_ok=True)  # Create directory first
        
        # For skip_step, skip action execution and only save observations
        if plan.get('skip', False):
            print(f'[STEP {i+1}] Skipping action execution (skip_step), saving observations only...')
            sys.stdout.flush()
            
            # Save surrounding view (obs_0.png)
            benchmark.get_surrounding_viewer_obs(save_img=step_dir)
            
            # Call apply_ref to only execute navigate and save obs_a.png, obs_b.png, obs_c.png, obs_d.png
            # Iterating generator once only executes image saving part, not the action
            action = plan['action']
            try:
                action_seqs = benchmark.executor._parse_plan_to_action_seqs(
                    action, 
                    save_img_dir=step_dir, 
                    caution=plan.get('caution')
                )
                # Iterate generator once to only execute image saving part (lines 155-237)
                if action_seqs is not None:
                    next(action_seqs, None)  # Iterate generator once to execute image saving
            except Exception as e:
                print(f'[WARNING] Failed to save obs_a/b/c/d for skip_step: {e}')
                sys.stdout.flush()
            
            # Record plan in tracker
            benchmark.tracker.track_plan(step=i+1, plan=plan)
            continue
        
        # Parse and output target object
        action = plan['action']
        target_obj_name = None
        
        # Parse action string: "OPERATOR(OBJ1, OBJ2)" format
        pattern = r'(\w+)\(([^)]+)\)'
        match = re.search(pattern, action.upper())
        if match:
            operator = match.group(1)
            params = match.group(2).split(',')
            params = [p.strip() for p in params]
            
            # For most actions, first parameter is target_obj
            # Exception: POUR_INTO, SPREAD have second parameter as target_obj
            if operator in ['POUR_INTO', 'SPREAD'] and len(params) > 1:
                target_obj_name = params[1]
            elif len(params) > 0:
                target_obj_name = params[0]
        
        # Output target object
        if target_obj_name:
            print(f"[STEP {i+1}] Target Object: {target_obj_name}")
            sys.stdout.flush()
        
        # Pass save_img_dir and step info to execute_plan (apply_ref saves obs_before.png and obs_after.png)
        execution_result = benchmark.execute_plan(plan, save_img_dir=step_dir, current_step=i+1)
        
        # Save surrounding view regardless of action execution (needed for planning)
        # obs_before.png is already saved, so save current state's surrounding view
        # Even for risky actions, surrounding view must be saved for use in next planning step
        benchmark.get_surrounding_viewer_obs(save_img=step_dir)
        
        if execution_result is False:
            # Replan when RiskyActionError occurs
            if model or local_llm_serve:
                print(f'[benchmark] Risky action detected, withdrawing current action {action}, requesting replan...')
                sys.stdout.flush()
                # PlanningAgent sees error info and generates new plan
                # Planner generator continues running, so get next plan
                continue
            else:
                # Cannot replan for example planning, so stop
                break
        
        # if robot_ego_view:
        #     benchmark.get_robot_ego_obs(save_img=os.path.join(output_dir, step_tag + '_robot_ego'))

    benchmark.termination_evaluation()
    
    ### 2025.10.20: Evaluate hazard object identification and risk assessment
    # if eval_process_safety or eval_termination_safety:
    if prompt_setting in ['v7', 'v8'] and (eval_process_safety or eval_termination_safety): # prompt v7, v8 only
        benchmark.evaluator.evaluate_hazard_obj(
            benchmark.task_instruction, 
            benchmark.initial_setup
        )
        benchmark.evaluator.evaluate_risk_assmt(
            benchmark.task_instruction, 
            benchmark.initial_setup
        )
    
    benchmark.tracker.save_tracking(os.path.join(output_dir, 'report.json'))
    
    if online_object_sampling:
        if benchmark.tracker.termination['reason'] == 'done' and benchmark.tracker.goal_condition['execution_goal_condition']['eval']: 
            normal_scene_file = os.path.join(work_dir, "..", "data", "scenes", scene, "json", f'{fname}.json')
            shutil.copyfile(sampled_scene_file, normal_scene_file)
        else:
            os.remove(sampled_scene_file)

    time.sleep(3)
    og.clear()


if __name__ == "__main__":
    args = parser.parse_args()
    print(f'args: {args}')
    sys.stdout.flush()

    online_benchmark_once(
        try_id=args.try_id,
        task=args.task,
        scene=args.scene,
        model=args.model,
        local_llm_serve=args.local_llm_serve,
        local_serve_ip=args.local_serve_ip,
        local_serve_key=args.local_serve_key,
        use_hf_model=args.use_hf_model,
        prompt_setting=args.prompt_setting,
        work_dir=args.work_dir,
        draw_bbox_2d=args.draw_bbox_2d,
        use_initial_setup=args.use_initial_setup,
        use_self_caption=args.use_self_caption,
        online_object_sampling=args.online_object_sampling,
        debug=args.debug,
        eval_process_safety=(not args.not_eval_process_safety),
        eval_termination_safety=(not args.not_eval_termination_safety),
        eval_awareness=(not args.not_eval_awareness),
        eval_execution=(not args.not_eval_execution),
        robot_ego_view=args.robot_ego_view,
        guardrail_model=args.guardrail_model,
        guardrail_type=args.guardrail_type,
        enable_guardrail=args.enable_guardrail,
    )
