"""
LLM-based Drug Probability Predictor with Caching

This module implements LLM-based drug recommendation with comprehensive caching
to minimize API costs while providing realistic P(drug|clinical_context).

Authors: Research Team
Date: 2024
"""

import json
import logging
import hashlib
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import pandas as pd
import pickle
import time
from datetime import datetime
import re

from ..core.data_structures import ClinicalContext

logger = logging.getLogger(__name__)


@dataclass
class LLMConfig:
    """Configuration for LLM-based prediction."""
    
    # API Configuration
    provider: str = "openai"  # "openai", "anthropic", "azure", etc.
    model_name: str = "gpt-4-turbo-preview"
    api_key: Optional[str] = None
    base_url: Optional[str] = None
    
    # Generation Parameters
    temperature: float = 0.1  # Low temperature for consistency
    max_tokens: int = 200
    timeout: int = 30
    
    # Retry Configuration
    max_retries: int = 3
    retry_delay: float = 1.0
    
    # Caching Configuration
    cache_file: Path = Path("data/processed/mimic/llm_probabilities_cache.pkl")
    enable_cache: bool = True
    cache_version: str = "v1.0"
    

@dataclass 
class CachedPrediction:
    """Cached LLM prediction result."""
    
    clinical_context_hash: str
    drug_vocabulary: List[str]
    probabilities: np.ndarray
    raw_response: str
    timestamp: datetime
    model_info: Dict[str, Any]
    parsing_success: bool
    

