import json
import os
from typing import List, Literal

import imageio

from og_ego_prim.benchmark.tracker import EvalTracker
from og_ego_prim.utils.types import StepwisePlan


class OnlineEvalTracker(EvalTracker):

    def __init__(self):
        super().__init__()

        self.plans = []
        self.raw_outputs = []
        self.prompts = []
        self.awareness = None
        self.caption = None

        self.goal_condition = {}
        self.termination = None

        self.error_stack = []
        self.video_cache = []
    
    def track_plan(self, **kwargs):
        self.plans.append(dict(**kwargs))
    
    def track_raw_output(self, **kwargs):
        self.raw_outputs.append(dict(**kwargs))

    def track_prompt(self, **kwargs):
        self.prompts.append(dict(**kwargs))

    def track_error(self, **kwargs):
        self.error_stack.append(dict(**kwargs))

    def track_process_safety_goal_condition(self, **kwargs):
        if 'process_safety_goal_condition' not in self.goal_condition:
            self.goal_condition['process_safety_goal_condition'] = []
        self.goal_condition['process_safety_goal_condition'].append(dict(**kwargs))
    
    def track_termination_safety_goal_condition(self, **kwargs):
        if 'termination_safety_goal_condition' not in self.goal_condition:
            self.goal_condition['termination_safety_goal_condition'] = []
        self.goal_condition['termination_safety_goal_condition'].append(dict(**kwargs))
    
    def track_execution_goal_condition(self, **kwargs):
        self.goal_condition['execution_goal_condition'] = dict(**kwargs)
    
    def track_awareness(self, **kwargs):
        self.awareness = dict(**kwargs)
    
    def track_hazard_obj_evaluation(self, **kwargs):
        if not hasattr(self, 'hazard_obj_evaluation'):
            self.hazard_obj_evaluation = dict(**kwargs)
        else:
            self.hazard_obj_evaluation.update(dict(**kwargs))
    
    def track_risk_assmt_evaluation(self, **kwargs):
        if not hasattr(self, 'risk_assmt_evaluation'):
            self.risk_assmt_evaluation = dict(**kwargs)
        else:
            self.risk_assmt_evaluation.update(dict(**kwargs))
        
    def track_caption(self, **kwargs):
        self.caption = dict(**kwargs)
        
    def track_termination(self, **kwargs):
        self.termination = dict(**kwargs)

    def track_video_rgb(self, rgb):
        self.video_cache.append(rgb)

    def save_video(self, save_path: str):
        if not self.video_cache:
            return

        if os.path.isdir(save_path):
            save_path = os.path.join(save_path, 'video.mp4')
        else:
            assert save_path.endswith('.mp4')
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            
        video_writer = imageio.get_writer(save_path, fps=30)
        for rgb in self.video_cache:
            video_writer.append_data(rgb)
        video_writer.close()

    def save_tracking(self, save_path: str):
        save_dir = os.path.dirname(save_path)
        os.makedirs(save_dir, exist_ok=True)

        report = {
            'task': self.task,
            'scene': self.scene,
            'model': self.model,
            'awareness': self.awareness,
            'plans': [
                {
                    'step': plan['step'], 
                    'hazard_obj': plan['plan'].get('hazard_obj', None),  ### add hazard_obj
                    'reasoning': plan['plan'].get('reasoning', None), ### add reasoning
                    'action': plan['plan']['action'], 
                    'risk_assmt': plan['plan'].get('risk_assmt', None),  ### add risk assessment
                    'caution': plan['plan']['caution']
                }
                for plan in self.plans
            ],
            'termination': self.termination,
            'error_stack': self.error_stack,
        }

        if 'process_safety_goal_condition' in self.goal_condition:
            report['process_safety_goal_condition'] = self.goal_condition['process_safety_goal_condition']
        if 'termination_safety_goal_condition' in self.goal_condition:
            report['termination_safety_goal_condition'] = self.goal_condition['termination_safety_goal_condition']
        if 'execution_goal_condition' in self.goal_condition:
            report['execution_goal_condition'] = self.goal_condition['execution_goal_condition']

        ### 2025.10.20: Include hazard, risk evaluation results if present
        if hasattr(self, 'hazard_obj_evaluation'):
            report['hazard_obj_evaluation'] = self.hazard_obj_evaluation
        if hasattr(self, 'risk_assmt_evaluation'):
            report['risk_assmt_evaluation'] = self.risk_assmt_evaluation

        report['raw_outputs'] = self.raw_outputs
        
        # Format prompts with input/output types and filter based on guardrail status
        formatted_prompts = []
        plans_dict = {plan['step']: plan for plan in self.plans}
        
        # Extract steps that should show input+output (step N+1 after risky step N)
        steps_with_input = set()
        for error in self.error_stack:
            if error.get('err_type') == 'RiskyActionError':
                risky_step = error.get('step')
                # Step N+1 (next step after risky action) should show input+output
                steps_with_input.add(risky_step + 1)
        
        for prompt_entry in self.prompts:
            step = prompt_entry.get('step')
            prompt_content = prompt_entry.get('content', '')
            
            # Get the corresponding plan for this step
            plan = plans_dict.get(step)
            if plan is None:
                continue
                
            action_content = plan['plan']['action']
            
            # Check if this step should show input (step N+1 after risky step N)
            should_show_input = step in steps_with_input
            
            # When step N+1 (after risky step), include both input (prompt) and output (action)
            if should_show_input:
                formatted_prompts.append({
                    'step': step,
                    'type': 'input',
                    'content': prompt_content
                })
                formatted_prompts.append({
                    'step': step,
                    'type': 'output',
                    'content': action_content
                })
            # When safe, include "Your Task:" section from prompt + output (action)
            else:  # safe or risky step N itself
                # Extract "Your Task:" section from prompt
                task_section_start = prompt_content.find('Your Task:')
                if task_section_start != -1:
                    task_section_content = prompt_content[task_section_start:]
                    formatted_prompts.append({
                        'step': step,
                        'type': 'input',
                        'content': task_section_content
                    })
                formatted_prompts.append({
                    'step': step,
                    'type': 'output',
                    'content': action_content
                })
        
        report['prompts'] = formatted_prompts
        
        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(report, f, indent=4, ensure_ascii=False)

        self.save_video(save_dir)
