# -*- coding: utf-8 -*-
"""
优化版因果推理系统
- 降低验证阈值
- 更智能的成功判断
- 渐进式知识积累
"""

import json
import logging
import re
from typing import List, Dict, Tuple, Optional, Any

logger = logging.getLogger("agent_frame")


class CausalHypothesisGenerator:
    """CHG: 因果假设生成器（优化版）"""
    
    def __init__(self, llm_agent, prompt_template_path: Optional[str] = None):
        self.llm_agent = llm_agent
        self.hypothesis_prompt = """Analyze the agent's current situation in ALFWorld.

Task: {task}
State: {state}
Recent: {history}

Generate 2 simple hypotheses:
[Hypothesis 1]: If [action], then [result]
[Hypothesis 2]: If [action], then [result]

Example:
[Hypothesis 1]: If open receptacle first, then can take items inside
[Hypothesis 2]: If go to location first, then can interact with objects there

Answer:"""
    
    def _format_history(self, history: List[Dict]) -> str:
        if not history:
            return "No history"
        recent = history[-3:] if len(history) > 3 else history
        parts = []
        for msg in recent:
            content = str(msg.get('content', ''))[:100]
            parts.append(f"{msg.get('role', '?')}: {content}")
        return " | ".join(parts)
    
    def _parse_hypotheses(self, response: str, max_num: int = 2) -> List[str]:
        if not response:
            return ["If proper setup, then success"]
        
        hypotheses = []
        for match in re.finditer(r'\[Hypothesis\s+\d+\]:\s*(.+?)(?=\[Hypothesis|\Z)', response, re.DOTALL | re.IGNORECASE):
            hyp = match.group(1).strip().split('.')[0]
            if len(hyp) > 15:
                hypotheses.append(hyp)
        
        if not hypotheses:
            lines = [l.strip() for l in response.split('\n') if 20 < len(l.strip()) < 200]
            hypotheses = lines[:max_num]
        
        return hypotheses[:max_num] if hypotheses else ["If follow procedure, then succeed"]
    
    def generate(self, task_desc: str, current_state: str, history: List[Dict], max_hypotheses: int = 2) -> List[str]:
        try:
            history_text = self._format_history(history)
            prompt_text = self.hypothesis_prompt.format(
                task=task_desc[:150],
                state=current_state[:150],
                history=history_text
            )
            
            messages = [{"role": "user", "content": prompt_text}]
            logger.info("🧠 CHG")
            response = self.llm_agent(messages)
            
            hypotheses = self._parse_hypotheses(response, max_hypotheses)
            logger.info(f"✅ {len(hypotheses)} 假设")
            
            return hypotheses
            
        except Exception as e:
            logger.error(f"❌ CHG: {e}")
            return ["If proper conditions, then success"]


class InterventionPlanner:
    """IP: 干预规划器（优化版）"""
    
    def __init__(self, llm_agent, prompt_template_path: Optional[str] = None):
        self.llm_agent = llm_agent
        self.intervention_prompt = """Hypotheses: {hypotheses}
Situation: {current_state}

Suggest one action:
[Intervention]: <action>
[Expected]: <result>

Answer:"""
    
    def _parse_plan(self, response: str) -> Dict[str, str]:
        plan = {'intervention': '', 'expected_effect': '', 'alternative': 'retry', 'rationale': ''}
        
        if not response:
            return plan
        
        for line in response.split('\n'):
            if 'intervention' in line.lower() and ':' in line:
                plan['intervention'] = line.split(':', 1)[1].strip()[:150]
            elif 'expected' in line.lower() and ':' in line:
                plan['expected_effect'] = line.split(':', 1)[1].strip()[:150]
        
        if not plan['intervention']:
            sentences = [s.strip() for s in response.split('.') if 10 < len(s.strip()) < 200]
            if sentences:
                plan['intervention'] = sentences[0]
                plan['expected_effect'] = 'Progress task'
        
        return plan
    
    def plan(self, hypotheses: List[str], current_state: str) -> Dict[str, str]:
        try:
            if not hypotheses:
                return {'intervention': '', 'expected_effect': '', 'alternative': '', 'rationale': ''}
            
            hyp_text = " | ".join(h[:100] for h in hypotheses[:2])
            prompt_text = self.intervention_prompt.format(hypotheses=hyp_text, current_state=current_state[:150])
            
            messages = [{"role": "user", "content": prompt_text}]
            logger.info("🔬 IP")
            response = self.llm_agent(messages)
            
            plan = self._parse_plan(response)
            if plan['intervention']:
                logger.info(f"✅ {plan['intervention'][:50]}...")
            
            return plan
            
        except Exception as e:
            logger.error(f"❌ IP: {e}")
            return {'intervention': '', 'expected_effect': '', 'alternative': '', 'rationale': ''}


