# -*- coding: utf-8 -*-
"""因果推理系统 - 优化版（详细日志 + 知识积累增强）"""

import json
import logging
import re
from typing import List, Dict, Tuple, Optional, Any

logger = logging.getLogger("agent_frame")


class CausalHypothesisGenerator:
    """CHG: 因果假设生成器"""
    
    def __init__(self, llm_agent, prompt_template_path: Optional[str] = None):
        self.llm_agent = llm_agent
        self.generation_history = []
        
        self.hypothesis_prompt = """You are analyzing an agent's behavior in ALFWorld household tasks.

**Context:**
- Task: {task}
- Current State: {state}
- Recent Actions: {history}

**Generate 2 causal hypotheses in this EXACT format:**
[Hypothesis 1]: If [condition], then [outcome] because [reason]
[Hypothesis 2]: If [condition], then [outcome] because [reason]

**Example:**
[Hypothesis 1]: If agent opens receptacle before taking items, then action succeeds because accessibility is required
[Hypothesis 2]: If agent goes to appliance location first, then using it succeeds because proximity is needed

**Now generate hypotheses for the current scenario:**"""
    
    def _format_history(self, history: List[Dict]) -> str:
        if not history:
            return "No history available"
        recent = history[-4:] if len(history) > 4 else history
        formatted = []
        for msg in recent:
            role = msg.get('role', 'unknown')
            content = str(msg.get('content', ''))[:120]
            formatted.append(f"{role}: {content}")
        return "\n".join(formatted)
    
    def _parse_hypotheses(self, response: str, max_num: int = 2) -> List[str]:
        if not response:
            return ["Default: Proper preconditions needed for success"]
        
        hypotheses = []
        pattern = r'\[Hypothesis\s+\d+\]:\s*(.+?)(?=\[Hypothesis|\Z)'
        matches = re.findall(pattern, response, re.DOTALL | re.IGNORECASE)
        
        for match in matches:
            hyp = match.strip()
            if len(hyp) > 20:
                first_sentence = hyp.split('.')[0] + '.'
                hypotheses.append(first_sentence)
        
        if not hypotheses:
            lines = [l.strip() for l in response.split('\n') if l.strip() and len(l.strip()) > 30]
            hypotheses = lines[:max_num]
        
        return hypotheses[:max_num] if hypotheses else ["Default: Follow standard procedures"]
    
    def generate(self, task_desc: str, current_state: str, history: List[Dict], max_hypotheses: int = 2) -> List[str]:
        try:
            history_text = self._format_history(history)
            prompt_text = self.hypothesis_prompt.format(
                task=task_desc[:200],
                state=current_state[:200],
                history=history_text
            )
            
            messages = [{"role": "user", "content": prompt_text}]
            
            logger.info("🧠 CHG - 生成因果假设")
            response = self.llm_agent(messages)
            
            hypotheses = self._parse_hypotheses(response, max_hypotheses)
            
            # 保存历史
            self.generation_history.append({
                'task': task_desc[:100],
                'prompt': prompt_text[:200],
                'response': response[:200],
                'hypotheses': hypotheses
            })
            
            logger.info(f"✅ 生成 {len(hypotheses)} 个假设")
            for i, h in enumerate(hypotheses, 1):
                logger.info(f"  H{i}: {h[:80]}...")
            
            return hypotheses
            
        except Exception as e:
            logger.error(f"❌ CHG失败: {e}")
            import traceback
            logger.debug(traceback.format_exc())
            return ["Default: Ensure preconditions before actions"]


