# -*- coding: utf-8 -*-
"""
因果推理器 - 修复版（确保ALFWorld格式）
"""

import json
import logging
from typing import List, Tuple, Dict, Any

logger = logging.getLogger("agent_frame")

class CausalReasoner:
    """因果推理器"""
    
    def __init__(self, llm_agent, config: Dict[str, Any] = None):
        self.llm_agent = llm_agent
        self.config = config or {}
        self.max_hypotheses = self.config.get('max_hypotheses', 2)
        self.validation_threshold = self.config.get('validation_threshold', 0.5)
        self.causal_kb = {'validated_patterns': [], 'failed_patterns': []}
        logger.info(f"✅ 因果推理器初始化 (max_hypotheses={self.max_hypotheses})")
    
    def reason(self, task_desc: str, history: List[Dict], 
               observation: str, max_hypotheses: int = None) -> Tuple[List[str], Dict]:
        """生成因果假设和干预"""
        if max_hypotheses is None:
            max_hypotheses = self.max_hypotheses
        
        logger.info("\n🧪 " + "="*50)
        logger.info("因果推理流程")
        logger.info("="*50)
        
        # 1. 生成假设
        logger.info("🔬 CHG - 生成因果假设")
        hypotheses = self._generate_hypotheses(task_desc, observation, max_hypotheses)
        logger.info(f"✅ 生成 {len(hypotheses)} 个假设")
        for i, h in enumerate(hypotheses, 1):
            logger.info(f"  H{i}: {h[:80]}...")
        
        # 2. 规划干预
        logger.info("🧪 IP - 规划实验干预")
        intervention_plan = self._plan_intervention(hypotheses, observation)
        
        # 3. 强制修复格式
        raw_action = intervention_plan.get('intervention', '')
        if raw_action:
            fixed_action = self._fix_action_format(raw_action)
            intervention_plan['intervention'] = fixed_action
            logger.info(f"✅ 干预: {fixed_action[:60]}")
            logger.info(f"   预期: {intervention_plan.get('expected_effect', '')[:60]}")
        
        logger.info("🔬 因果建议: " + intervention_plan.get('intervention', 'None')[:80])
        
        return hypotheses, intervention_plan
    
    def _fix_action_format(self, action: str) -> str:
        """
        强制修复动作格式
        
        确保：
        1. 有 "Action: " 前缀
        2. 去除多余的 </thought><action> 等标签
        """
        # 去除HTML标签
        action = action.replace('</thought>', '').replace('<action>', '').replace('</action>', '')
        action = action.strip()
        
        # 如果已经有 "Action: "，确保它在开头
        if 'Action:' in action:
            parts = action.split('Action:', 1)
            if len(parts) > 1:
                action = parts[1].strip()
        
        # 确保有 "Action: " 前缀
        if not action.startswith('Action: '):
            action = f'Action: {action}'
        
        return action
    
    def _generate_hypotheses(self, task_desc: str, observation: str, max_count: int) -> List[str]:
        """生成因果假设"""
        prompt = f"""You are a causal reasoning expert for household tasks.

Task: {task_desc}
Current observation: {observation}

Generate {max_count} causal hypotheses in format "If [condition], then [effect]".

Focus on object locations, states, and task requirements.

Output ONLY the hypotheses, one per line."""

        try:
            response = self.llm_agent(prompt)
            lines = [l.strip() for l in response.split('\n') if l.strip()]
            hypotheses = [l for l in lines if l.lower().startswith('if')][:max_count]
            
            if not hypotheses:
                hypotheses = [
                    "If the agent searches cabinets, then it may find the target object",
                    "If the agent opens closed receptacles, then more objects become accessible"
                ]
            
            return hypotheses
        except Exception as e:
            logger.error(f"假设生成失败: {e}")
            return ["If the agent explores systematically, then task progress is made"]
    
    def _plan_intervention(self, hypotheses: List[str], observation: str) -> Dict:
        """规划干预"""
        
        prompt = f"""Based on these hypotheses:
{chr(10).join(f"{i+1}. {h}" for i, h in enumerate(hypotheses))}

Current observation: {observation}

Plan ONE action using ALFWorld commands:
- go to <receptacle>
- take <obj> from <receptacle>
- put <obj> in/on <receptacle>
- open <receptacle>
- close <receptacle>
- cool <obj> with <receptacle>
- heat <obj> with <receptacle>
- clean <obj> with <receptacle>

CRITICAL: Your response MUST start with "Action: "

Format:
Action: <command>
Expected: <effect>

Example:
Action: go to cabinet 3
Expected: See what is on cabinet 3"""

        try:
            response = self.llm_agent(prompt)
            
            intervention = ""
            expected_effect = ""
            
            # 解析响应
            for line in response.split('\n'):
                line = line.strip()
                if line.lower().startswith('action:'):
                    intervention = line.split(':', 1)[1].strip()
                elif line.lower().startswith('expected:'):
                    expected_effect = line.split(':', 1)[1].strip()
            
            # 如果没找到，用第一行
            if not intervention:
                lines = [l.strip() for l in response.split('\n') if l.strip()]
                if lines:
                    intervention = lines[0]
            
            return {
                'intervention': intervention,
                'expected_effect': expected_effect
            }
        
        except Exception as e:
            logger.error(f"干预规划失败: {e}")
            return {'intervention': '', 'expected_effect': ''}
    
    def validate_and_learn(self, intervention: str, expected_effect: str, 
                          actual_observation: str, success: bool) -> Dict:
        """验证并学习"""
        validation_result = {
            'intervention': intervention,
            'expected': expected_effect,
            'actual': actual_observation[:100],
            'success': success,
            'match_score': 0.0
        }
        
        if expected_effect and actual_observation:
            expected_words = set(expected_effect.lower().split())
            actual_words = set(actual_observation.lower().split())
            common_words = expected_words & actual_words
            
            if expected_words:
                validation_result['match_score'] = len(common_words) / len(expected_words)
        
        if validation_result['match_score'] > self.validation_threshold:
            self.causal_kb['validated_patterns'].append({
                'intervention': intervention,
                'effect': expected_effect,
                'score': validation_result['match_score']
            })
        
        return validation_result
    
    def save_knowledge_base(self, filepath: str):
        """保存知识库"""
        try:
            with open(filepath, 'w', encoding='utf-8') as f:
                json.dump(self.causal_kb, f, indent=2, ensure_ascii=False)
            logger.info(f"💾 知识库已保存: {filepath}")
        except Exception as e:
            logger.error(f"保存失败: {e}")