class EffectValidator:
    """EV: 效应验证器（优化版 - 更宽松的验证）"""
    
    def __init__(self, llm_agent, prompt_template_path: Optional[str] = None):
        self.llm_agent = llm_agent
        self.validation_history = []
        self.validation_prompt = """Did the intervention work?

Intervention: {intervention}
Expected: {expected_effect}
Result: {observation}

Answer briefly:
[Status]: Success/Partial/Failure
[Score]: 0.0-1.0

Analysis:"""
    
    def _simple_validate(self, intervention: str, observation: str, success: Optional[bool]) -> Dict[str, Any]:
        """简单的启发式验证（后备方案）"""
        validation = {
            'status': 'Partial',
            'confidence': 0.5,
            'confirmed_factors': [],
            'refuted_factors': [],
            'updated_understanding': ''
        }
        
        # 启发式判断
        observation_lower = observation.lower()
        
        # 成功指标
        success_indicators = ['you pick up', 'you put', 'you open', 'you close', 'you arrive', 'you see']
        failure_indicators = ['nothing happens', 'error', 'cannot', 'fail']
        
        success_count = sum(1 for indicator in success_indicators if indicator in observation_lower)
        failure_count = sum(1 for indicator in failure_indicators if indicator in observation_lower)
        
        if success_count > 0 and failure_count == 0:
            validation['status'] = 'Success'
            validation['confidence'] = 0.7
        elif failure_count > 0:
            validation['status'] = 'Failure'
            validation['confidence'] = 0.6
        elif success is True:
            validation['status'] = 'Success'
            validation['confidence'] = 0.8
        
        return validation
    
    def _parse_validation(self, response: str, fallback_validation: Dict) -> Dict[str, Any]:
        """解析LLM的验证结果"""
        validation = fallback_validation.copy()
        
        if not response:
            return validation
        
        for line in response.split('\n'):
            line_lower = line.lower()
            
            if 'status' in line_lower and ':' in line:
                status_text = line.split(':', 1)[1].strip().lower()
                if 'success' in status_text:
                    validation['status'] = 'Success'
                elif 'partial' in status_text:
                    validation['status'] = 'Partial'
                elif 'fail' in status_text:
                    validation['status'] = 'Failure'
            
            if 'score' in line_lower or 'confidence' in line_lower:
                if ':' in line:
                    try:
                        score_str = line.split(':', 1)[1].strip()
                        score_match = re.search(r'(0?\.\d+|\d+\.?\d*)', score_str)
                        if score_match:
                            score = float(score_match.group(1))
                            validation['confidence'] = min(1.0, max(0.0, score))
                    except:
                        pass
        
        return validation
    
    def validate(self, intervention: str, expected_effect: str, observation: str, success: Optional[bool] = None) -> Dict[str, Any]:
        try:
            if not intervention or len(intervention) < 5:
                return {'status': 'Skipped', 'confidence': 0.0, 'confirmed_factors': [], 'refuted_factors': [], 'updated_understanding': ''}
            
            # 先用启发式方法得到baseline验证
            fallback_validation = self._simple_validate(intervention, observation, success)
            
            # 尝试用LLM验证
            try:
                success_text = "Yes" if success else ("No" if success is False else "In Progress")
                prompt_text = self.validation_prompt.format(
                    intervention=intervention[:120],
                    expected_effect=expected_effect[:120],
                    observation=observation[:150]
                )
                
                messages = [{"role": "user", "content": prompt_text}]
                logger.info("🔍 EV")
                response = self.llm_agent(messages)
                
                validation = self._parse_validation(response, fallback_validation)
            except:
                # LLM失败，使用启发式结果
                validation = fallback_validation
            
            logger.info(f"✅ {validation['status']} ({validation['confidence']:.2f})")
            
            self.validation_history.append(validation)
            return validation
            
        except Exception as e:
            logger.error(f"❌ EV: {e}")
            return {'status': 'Error', 'confidence': 0.0, 'confirmed_factors': [], 'refuted_factors': [], 'updated_understanding': ''}