class InterventionPlanner:
    """IP: 实验干预规划器"""
    
    def __init__(self, llm_agent, prompt_template_path: Optional[str] = None):
        self.llm_agent = llm_agent
        self.planning_history = []
        
        self.intervention_prompt = """Based on these causal hypotheses:
{hypotheses}

Current situation:
{current_state}

**Suggest ONE specific action to test the hypothesis.**

Output format:
[Intervention]: <specific action>
[Expected Effect]: <what should happen>

Example:
[Intervention]: open cabinet 1
[Expected Effect]: Items inside become accessible

Your answer:"""
    
    def _parse_plan(self, response: str) -> Dict[str, str]:
        if not response:
            return {'intervention': '', 'expected_effect': '', 'alternative': '', 'rationale': ''}
        
        plan = {
            'intervention': '',
            'expected_effect': '',
            'alternative': 'retry',
            'rationale': 'test hypothesis'
        }
        
        for line in response.split('\n'):
            line_lower = line.lower()
            if 'intervention' in line_lower and ':' in line:
                plan['intervention'] = line.split(':', 1)[1].strip()[:150]
            elif 'effect' in line_lower and ':' in line:
                plan['expected_effect'] = line.split(':', 1)[1].strip()[:150]
        
        if not plan['intervention']:
            sentences = [s.strip() for s in response.split('.') if len(s.strip()) > 15]
            if sentences:
                plan['intervention'] = sentences[0]
                plan['expected_effect'] = 'Enable task progress'
        
        return plan
    
    def plan(self, hypotheses: List[str], current_state: str) -> Dict[str, str]:
        try:
            if not hypotheses:
                return {'intervention': '', 'expected_effect': '', 'alternative': '', 'rationale': ''}
            
            hyp_text = "\n".join([f"{i+1}. {h[:120]}" for i, h in enumerate(hypotheses[:2])])
            prompt_text = self.intervention_prompt.format(
                hypotheses=hyp_text,
                current_state=current_state[:200]
            )
            
            messages = [{"role": "user", "content": prompt_text}]
            
            logger.info("🔬 IP - 规划实验干预")
            response = self.llm_agent(messages)
            
            plan = self._parse_plan(response)
            
            # 保存历史
            self.planning_history.append({
                'hypotheses': hypotheses,
                'response': response[:200],
                'plan': plan
            })
            
            if plan['intervention']:
                logger.info(f"✅ 干预: {plan['intervention'][:70]}")
                logger.info(f"   预期: {plan['expected_effect'][:70]}")
            else:
                logger.info("⚠️ 未生成有效干预")
            
            return plan
            
        except Exception as e:
            logger.error(f"❌ IP失败: {e}")
            import traceback
            logger.debug(traceback.format_exc())
            return {'intervention': '', 'expected_effect': '', 'alternative': '', 'rationale': ''}


