import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import logging
from datetime import datetime

from core_system import AgenticAISystem, SemanticNode, KnowledgeState, ExpertKnowledge
from hybrid_neural_system import HybridNeuralSymbolicSystem
from pure_medical_qlearning import PureMedicalQLearning

logger = logging.getLogger(__name__)

@dataclass
class QuestionFeedback:
    question: str
    expert_rating: float
    outcome_improvement: float
    missing_info_found: List[str]
    clinical_relevance: float

class RLRewardSystem:
    
    def __init__(self):
        self.question_history = []
        self.reward_weights = {
            'expert_rating': 0.3,
            'outcome_improvement': 0.4,
            'info_discovery': 0.2,
            'clinical_relevance': 0.1
        }
        
        self.pattern_q_values = defaultdict(float)
        self.alpha = 0.1
        self.gamma = 0.9
        
        logger.info("RL Question Reward System initialized")
    
    def calculate_reward(self, feedback: QuestionFeedback) -> float:
        
        expert_score = feedback.expert_rating * self.reward_weights['expert_rating']
        outcome_score = feedback.outcome_improvement * self.reward_weights['outcome_improvement']
        info_score = min(len(feedback.missing_info_found) * 0.1, 0.5) * self.reward_weights['info_discovery']
        clinical_score = feedback.clinical_relevance * self.reward_weights['clinical_relevance']
        
        total_reward = expert_score + outcome_score + info_score + clinical_score
        
        if feedback.expert_rating < 0.3:
            total_reward *= 0.5
        
        return min(max(total_reward, 0.0), 1.0)
    
    def learn_pattern(self, question: str, reward: float):
        
        pattern = self._extract_pattern(question)
        current_q = self.pattern_q_values[pattern]
        avg_future = np.mean(list(self.pattern_q_values.values())) if self.pattern_q_values else 0
        
        new_q = current_q + self.alpha * (reward + self.gamma * avg_future - current_q)
        self.pattern_q_values[pattern] = new_q
        
        logger.info(f"Pattern '{pattern}' Q-value: {new_q:.3f}")
    
    def _extract_pattern(self, question: str) -> str:
        
        question_lower = question.lower()
        
        patterns = [
            ('family_history', ['family', 'history', 'relatives']),
            ('symptom_detail', ['when', 'how long', 'severity']),
            ('medication_history', ['medication', 'drugs', 'pills']),
            ('test_results', ['test', 'lab', 'result']),
            ('pain_assessment', ['pain', 'hurt', 'ache']),
            ('timeline', ['started', 'began', 'first']),
        ]
        
        for pattern_name, keywords in patterns:
            if any(keyword in question_lower for keyword in keywords):
                return pattern_name
        
        return 'general'
    
    def get_best_patterns(self) -> List[Tuple[str, float]]:
        return sorted(self.pattern_q_values.items(), key=lambda x: x[1], reverse=True)