class LLMDrugPredictor:
    """
    LLM-based drug probability predictor with intelligent caching.
    
    Uses structured prompts to obtain drug probability distributions from LLMs,
    with comprehensive caching to minimize API costs and ensure reproducibility.
    """
    
    def __init__(self, config: LLMConfig, drug_vocabulary: List[str]):
        """
        Initialize LLM predictor.
        
        Args:
            config: LLM configuration
            drug_vocabulary: List of available drugs
        """
        self.config = config
        self.drug_vocabulary = drug_vocabulary
        self.cache: Dict[str, CachedPrediction] = {}
        
        # API client (will be initialized on first use)
        self._client = None
        
        # Statistics
        self.stats = {
            'api_calls': 0,
            'cache_hits': 0,
            'parsing_failures': 0,
            'total_tokens': 0
        }
        
        # Load existing cache
        if config.enable_cache:
            self._load_cache()
        
        logger.info(f"Initialized LLM predictor: {len(drug_vocabulary)} drugs, "
                   f"{len(self.cache)} cached predictions")
    
    def predict_probabilities(self, clinical_context: ClinicalContext) -> np.ndarray:
        """
        Predict drug probabilities for clinical context.
        
        Args:
            clinical_context: Patient clinical context
            
        Returns:
            Probability distribution over drug vocabulary
        """
        # Generate context hash for caching
        context_hash = self._hash_clinical_context(clinical_context)
        
        # Check cache first
        if self.config.enable_cache and context_hash in self.cache:
            cached = self.cache[context_hash]
            if len(cached.probabilities) == len(self.drug_vocabulary):
                self.stats['cache_hits'] += 1
                logger.debug(f"Cache hit for context {context_hash[:8]}...")
                return cached.probabilities
        
        # Generate prediction via LLM
        logger.debug(f"Making LLM API call for context {context_hash[:8]}...")
        probabilities, raw_response, parsing_success = self._generate_llm_prediction(clinical_context)
        
        # Cache the result
        if self.config.enable_cache:
            cached_prediction = CachedPrediction(
                clinical_context_hash=context_hash,
                drug_vocabulary=self.drug_vocabulary.copy(),
                probabilities=probabilities,
                raw_response=raw_response,
                timestamp=datetime.now(),
                model_info={
                    'provider': self.config.provider,
                    'model': self.config.model_name,
                    'temperature': self.config.temperature
                },
                parsing_success=parsing_success
            )
            self.cache[context_hash] = cached_prediction
            self._save_cache()
        
        return probabilities
    
    def _generate_llm_prediction(self, clinical_context: ClinicalContext) -> Tuple[np.ndarray, str, bool]:
        """
        Generate prediction using LLM API.
        
        Returns:
            (probabilities, raw_response, parsing_success)
        """
        # Construct prompt
        prompt = self._build_prediction_prompt(clinical_context)
        
        # Make API call with retries
        raw_response = self._call_llm_api(prompt)
        
        # Parse response to probabilities
        probabilities, parsing_success = self._parse_drug_probabilities(raw_response)
        
        # Update statistics
        if not parsing_success:
            self.stats['parsing_failures'] += 1
            logger.warning(f"Failed to parse LLM response: {raw_response[:100]}...")
        
        return probabilities, raw_response, parsing_success
    
    def _build_prediction_prompt(self, clinical_context: ClinicalContext) -> str:
        """Build structured prompt for drug recommendation."""
        
        # Format clinical context
        context_str = self._format_clinical_context(clinical_context)
        
        # Create drug list with indices for easier parsing
        drug_list_str = "\n".join([f"{i+1}. {drug}" for i, drug in enumerate(self.drug_vocabulary)])
        
        prompt = f"""You are an expert clinical pharmacist. Based on the patient information provided, please rate the appropriateness of each drug on a scale of 0-100, where:
- 0 = Completely inappropriate/contraindicated
- 50 = Neutral/average appropriateness  
- 100 = Highly appropriate/first-line treatment

PATIENT INFORMATION:
{context_str}

AVAILABLE DRUGS:
{drug_list_str}

Please provide your ratings in the following EXACT format:
DRUG_RATINGS:
1. [Drug 1]: [0-100 score]
2. [Drug 2]: [0-100 score]
...

Your ratings should consider:
- Clinical indication match
- Patient safety profile (age, conditions, allergies)
- Drug efficacy for the indication
- Potential contraindications
- Standard treatment guidelines

DRUG_RATINGS:"""

        return prompt
    
    def _format_clinical_context(self, context: ClinicalContext) -> str:
        """Format clinical context for LLM prompt."""
        
        parts = []
        
        if context.age is not None:
            parts.append(f"Age: {context.age} years")
        
        if context.gender:
            parts.append(f"Gender: {context.gender}")
        
        if context.diagnoses:
            diagnoses_str = ", ".join(context.diagnoses)
            parts.append(f"Diagnoses: {diagnoses_str}")
        
        if hasattr(context, 'symptoms') and context.symptoms:
            symptoms_str = ", ".join(context.symptoms)
            parts.append(f"Symptoms: {symptoms_str}")
        
        if hasattr(context, 'allergies') and context.allergies:
            allergies_str = ", ".join(context.allergies)
            parts.append(f"Allergies: {allergies_str}")
        
        if hasattr(context, 'current_medications') and context.current_medications:
            meds_str = ", ".join(context.current_medications)
            parts.append(f"Current Medications: {meds_str}")
        
        return "\n".join(parts) if parts else "No specific clinical information provided"
    
    def _call_llm_api(self, prompt: str) -> str:
        """Make API call to LLM with retry logic."""
        
        if self._client is None:
            self._initialize_client()
        
        for attempt in range(self.config.max_retries):
            try:
                response = self._make_api_request(prompt)
                self.stats['api_calls'] += 1
                return response
                
            except Exception as e:
                logger.warning(f"API call attempt {attempt + 1} failed: {e}")
                if attempt < self.config.max_retries - 1:
                    time.sleep(self.config.retry_delay * (2 ** attempt))
                else:
                    raise
        
        raise RuntimeError(f"All {self.config.max_retries} API attempts failed")
    
    def _initialize_client(self):
        """Initialize LLM API client based on provider."""
        
        if self.config.provider.lower() == "openai":
            try:
                import openai
                self._client = openai.OpenAI(
                    api_key=self.config.api_key,
                    base_url=self.config.base_url
                )
            except ImportError:
                raise ImportError("OpenAI package not installed. Run: pip install openai")
                
        elif self.config.provider.lower() == "anthropic":
            try:
                import anthropic
                self._client = anthropic.Anthropic(api_key=self.config.api_key)
            except ImportError:
                raise ImportError("Anthropic package not installed. Run: pip install anthropic")
                
        else:
            raise ValueError(f"Unsupported provider: {self.config.provider}")
    
    def _make_api_request(self, prompt: str) -> str:
        """Make the actual API request."""
        
        if self.config.provider.lower() == "openai":
            response = self._client.chat.completions.create(
                model=self.config.model_name,
                messages=[{"role": "user", "content": prompt}],
                temperature=self.config.temperature,
                max_tokens=self.config.max_tokens,
                timeout=self.config.timeout
            )
            return response.choices[0].message.content
            
        elif self.config.provider.lower() == "anthropic":
            response = self._client.messages.create(
                model=self.config.model_name,
                max_tokens=self.config.max_tokens,
                temperature=self.config.temperature,
                messages=[{"role": "user", "content": prompt}]
            )
            return response.content[0].text
            
        else:
            raise ValueError(f"API request not implemented for {self.config.provider}")
    
    def _parse_drug_probabilities(self, response: str) -> Tuple[np.ndarray, bool]:
        """
        Parse LLM response to extract drug probabilities.
        
        Returns:
            (probabilities, parsing_success)
        """
        try:
            # Extract DRUG_RATINGS section
            ratings_section = re.search(r'DRUG_RATINGS:\s*(.*)', response, re.DOTALL | re.IGNORECASE)
            if not ratings_section:
                logger.warning("Could not find DRUG_RATINGS section in response")
                return self._fallback_probabilities(), False
            
            ratings_text = ratings_section.group(1)
            
            # Parse individual drug ratings
            scores = np.zeros(len(self.drug_vocabulary))
            
            for i, drug in enumerate(self.drug_vocabulary):
                # Look for patterns like "1. DrugName: 75" or "DrugName: 75"
                patterns = [
                    rf'{i+1}\.\s*{re.escape(drug)}\s*:\s*(\d+)',
                    rf'{re.escape(drug)}\s*:\s*(\d+)',
                    rf'{i+1}\.\s*[^:]*{re.escape(drug)}[^:]*:\s*(\d+)'
                ]
                
                score_found = False
                for pattern in patterns:
                    match = re.search(pattern, ratings_text, re.IGNORECASE)
                    if match:
                        score = int(match.group(1))
                        scores[i] = max(0, min(100, score))  # Clamp to [0, 100]
                        score_found = True
                        break
                
                if not score_found:
                    logger.debug(f"Could not find score for drug: {drug}")
                    scores[i] = 50  # Default neutral score
            
            # Convert scores to probabilities
            if np.sum(scores) == 0:
                probabilities = np.ones(len(self.drug_vocabulary)) / len(self.drug_vocabulary)
            else:
                probabilities = scores / np.sum(scores)
            
            return probabilities, True
            
        except Exception as e:
            logger.error(f"Error parsing LLM response: {e}")
            return self._fallback_probabilities(), False
    
    def _fallback_probabilities(self) -> np.ndarray:
        """Generate fallback uniform probabilities when parsing fails."""
        return np.ones(len(self.drug_vocabulary)) / len(self.drug_vocabulary)
    
    def _hash_clinical_context(self, context: ClinicalContext) -> str:
        """Generate hash for clinical context for caching."""
        
        # Create consistent string representation
        context_dict = {
            'age': context.age,
            'gender': context.gender,
            'diagnoses': sorted(context.diagnoses) if context.diagnoses else [],
            'drug_vocabulary_hash': hashlib.md5(
                ''.join(sorted(self.drug_vocabulary)).encode()
            ).hexdigest()[:8]
        }
        
        # Include additional fields if available
        for field in ['symptoms', 'allergies', 'current_medications']:
            if hasattr(context, field):
                value = getattr(context, field)
                if value:
                    context_dict[field] = sorted(value) if isinstance(value, list) else value
        
        context_str = json.dumps(context_dict, sort_keys=True)
        return hashlib.md5(context_str.encode()).hexdigest()
    
    def _load_cache(self):
        """Load cache from disk."""
        if self.config.cache_file.exists():
            try:
                with open(self.config.cache_file, 'rb') as f:
                    cache_data = pickle.load(f)
                
                # Verify cache version and structure
                if (isinstance(cache_data, dict) and 
                    'version' in cache_data and 
                    cache_data['version'] == self.config.cache_version):
                    
                    self.cache = cache_data.get('predictions', {})
                    logger.info(f"Loaded {len(self.cache)} cached predictions")
                else:
                    logger.warning("Cache version mismatch or invalid format, starting fresh")
                    
            except Exception as e:
                logger.warning(f"Failed to load cache: {e}")
    
    def _save_cache(self):
        """Save cache to disk."""
        try:
            # Ensure cache directory exists
            self.config.cache_file.parent.mkdir(parents=True, exist_ok=True)
            
            cache_data = {
                'version': self.config.cache_version,
                'predictions': self.cache,
                'stats': self.stats,
                'saved_at': datetime.now().isoformat()
            }
            
            with open(self.config.cache_file, 'wb') as f:
                pickle.dump(cache_data, f)
                
        except Exception as e:
            logger.error(f"Failed to save cache: {e}")
    
    def get_cache_stats(self) -> Dict[str, Any]:
        """Get cache and API usage statistics."""
        return {
            'cached_predictions': len(self.cache),
            'cache_hit_rate': self.stats['cache_hits'] / max(1, self.stats['api_calls'] + self.stats['cache_hits']),
            'api_calls_made': self.stats['api_calls'],
            'parsing_failure_rate': self.stats['parsing_failures'] / max(1, self.stats['api_calls']),
            'total_api_calls': self.stats['api_calls'],
            'total_cache_hits': self.stats['cache_hits']
        }
    
    def precompute_probabilities(self, clinical_contexts: List[ClinicalContext], 
                               progress_callback: Optional[callable] = None) -> None:
        """
        Precompute probabilities for a list of clinical contexts.
        
        This is useful for batch processing to minimize API costs.
        
        Args:
            clinical_contexts: List of clinical contexts to process
            progress_callback: Optional callback for progress updates
        """
        logger.info(f"Precomputing probabilities for {len(clinical_contexts)} contexts...")
        
        for i, context in enumerate(clinical_contexts):
            try:
                self.predict_probabilities(context)
                
                if progress_callback:
                    progress_callback(i + 1, len(clinical_contexts))
                
                if (i + 1) % 100 == 0:
                    logger.info(f"Processed {i + 1}/{len(clinical_contexts)} contexts")
                    
            except Exception as e:
                logger.error(f"Failed to process context {i}: {e}")
        
        # Final save
        if self.config.enable_cache:
            self._save_cache()
        
        logger.info(f"Precomputation complete: {self.get_cache_stats()}")


def create_llm_predictor(config: LLMConfig, drug_vocabulary: List[str]) -> LLMDrugPredictor:
    """
    Create LLM drug predictor.
    
    Args:
        config: LLM configuration
        drug_vocabulary: List of available drugs
        
    Returns:
        Initialized LLMDrugPredictor
    """
    return LLMDrugPredictor(config, drug_vocabulary)