class EffectValidator:
    """EV: 因果效应验证器（优化版 - 增强知识积累）"""
    
    def __init__(self, llm_agent, prompt_template_path: Optional[str] = None):
        self.llm_agent = llm_agent
        self.validation_history = []
        
        self.validation_prompt = """Validate whether the intervention produced the expected effect.

**Intervention taken:** {intervention}
**Expected effect:** {expected_effect}
**Actual observation:** {observation}
**Task success:** {success}

**Answer in this format:**
[Status]: Success/Partial/Failure
[Confidence]: 0.0-1.0
[Learning]: Brief insight about what was confirmed

Your analysis:"""
    
    def _heuristic_validate(self, intervention: str, observation: str, success: Optional[bool]) -> Dict[str, Any]:
        """启发式验证（快速且宽松）"""
        validation = {
            'status': 'Partial',
            'confidence': 0.5,
            'confirmed_factors': [],
            'refuted_factors': [],
            'updated_understanding': ''
        }
        
        obs_lower = observation.lower()
        
        # 成功指标（扩展）
        success_keywords = [
            'you pick up', 'you put', 'you open', 'you close', 'you arrive',
            'you see', 'you take', 'you heat', 'you cool', 'you clean',
            'you use', 'on the', 'in the', 'you are now'
        ]
        
        # 失败指标
        failure_keywords = [
            'nothing happens', 'nothing happened', 'error', 'cannot', 'fail',
            'not here', 'already', 'closed'
        ]
        
        success_count = sum(1 for kw in success_keywords if kw in obs_lower)
        failure_count = sum(1 for kw in failure_keywords if kw in obs_lower)
        
        # 更宽松的判断标准
        if success_count > 0 and failure_count == 0:
            validation['status'] = 'Success'
            validation['confidence'] = 0.65 + min(0.25, success_count * 0.05)
            validation['updated_understanding'] = f'Action succeeded with {success_count} positive indicators'
        elif success_count > failure_count:
            validation['status'] = 'Partial'
            validation['confidence'] = 0.55
            validation['updated_understanding'] = 'Partial success with mixed results'
        elif failure_count > 0:
            validation['status'] = 'Failure'
            validation['confidence'] = 0.6
            validation['updated_understanding'] = f'Action failed with {failure_count} failure indicators'
        elif success is True:
            validation['status'] = 'Success'
            validation['confidence'] = 0.75
            validation['updated_understanding'] = 'Task completed successfully'
        
        return validation
    
    def _parse_validation(self, response: str, fallback: Dict) -> Dict[str, Any]:
        """解析LLM验证结果"""
        validation = fallback.copy()
        
        if not response:
            return validation
        
        for line in response.split('\n'):
            line_lower = line.lower()
            
            if 'status' in line_lower and ':' in line:
                status_text = line.split(':', 1)[1].strip().lower()
                if 'success' in status_text:
                    validation['status'] = 'Success'
                elif 'partial' in status_text:
                    validation['status'] = 'Partial'
                elif 'fail' in status_text:
                    validation['status'] = 'Failure'
            
            if 'confidence' in line_lower and ':' in line:
                try:
                    conf_str = line.split(':', 1)[1].strip()
                    conf_match = re.search(r'[\d.]+', conf_str)
                    if conf_match:
                        conf_num = float(conf_match.group())
                        validation['confidence'] = min(1.0, max(0.0, conf_num))
                except:
                    pass
            
            if 'learning' in line_lower and ':' in line:
                learning_text = line.split(':', 1)[1].strip()
                if learning_text and len(learning_text) > 5:
                    validation['updated_understanding'] = learning_text
        
        return validation
    
    def validate(self, intervention: str, expected_effect: str, observation: str, success: Optional[bool] = None) -> Dict[str, Any]:
        try:
            if not intervention or len(intervention) < 5:
                return {
                    'status': 'Skipped',
                    'confidence': 0.0,
                    'confirmed_factors': [],
                    'refuted_factors': [],
                    'updated_understanding': 'No intervention to validate'
                }
            
            logger.info("🔍 EV - 验证因果效应")
            
            # 先用启发式验证
            heuristic_validation = self._heuristic_validate(intervention, observation, success)
            
            # 尝试LLM验证（增强）
            try:
                success_text = "Yes" if success else ("No" if success is False else "In Progress")
                prompt_text = self.validation_prompt.format(
                    intervention=intervention[:150],
                    expected_effect=expected_effect[:150],
                    observation=observation[:200],
                    success=success_text
                )
                
                messages = [{"role": "user", "content": prompt_text}]
                response = self.llm_agent(messages)
                
                validation = self._parse_validation(response, heuristic_validation)
            except Exception as llm_error:
                logger.debug(f"LLM验证失败，使用启发式结果: {llm_error}")
                validation = heuristic_validation
            
            # 保存历史
            self.validation_history.append({
                'intervention': intervention[:100],
                'expected': expected_effect[:100],
                'observation': observation[:100],
                'validation': validation
            })
            
            logger.info(f"✅ 验证: {validation['status']} (置信度: {validation['confidence']:.2f})")
            if validation.get('updated_understanding'):
                logger.info(f"   洞察: {validation['updated_understanding'][:80]}...")
            
            return validation
            
        except Exception as e:
            logger.error(f"❌ EV失败: {e}")
            import traceback
            logger.debug(traceback.format_exc())
            return {
                'status': 'Error',
                'confidence': 0.0,
                'confirmed_factors': [],
                'refuted_factors': [],
                'updated_understanding': f'Validation error: {str(e)}'
            }


