import sys
from typing import Optional, Dict
from og_ego_prim.guardrail.base_guardrail_model import BaseGuardrailModel


class HumanGuardrailModel(BaseGuardrailModel):
    """
    Guardrail model implementation where human provides feedback directly
    """
    
    def __init__(self, task_instruction: str):
        """
        Args:
            task_instruction: Task instruction (for reference)
        """
        self.task_instruction = task_instruction
    
    def is_risky(
        self, 
        action_nl: str, 
        image_file: Optional[str] = None,
        caution: Optional[str] = None
    ) -> Dict[str, str]:
        """
        Receive action_nl and image path, ask human if it's risky
        
        Args:
            action_nl: Action description converted to natural language
            image_file: Captured image file path (obs_before.png)
            caution: Safety warning (optional)
            
        Returns:
            Dict[str, str]: {
                'potential_risk': 'risky' or 'benign',
                'risk_type': risk type or 'none',
                'hazard': hazard description or 'none'
            }
        """
        # Output prompt
        print("\n" + "=" * 80, file=sys.stdout)
        print("[HUMAN GUARDRAIL] Action Safety Evaluation", file=sys.stdout)
        print("=" * 80, file=sys.stdout)
        print(f"Task: {self.task_instruction}", file=sys.stdout)
        print(file=sys.stdout)
        print(f"Proposed Action: Execute action: {action_nl}", file=sys.stdout)
        if caution:
            print(f"Caution: {caution}", file=sys.stdout)
        print(file=sys.stdout)
        if image_file:
            print(f"Observation image: {image_file}", file=sys.stdout)
            print("(Please check the image to evaluate the safety of the action)", file=sys.stdout)
        print(file=sys.stdout)
        print("-" * 80, file=sys.stdout)
        print(file=sys.stdout)
        sys.stdout.flush()
        
        # Input risky status
        user_input = None
        while True:
            try:
                sys.stderr.write("Is this action RISKY? (yes/y/risky or no/n/unrisky/benign): ")
                sys.stderr.flush()
                line = sys.stdin.readline()
                if not line:  # EOF
                    sys.stderr.write("\n[ERROR] Cannot read from stdin. Please ensure the process is running in foreground.\n")
                    sys.stderr.flush()
                    user_input = 'benign'
                    break
                
                line = line.strip().lower()
                if not line:  # Ignore empty input and read again
                    continue
                
                if line in ['yes', 'y', 'risky']:
                    user_input = 'risky'
                    break
                elif line in ['no', 'n', 'unrisky', 'benign']:
                    user_input = 'benign'
                    break
                else:
                    sys.stderr.write("Please enter 'yes/y/risky' or 'no/n/unrisky/benign'\n")
                    sys.stderr.flush()
            except (EOFError, KeyboardInterrupt):
                sys.stderr.write("\n[ERROR] Cannot read from stdin. Please ensure the process is running in foreground.\n")
                sys.stderr.flush()
                user_input = 'benign'
                break
        
        if user_input is None:
            user_input = 'benign'
        
        if user_input == 'benign':
            # Return benign if safe
            result = {
                'potential_risk': 'benign',
                'risk_type': 'none',
                'hazard': 'none'
            }
            print("[guardrail] safe", file=sys.stdout)
            sys.stdout.flush()
            return result
        else:
            # If risky, request hazard and risk_type input in order
            try:
                sys.stderr.write("\nHazard : ")
                sys.stderr.flush()
                hazard_line = sys.stdin.readline()
                if not hazard_line:
                    hazard = 'unknown'
                else:
                    hazard = hazard_line.strip()
                
                sys.stderr.write("Risk Type : ")
                sys.stderr.flush()
                risk_type_line = sys.stdin.readline()
                if not risk_type_line:
                    risk_type = 'unknown'
                else:
                    risk_type = risk_type_line.strip()
            except (EOFError, KeyboardInterrupt):
                sys.stderr.write("\n[ERROR] Cannot read from stdin.\n")
                sys.stderr.flush()
                hazard = 'unknown'
                risk_type = 'unknown'
            
            result = {
                'potential_risk': 'risky',
                'risk_type': risk_type if risk_type else 'unknown',
                'hazard': hazard if hazard else 'unknown'
            }
            print(f"[guardrail] risky - hazard: {hazard if hazard else 'unknown'}, risk_type: {risk_type if risk_type else 'unknown'}", file=sys.stdout)
            sys.stdout.flush()
            return result

    def inference(
        self,
        action: str,
        image: Optional[str] = None,
        caution: Optional[str] = None,
        use_few_shot: bool = False,
        use_thinking: bool = False
    ) -> Dict[str, str]:
        """
        Add inference method (for compatibility with EMBGuard)
        Call is_risky() to return same result
        
        Args:
            action: Action description (same as action_nl in is_risky)
            image: Image file path (same as image_file in is_risky)
            caution: Safety caution (same as caution in is_risky)
            use_few_shot: Not used for human guardrail (for compatibility)
            use_thinking: Not used for human guardrail (for compatibility)
            
        Returns:
            Dict[str, str]: {
                'potential_risk': 'risky' or 'benign',
                'risk_type': risk type or 'none',
                'hazard': hazard description or 'none'
            }
        """
        # Call is_risky() method to return same result
        return self.is_risky(action_nl=action, image_file=image, caution=caution)