class QFeedbackLoop:
    
    def __init__(self, qlearning_system, symbolic_system):
        self.qlearning_system = qlearning_system
        self.symbolic_system = symbolic_system
        self.feedback_active = True
        self.concept_boosts = defaultdict(float)
        
        logger.info("Q-Learning → Symbolic Feedback Loop initialized")
    
    def update_symbolic_with_q(self, patient_concepts: List[str]) -> Dict[str, float]:
        
        if not self.feedback_active:
            return {}
        
        q_missing = self.qlearning_system.discover_missing_concepts(patient_concepts, max_exploration=10)
        
        concept_boosts = {}
        
        for missing_data in q_missing:
            concept = missing_data['concept']
            q_value = missing_data['q_value']
            frequency = missing_data['mimic_frequency']
            
            boost = self._calculate_boost(q_value, frequency)
            
            if concept in self.symbolic_system.causal_graph.missing_entities:
                missing_entity = self.symbolic_system.causal_graph.missing_entities[concept]
                missing_entity.discovery_priority += boost
                
                logger.info(f"Boosted '{concept}' priority: +{boost:.3f}")
            
            concept_boosts[concept] = boost
        
        return concept_boosts
    
    def _calculate_boost(self, q_value: float, frequency: int) -> float:
        
        q_boost = min(q_value / 100.0, 0.5)
        freq_boost = min(np.log(frequency + 1) / 10.0, 0.3)
        
        return min(q_boost + freq_boost, 0.8)
    
    def get_enhanced_reasoning(self, patient_data: Dict, expert_inputs: List[Dict]) -> Dict:
        
        patient_concepts = self._extract_concepts(patient_data)
        
        boosts = self.update_symbolic_with_q(patient_concepts)
        
        symbolic_result = self.symbolic_system.process_medical_case(patient_data, expert_inputs)
        
        q_missing = self.qlearning_system.discover_missing_concepts(patient_concepts)
        q_concepts = [item['concept'] for item in q_missing]
        
        enhanced_result = symbolic_result.copy()
        enhanced_result['qlearning_insights'] = {
            'concept_boosts': boosts,
            'discoveries': len(boosts),
            'feedback_active': self.feedback_active
        }
        
        all_missing = list(set(symbolic_result['missing_critical_info'] + q_concepts))
        enhanced_result['missing_critical_info'] = all_missing
        
        q_text = f"\n\nQ-Learning Insights:\n"
        q_text += f"- Additional concepts found: {len(q_concepts)}\n"
        q_text += f"- Priority boosts applied: {len(boosts)}\n"
        if q_concepts:
            q_text += f"- Top discoveries: {', '.join(q_concepts[:3])}\n"
        
        enhanced_result['reasoning'] += q_text
        
        return enhanced_result
    
    def _extract_concepts(self, patient_data: Dict) -> List[str]:
        
        concepts = []
        
        for field in ['symptoms', 'diagnosis', 'known_info']:
            if field in patient_data:
                value = patient_data[field]
                if isinstance(value, str):
                    concepts.extend(value.replace(',', ' ').split())
                elif isinstance(value, list):
                    concepts.extend([str(item) for item in value])
        
        cleaned = []
        for concept in concepts:
            concept = str(concept).strip().lower()
            if len(concept) > 2 and concept.isalpha():
                cleaned.append(concept)
        
        return cleaned

def test_rl_feedback():
    
    print("Testing RL Feedback System")
    
    class MockQLearning:
        def discover_missing_concepts(self, concepts, max_exploration=10):
            return [
                {'concept': 'family_history', 'q_value': 50.0, 'mimic_frequency': 10},
                {'concept': 'blood_test', 'q_value': 30.0, 'mimic_frequency': 8}
            ]
    
    class MockSymbolic:
        def __init__(self):
            self.causal_graph = type('obj', (object,), {'missing_entities': {}})
        
        def process_medical_case(self, patient_data, expert_inputs):
            return {
                'reasoning': 'Basic reasoning',
                'missing_critical_info': ['test_concept'],
                'suggested_question': 'What about family history?'
            }
    
    reward_system = RLRewardSystem()
    
    feedback = QuestionFeedback(
        question="Does patient have family history of heart disease?",
        expert_rating=0.8,
        outcome_improvement=0.7,
        missing_info_found=['family_history', 'cardiac_risk'],
        clinical_relevance=0.9
    )
    
    reward = reward_system.calculate_reward(feedback)
    reward_system.learn_pattern(feedback.question, reward)
    
    print(f"Question reward: {reward:.3f}")
    print(f"Best patterns: {reward_system.get_best_patterns()}")
    
    mock_q = MockQLearning()
    mock_symbolic = MockSymbolic()
    
    feedback_loop = QFeedbackLoop(mock_q, mock_symbolic)
    
    test_patient = {
        'symptoms': 'chest pain',
        'known_info': ['chest', 'pain']
    }
    
    result = feedback_loop.get_enhanced_reasoning(test_patient, [])
    
    print(f"Enhanced reasoning generated")
    print(f"Q-learning insights: {result.get('qlearning_insights', {})}")

if __name__ == "__main__":
    test_rl_feedback()