class CausalReasoner:
    """因果推理系统（优化版）"""
    
    def __init__(self, llm_agent, config: Optional[Dict] = None):
        self.config = config or {}
        
        logger.info("🧠 因果推理系统启动（优化版）")
        
        self.chg = CausalHypothesisGenerator(llm_agent)
        self.ip = InterventionPlanner(llm_agent)
        self.ev = EffectValidator(llm_agent)
        
        self.causal_knowledge_base = []
        self.session_stats = {
            'total_hypotheses': 0,
            'total_interventions': 0,
            'successful_validations': 0,
            'partial_validations': 0,
            'failed_validations': 0
        }
        
        logger.info("✅ CHG + IP + EV 就绪")
    
    def reason(self, task_desc: str, state_history: List[Dict], current_state: str, max_hypotheses: int = 2) -> Tuple[List[str], Dict[str, str]]:
        """因果推理流程"""
        logger.info("\n🔬 因果推理")
        
        hypotheses = self.chg.generate(task_desc, current_state, state_history, max_hypotheses)
        self.session_stats['total_hypotheses'] += len(hypotheses)
        
        intervention_plan = self.ip.plan(hypotheses, current_state)
        if intervention_plan.get('intervention'):
            self.session_stats['total_interventions'] += 1
        
        return hypotheses, intervention_plan
    
    def validate_and_learn(self, intervention: str, expected: str, observation: str, success: Optional[bool] = None) -> Dict[str, Any]:
        """验证并学习（更宽松的知识积累）"""
        validation = self.ev.validate(intervention, expected, observation, success)
        
        # 统计
        if validation['status'] == 'Success':
            self.session_stats['successful_validations'] += 1
        elif validation['status'] == 'Partial':
            self.session_stats['partial_validations'] += 1
        elif validation['status'] == 'Failure':
            self.session_stats['failed_validations'] += 1
        
        # 降低知识库门槛：只要不是完全失败且置信度>0.5就保存
        if validation['status'] in ['Success', 'Partial'] and validation['confidence'] >= 0.5:
            self.causal_knowledge_base.append({
                'intervention': intervention[:150],
                'expected_effect': expected[:150],
                'actual_observation': observation[:150],
                'status': validation['status'],
                'confidence': validation['confidence'],
                'success': success
            })
            logger.info(f"📚 知识+1 (总: {len(self.causal_knowledge_base)})")
        
        return validation
    
    def get_statistics(self) -> Dict[str, Any]:
        total = self.session_stats['successful_validations'] + self.session_stats['partial_validations'] + self.session_stats['failed_validations']
        return {
            **self.session_stats,
            'knowledge_base_size': len(self.causal_knowledge_base),
            'success_rate': (self.session_stats['successful_validations'] + 0.5 * self.session_stats['partial_validations']) / max(1, total)
        }
    
    def save_knowledge_base(self, path: str):
        try:
            with open(path, 'w', encoding='utf-8') as f:
                json.dump({
                    'knowledge_base': self.causal_knowledge_base,
                    'statistics': self.get_statistics()
                }, f, indent=2, ensure_ascii=False)
            
            if len(self.causal_knowledge_base) > 0:
                logger.info(f"💾 保存 {len(self.causal_knowledge_base)} 条因果知识")
        except Exception as e:
            logger.error(f"❌ 保存失败: {e}")
