import numpy as np
import pandas as pd
from collections import defaultdict, Counter
from typing import Dict, List, Tuple, Set
import logging
from datetime import datetime

logger = logging.getLogger(__name__)

class PureMedicalQLearning:
    
    def __init__(self):
        self.Q = None
        self.medical_concepts = []
        self.concept_to_index = {}
        self.index_to_concept = {}
        
        self.alpha = 0.1
        self.gamma = 0.9
        self.epsilon = 0.1
        
        self.min_frequency = 1
        self.max_concepts = 200
        
        self.concept_frequencies = Counter()
        self.concept_cooccurrences = defaultdict(Counter)
        self.mimic_cases_used = 0
        
        logger.info("Pure Data-Driven Q-Learning Initialized")
    
    def learn_from_mimic_data_only(self, mimic_cases: List[Dict]) -> List[str]:
        
        print("Learning EVERYTHING from MIMIC data...")
        
        self.mimic_cases_used = len(mimic_cases)
        all_concepts = []
        
        for case in mimic_cases:
            case_concepts = self._extract_concepts_from_case(case)
            all_concepts.extend(case_concepts)
            
            for concept in case_concepts:
                self.concept_frequencies[concept] += 1
            
            for i, concept1 in enumerate(case_concepts):
                for j, concept2 in enumerate(case_concepts):
                    if i != j:
                        self.concept_cooccurrences[concept1][concept2] += 1
        
        frequent_concepts = [
            concept for concept, freq in self.concept_frequencies.items() 
            if freq >= self.min_frequency and len(concept) > 2
        ]
        
        if not frequent_concepts:
            print("No concepts found with min_frequency=1, trying with all concepts...")
            frequent_concepts = [
                concept for concept, freq in self.concept_frequencies.items()
                if len(concept) > 2
            ]
        
        if not frequent_concepts:
            print("Creating concepts from raw case data...")
            frequent_concepts = self._create_concepts_from_raw_data(mimic_cases)
        
        if len(frequent_concepts) > self.max_concepts:
            frequent_concepts = [
                concept for concept, freq in self.concept_frequencies.most_common(self.max_concepts)
            ]
        
        if not frequent_concepts:
            print("Creating fallback concepts...")
            frequent_concepts = ['chest', 'pain', 'heart', 'blood', 'test', 'history', 'patient', 'treatment']
        
        self.medical_concepts = frequent_concepts
        
        self.concept_to_index = {concept: i for i, concept in enumerate(self.medical_concepts)}
        self.index_to_concept = {i: concept for concept, i in self.concept_to_index.items()}
        
        print(f"Learned {len(self.medical_concepts)} medical concepts from MIMIC")
        print(f"Top concepts: {self.medical_concepts[:10]}")
        
        return self.medical_concepts
    
    def _extract_concepts_from_case(self, case: Dict) -> List[str]:
        
        concepts = set()
        
        if case.get('symptoms'):
            symptoms = str(case['symptoms']).lower()
            for delimiter in [', ', ';', ' and ', ' or ', '\n']:
                symptoms = symptoms.replace(delimiter, ' ')
            concepts.update(symptoms.split())
        
        if case.get('diagnoses'):
            for diag in case['diagnoses']:
                if diag and str(diag).strip().lower() != 'nan':
                    diag_clean = str(diag).lower().replace('_', ' ')
                    concepts.update(diag_clean.split())
        
        if case.get('procedures'):
            for proc in case['procedures']:
                if proc and str(proc).strip().lower() != 'nan':
                    proc_clean = str(proc).lower().replace('_', ' ')
                    concepts.update(proc_clean.split())
        
        if case.get('medications'):
            for med in case['medications'][:10]:
                if med and str(med).strip().lower() != 'nan':
                    med_clean = str(med).lower().replace('_', ' ')
                    concepts.update(med_clean.split())
        
        if case.get('known_info'):
            for info in case['known_info']:
                if info and str(info).strip():
                    info_clean = str(info).lower().replace('_', ' ')
                    concepts.update(info_clean.split())
        
        cleaned_concepts = []
        for concept in concepts:
            concept = str(concept).strip().lower()
            if (len(concept) > 2 and 
                concept.isalpha() and 
                concept not in ['and', 'the', 'with', 'for', 'from', 'that', 'this', 'have', 'had', 'has']):
                cleaned_concepts.append(concept)
        
        return list(set(cleaned_concepts))
    
    def _create_concepts_from_raw_data(self, mimic_cases: List[Dict]) -> List[str]:
        
        fallback_concepts = set()
        
        for case in mimic_cases:
            for field in ['symptoms', 'diagnoses', 'procedures', 'medications', 'known_info']:
                if case.get(field):
                    if isinstance(case[field], str):
                        text = case[field].lower()
                        words = text.replace(',', ' ').replace(';', ' ').split()
                        for word in words:
                            word = word.strip()
                            if len(word) > 3 and word.isalpha():
                                fallback_concepts.add(word)
                    elif isinstance(case[field], list):
                        for item in case[field]:
                            if item:
                                text = str(item).lower()
                                words = text.replace(',', ' ').replace(';', ' ').split()
                                for word in words:
                                    word = word.strip()
                                    if len(word) > 3 and word.isalpha():
                                        fallback_concepts.add(word)
        
        return list(fallback_concepts)[:50]
    
    def build_pure_reward_matrix(self) -> np.ndarray:
        
        if not self.medical_concepts:
            raise ValueError("Must learn concepts first!")
        
        n_concepts = len(self.medical_concepts)
        reward_matrix = np.zeros((n_concepts, n_concepts))
        
        for concept1, cooccur_dict in self.concept_cooccurrences.items():
            if concept1 in self.concept_to_index:
                i = self.concept_to_index[concept1]
                
                for concept2, cooccur_count in cooccur_dict.items():
                    if concept2 in self.concept_to_index:
                        j = self.concept_to_index[concept2]
                        
                        base_reward = min(cooccur_count * 10, 100)
                        
                        medical_bonus = self._get_medical_relevance_bonus(concept1, concept2)
                        
                        reward_matrix[i, j] = base_reward + medical_bonus
        
        print(f"Built reward matrix: {n_concepts}x{n_concepts} from pure MIMIC patterns")
        return reward_matrix
    
    def _get_medical_relevance_bonus(self, concept1: str, concept2: str) -> float:
        
        medical_indicators = ['pain', 'blood', 'heart', 'lung', 'chest', 'pressure', 
                            'test', 'exam', 'drug', 'medication', 'treatment', 'therapy']
        
        bonus = 0.0
        
        if any(indicator in concept1 for indicator in medical_indicators):
            bonus += 5.0
        if any(indicator in concept2 for indicator in medical_indicators):
            bonus += 5.0
        
        symptom_words = ['pain', 'ache', 'pressure', 'difficulty', 'trouble']
        diagnosis_words = ['disease', 'syndrome', 'disorder', 'condition', 'failure']
        
        if (any(s in concept1 for s in symptom_words) and 
            any(d in concept2 for d in diagnosis_words)):
            bonus += 10.0
        
        return bonus
    
    def train_qlearning(self, episodes: int = 1000) -> Dict[str, float]:
        
        if not self.medical_concepts:
            raise ValueError("Must learn concepts first!")
        
        n_concepts = len(self.medical_concepts)
        self.Q = np.random.random((n_concepts, n_concepts)) * 0.1
        
        reward_matrix = self.build_pure_reward_matrix()
        
        print(f"Training Q-learning for {episodes} episodes...")
        
        scores = []
        
        for episode in range(episodes):
            state = np.random.randint(0, n_concepts)
            episode_score = 0
            
            for step in range(10):
                if np.random.random() < self.epsilon:
                    action = np.random.randint(0, n_concepts)
                else:
                    action = np.argmax(self.Q[state, :])
                
                reward = reward_matrix[state, action]
                
                next_state = action
                max_future_q = np.max(self.Q[next_state, :])
                
                self.Q[state, action] = (
                    self.Q[state, action] + 
                    self.alpha * (reward + self.gamma * max_future_q - self.Q[state, action])
                )
                
                episode_score += reward
                state = next_state
            
            scores.append(episode_score)
            
            if episode % 100 == 0:
                self.epsilon = max(0.01, self.epsilon * 0.99)
        
        avg_score = np.mean(scores[-100:])
        
        print(f"Q-learning training complete! Average score: {avg_score:.2f}")
        
        return {
            'final_score': avg_score,
            'total_episodes': episodes,
            'concepts_learned': len(self.medical_concepts),
            'q_matrix_shape': self.Q.shape
        }
    
    def discover_missing_concepts(self, patient_concepts: List[str], max_exploration: int = 10) -> List[Dict]:
        
        if self.Q is None:
            print("Q-matrix not trained yet!")
            return []
        
        missing_concepts = []
        patient_concepts_lower = [c.lower() for c in patient_concepts]
        
        for patient_concept in patient_concepts_lower:
            if patient_concept in self.concept_to_index:
                concept_idx = self.concept_to_index[patient_concept]
                q_row = self.Q[concept_idx, :]
                
                top_indices = np.argsort(q_row)[-max_exploration:][::-1]
                
                for idx in top_indices:
                    related_concept = self.index_to_concept[idx]
                    q_value = q_row[idx]
                    
                    if (related_concept not in patient_concepts_lower and 
                        q_value > 1.0):
                        
                        missing_concepts.append({
                            'concept': related_concept,
                            'q_value': q_value,
                            'source_concept': patient_concept,
                            'mimic_frequency': self.concept_frequencies.get(related_concept, 0)
                        })
        
        seen = set()
        unique_missing = []
        for item in sorted(missing_concepts, key=lambda x: x['q_value'], reverse=True):
            if item['concept'] not in seen:
                unique_missing.append(item)
                seen.add(item['concept'])
        
        return unique_missing[:max_exploration]
    
    def get_concept_relationships(self, concept: str, top_k: int = 5) -> List[Tuple[str, float]]:
        
        if concept not in self.concept_to_index or self.Q is None:
            return []
        
        concept_idx = self.concept_to_index[concept]
        q_row = self.Q[concept_idx, :]
        
        top_indices = np.argsort(q_row)[-top_k:][::-1]
        
        relationships = []
        for idx in top_indices:
            related_concept = self.index_to_concept[idx]
            q_value = q_row[idx]
            if q_value > 0:
                relationships.append((related_concept, q_value))
        
        return relationships
    
    def save_learned_model(self, filepath: str):
        
        model_data = {
            'Q_matrix': self.Q.tolist() if self.Q is not None else None,
            'medical_concepts': self.medical_concepts,
            'concept_frequencies': dict(self.concept_frequencies),
            'mimic_cases_used': self.mimic_cases_used,
            'learning_params': {
                'alpha': self.alpha,
                'gamma': self.gamma,
                'epsilon': self.epsilon
            },
            'timestamp': datetime.now().isoformat()
        }
        
        import json
        with open(filepath, 'w') as f:
            json.dump(model_data, f, indent=2)
        
        print(f"Saved learned model to {filepath}")

def test_fixed_qlearning():
    
    print("Testing FIXED Q-Learning System")
    
    test_cases = [
        {
            'symptoms': 'chest pain, shortness of breath',
            'diagnoses': ['myocardial infarction'],
            'procedures': ['ecg'],
            'medications': ['aspirin'],
            'known_info': ['chest pain', 'ecg']
        },
        {
            'symptoms': 'headache, fever',
            'diagnoses': ['infection'],
            'procedures': ['blood test'],
            'medications': ['antibiotics'],
            'known_info': ['fever', 'blood test']
        }
    ]
    
    qlearning = PureMedicalQLearning()
    
    concepts = qlearning.learn_from_mimic_data_only(test_cases)
    print(f"Learned {len(concepts)} concepts")
    
    if concepts:
        scores = qlearning.train_qlearning(100)
        print(f"Training completed: {scores}")
        
        patient_concepts = ['chest', 'pain']
        missing = qlearning.discover_missing_concepts(patient_concepts)
        print(f"Found {len(missing)} missing concepts")
    
    return qlearning

if __name__ == "__main__":
    test_fixed_qlearning()