class CausalReasoner:
    """集成的因果推理系统（优化版）"""
    
    def __init__(self, llm_agent, config: Optional[Dict] = None):
        self.config = config or {}
        
        logger.info("🧠 初始化因果推理系统...")
        
        self.chg = CausalHypothesisGenerator(llm_agent, self.config.get('chg_prompt_path'))
        self.ip = InterventionPlanner(llm_agent, self.config.get('ip_prompt_path'))
        self.ev = EffectValidator(llm_agent, self.config.get('ev_prompt_path'))
        
        self.causal_knowledge_base = []
        self.session_stats = {
            'total_hypotheses': 0,
            'total_interventions': 0,
            'successful_validations': 0,
            'partial_validations': 0,
            'failed_validations': 0,
            'skipped_validations': 0
        }
        
        self.reasoning_log = []
        
        logger.info("✅ 因果推理系统就绪 (CHG + IP + EV)")
    
    def reason(self, task_desc: str, state_history: List[Dict], current_state: str, max_hypotheses: int = 2) -> Tuple[List[str], Dict[str, str]]:
        """执行因果推理: CHG + IP"""
        logger.info("\n🔬 " + "=" * 48)
        logger.info("因果推理流程")
        logger.info("=" * 50)
        
        hypotheses = self.chg.generate(task_desc, current_state, state_history, max_hypotheses)
        self.session_stats['total_hypotheses'] += len(hypotheses)
        
        intervention_plan = self.ip.plan(hypotheses, current_state)
        if intervention_plan.get('intervention'):
            self.session_stats['total_interventions'] += 1
        
        reasoning_step = {
            'task': task_desc[:100],
            'state': current_state[:100],
            'hypotheses': hypotheses,
            'intervention_plan': intervention_plan
        }
        self.reasoning_log.append(reasoning_step)
        
        return hypotheses, intervention_plan
    
    def validate_and_learn(self, intervention: str, expected: str, observation: str, success: Optional[bool] = None) -> Dict[str, Any]:
        """验证效应并学习（降低知识积累门槛）"""
        validation = self.ev.validate(intervention, expected, observation, success)
        
        # 更新统计
        if validation['status'] == 'Success':
            self.session_stats['successful_validations'] += 1
        elif validation['status'] == 'Partial':
            self.session_stats['partial_validations'] += 1
        elif validation['status'] == 'Failure':
            self.session_stats['failed_validations'] += 1
        elif validation['status'] == 'Skipped':
            self.session_stats['skipped_validations'] += 1
        
        # 🔑 降低知识积累门槛：Success、Partial 且置信度 >= 0.45 就保存
        if validation['status'] in ['Success', 'Partial'] and validation['confidence'] >= 0.45:
            knowledge_entry = {
                'intervention': intervention[:150],
                'expected_effect': expected[:150],
                'actual_observation': observation[:150],
                'status': validation['status'],
                'confidence': validation['confidence'],
                'understanding': validation.get('updated_understanding', ''),
                'success': success
            }
            self.causal_knowledge_base.append(knowledge_entry)
            logger.info(f"📚 新因果知识 (总计: {len(self.causal_knowledge_base)})")
        
        # 记录到推理日志
        if self.reasoning_log:
            self.reasoning_log[-1]['validation'] = validation
        
        return validation
    
    def get_statistics(self) -> Dict[str, Any]:
        """获取统计信息"""
        total_validations = (
            self.session_stats['successful_validations'] + 
            self.session_stats['partial_validations'] + 
            self.session_stats['failed_validations']
        )
        
        return {
            **self.session_stats,
            'knowledge_base_size': len(self.causal_knowledge_base),
            'success_rate': (
                (self.session_stats['successful_validations'] + 
                 0.5 * self.session_stats['partial_validations']) / 
                max(1, total_validations)
            ),
            'total_reasoning_steps': len(self.reasoning_log)
        }
    
    def save_knowledge_base(self, path: str):
        """保存完整的因果知识库"""
        try:
            output_data = {
                'knowledge_base': self.causal_knowledge_base,
                'statistics': self.get_statistics(),
                'reasoning_log': self.reasoning_log,
                'generation_history': self.chg.generation_history,
                'planning_history': self.ip.planning_history,
                'validation_history': self.ev.validation_history
            }
            
            with open(path, 'w', encoding='utf-8') as f:
                json.dump(output_data, f, indent=2, ensure_ascii=False)
            
            if len(self.causal_knowledge_base) > 0:
                logger.info(f"💾 因果知识库已保存: {len(self.causal_knowledge_base)} 条知识")
                logger.info(f"   推理步骤: {len(self.reasoning_log)}")
                logger.info(f"   验证记录: {len(self.ev.validation_history)}")
        except Exception as e:
            logger.error(f"❌ 保存知识库失败: {e}")
