"""
DrugBank API/XML Interface with Caching for CNCRC Framework

This module provides functionality to parse DrugBank database for drug-drug
interaction (DDI) information. It supports both XML parsing and caching
mechanisms to efficiently extract DDI data needed for cost matrix construction.

Key Components:
- DrugBankParser: Main XML parsing class
- DDICache: Caching mechanism for parsed data
- DrugBankInterface: High-level interface for DDI queries
- Mock data generation for development and testing

The module handles the large DrugBank XML file efficiently with smart caching
and provides DDI information in formats compatible with CNCRC data structures.
"""

import os
import json
import csv
import xml.etree.ElementTree as ET
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple, Optional, Any, Set
from pathlib import Path
import logging
from dataclasses import dataclass, field
from datetime import datetime
import hashlib
import warnings

from ..core.data_structures import DrugInteraction, CostMatrix

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class DrugBankConfig:
    """Configuration for DrugBank interface."""
    xml_path: Optional[str] = None  # Path to DrugBank XML file
    cache_dir: str = "data/processed/drugbank"  # Cache directory
    cache_format: str = "parquet"  # "parquet", "csv", or "json"
    enable_cache: bool = True  # Enable/disable caching
    max_cache_age_days: int = 30  # Maximum cache age in days
    chunk_size: int = 1000  # Chunk size for processing large XML
    use_mock_data: bool = False  # Use mock data when XML not available
    mock_drug_count: int = 100  # Number of drugs in mock data
    mock_interaction_prob: float = 0.1  # Probability of interaction in mock data
    interaction_types: List[str] = field(default_factory=lambda: [
        "major", "moderate", "minor", "contraindicated", "monitor"
    ])


@dataclass 
class DDIEntry:
    """Represents a single drug-drug interaction entry."""
    drug1_id: str
    drug1_name: str
    drug2_id: str
    drug2_name: str
    interaction_type: str
    severity: str
    description: str
    mechanism: Optional[str] = None
    management: Optional[str] = None
    evidence_level: Optional[str] = None
    source: str = "DrugBank"
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            "drug1_id": self.drug1_id,
            "drug1_name": self.drug1_name,
            "drug2_id": self.drug2_id,
            "drug2_name": self.drug2_name,
            "interaction_type": self.interaction_type,
            "severity": self.severity,
            "description": self.description,
            "mechanism": self.mechanism,
            "management": self.management,
            "evidence_level": self.evidence_level,
            "source": self.source
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'DDIEntry':
        """Create DDIEntry from dictionary."""
        return cls(**data)


class DDICache:
    """
    Caching mechanism for DrugBank DDI data.
    
    Provides efficient storage and retrieval of parsed DDI information
    with automatic cache invalidation and format conversion.
    """
    
    def __init__(self, config: DrugBankConfig):
        """
        Initialize DDI cache.
        
        Args:
            config: DrugBank configuration
        """
        self.config = config
        self.cache_dir = Path(config.cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        # Cache file paths
        self.cache_file = self.cache_dir / f"ddi_cache.{config.cache_format}"
        self.metadata_file = self.cache_dir / "cache_metadata.json"
        
        logger.info(f"Initialized DDI cache: {self.cache_dir}")
    
    def _get_xml_hash(self) -> Optional[str]:
        """Get hash of XML file for cache validation."""
        if not self.config.xml_path or not Path(self.config.xml_path).exists():
            return None
        
        try:
            with open(self.config.xml_path, 'rb') as f:
                # Hash first 1MB for efficiency
                content = f.read(1024 * 1024)
                return hashlib.md5(content).hexdigest()
        except Exception as e:
            logger.warning(f"Failed to hash XML file: {e}")
            return None
    
    def is_cache_valid(self) -> bool:
        """Check if cache is valid and up-to-date."""
        if not self.config.enable_cache:
            return False
        
        if not self.cache_file.exists() or not self.metadata_file.exists():
            return False
        
        try:
            # Check metadata
            with open(self.metadata_file, 'r') as f:
                metadata = json.load(f)
            
            # Check age
            cache_date = datetime.fromisoformat(metadata.get('created', '2000-01-01'))
            age_days = (datetime.now() - cache_date).days
            
            if age_days > self.config.max_cache_age_days:
                logger.info(f"Cache expired (age: {age_days} days)")
                return False
            
            # Check XML hash if available
            current_hash = self._get_xml_hash()
            cached_hash = metadata.get('xml_hash')
            
            if current_hash and cached_hash and current_hash != cached_hash:
                logger.info("XML file changed, cache invalid")
                return False
            
            logger.info("Cache is valid")
            return True
            
        except Exception as e:
            logger.warning(f"Cache validation failed: {e}")
            return False
    
    def load_cache(self) -> Optional[List[DDIEntry]]:
        """Load DDI data from cache."""
        if not self.is_cache_valid():
            return None
        
        try:
            logger.info(f"Loading cache from {self.cache_file}")
            
            if self.config.cache_format == "parquet":
                df = pd.read_parquet(self.cache_file)
                return [DDIEntry.from_dict(row) for _, row in df.iterrows()]
            
            elif self.config.cache_format == "csv":
                df = pd.read_csv(self.cache_file)
                return [DDIEntry.from_dict(row) for _, row in df.iterrows()]
            
            elif self.config.cache_format == "json":
                with open(self.cache_file, 'r') as f:
                    data = json.load(f)
                return [DDIEntry.from_dict(entry) for entry in data]
            
            else:
                raise ValueError(f"Unsupported cache format: {self.config.cache_format}")
                
        except Exception as e:
            logger.error(f"Failed to load cache: {e}")
            return None
    
    def save_cache(self, ddi_entries: List[DDIEntry]) -> bool:
        """Save DDI data to cache."""
        if not self.config.enable_cache:
            return False
        
        try:
            logger.info(f"Saving {len(ddi_entries)} DDI entries to cache")
            
            # Convert to DataFrame for easier handling
            data = [entry.to_dict() for entry in ddi_entries]
            df = pd.DataFrame(data)
            
            # Save data
            if self.config.cache_format == "parquet":
                df.to_parquet(self.cache_file, index=False)
            elif self.config.cache_format == "csv":
                df.to_csv(self.cache_file, index=False)
            elif self.config.cache_format == "json":
                with open(self.cache_file, 'w') as f:
                    json.dump(data, f, indent=2)
            else:
                raise ValueError(f"Unsupported cache format: {self.config.cache_format}")
            
            # Save metadata
            metadata = {
                "created": datetime.now().isoformat(),
                "entry_count": len(ddi_entries),
                "xml_hash": self._get_xml_hash(),
                "config": {
                    "cache_format": self.config.cache_format,
                    "use_mock_data": self.config.use_mock_data
                }
            }
            
            with open(self.metadata_file, 'w') as f:
                json.dump(metadata, f, indent=2)
            
            logger.info(f"Cache saved successfully")
            return True
            
        except Exception as e:
            logger.error(f"Failed to save cache: {e}")
            return False


class DrugBankParser:
    """
    DrugBank XML parser for extracting drug-drug interaction data.
    
    Handles large XML files efficiently with streaming parsing and
    extracts relevant DDI information in structured format.
    """
    
    def __init__(self, config: DrugBankConfig):
        """
        Initialize DrugBank parser.
        
        Args:
            config: DrugBank configuration
        """
        self.config = config
        
        # DrugBank XML namespaces
        self.namespaces = {
            'db': 'http://www.drugbank.ca',
            'ns': 'http://www.drugbank.ca'
        }
    
    def _generate_mock_data(self) -> List[DDIEntry]:
        """Generate mock DDI data for development and testing."""
        logger.info(f"Generating mock DDI data ({self.config.mock_drug_count} drugs)")
        
        np.random.seed(42)  # For reproducibility
        
        ddi_entries = []
        drug_count = self.config.mock_drug_count
        
        # Generate mock drug names and IDs
        drugs = []
        for i in range(drug_count):
            drugs.append({
                'id': f'DB{i:05d}',
                'name': f'MockDrug_{i:03d}'
            })
        
        # Generate interactions
        for i in range(drug_count):
            for j in range(i + 1, drug_count):
                # Random chance of interaction
                if np.random.random() < self.config.mock_interaction_prob:
                    severity = np.random.choice(['major', 'moderate', 'minor'])
                    interaction_type = np.random.choice(['therapeutic', 'toxic', 'absorption'])
                    
                    entry = DDIEntry(
                        drug1_id=drugs[i]['id'],
                        drug1_name=drugs[i]['name'],
                        drug2_id=drugs[j]['id'],
                        drug2_name=drugs[j]['name'],
                        interaction_type=interaction_type,
                        severity=severity,
                        description=f"Mock interaction between {drugs[i]['name']} and {drugs[j]['name']}",
                        mechanism=f"Mock mechanism for {severity} interaction",
                        management="Monitor patient closely",
                        evidence_level="theoretical",
                        source="MockData"
                    )
                    ddi_entries.append(entry)
        
        logger.info(f"Generated {len(ddi_entries)} mock DDI entries")
        return ddi_entries
    
    def _parse_drug_element(self, drug_elem: ET.Element) -> Optional[Dict[str, str]]:
        """Parse individual drug element from XML."""
        try:
            # Extract drug ID (primary key)
            drug_id_elem = drug_elem.find('.//db:drugbank-id[@primary="true"]', self.namespaces)
            if drug_id_elem is None:
                return None
            
            drug_id = drug_id_elem.text
            
            # Extract drug name
            name_elem = drug_elem.find('.//db:name', self.namespaces)
            drug_name = name_elem.text if name_elem is not None else f"Unknown_{drug_id}"
            
            return {
                'id': drug_id,
                'name': drug_name
            }
            
        except Exception as e:
            logger.warning(f"Failed to parse drug element: {e}")
            return None
    
    def _parse_interactions(self, drug_elem: ET.Element, drug_info: Dict[str, str]) -> List[DDIEntry]:
        """Parse drug interactions from drug element."""
        interactions = []
        
        try:
            # Find drug interactions section
            interactions_elem = drug_elem.find('.//db:drug-interactions', self.namespaces)
            if interactions_elem is None:
                return interactions
            
            # Parse each interaction
            for interaction_elem in interactions_elem.findall('.//db:drug-interaction', self.namespaces):
                try:
                    # Extract interaction details
                    target_drug_id = interaction_elem.find('.//db:drugbank-id', self.namespaces)
                    target_drug_name = interaction_elem.find('.//db:name', self.namespaces)
                    description = interaction_elem.find('.//db:description', self.namespaces)
                    
                    if None in [target_drug_id, target_drug_name, description]:
                        continue
                    
                    # Create DDI entry
                    entry = DDIEntry(
                        drug1_id=drug_info['id'],
                        drug1_name=drug_info['name'],
                        drug2_id=target_drug_id.text,
                        drug2_name=target_drug_name.text,
                        interaction_type="drug_interaction",
                        severity="unknown",  # DrugBank doesn't always provide severity
                        description=description.text,
                        source="DrugBank"
                    )
                    
                    interactions.append(entry)
                    
                except Exception as e:
                    logger.warning(f"Failed to parse interaction: {e}")
                    continue
                    
        except Exception as e:
            logger.warning(f"Failed to parse interactions for {drug_info.get('id', 'unknown')}: {e}")
        
        return interactions
    
    def parse_xml(self) -> List[DDIEntry]:
        """Parse DrugBank XML file and extract DDI entries."""
        if self.config.use_mock_data or not self.config.xml_path:
            logger.info("Using mock data (XML path not provided or mock mode enabled)")
            return self._generate_mock_data()
        
        xml_path = Path(self.config.xml_path)
        if not xml_path.exists():
            logger.warning(f"XML file not found: {xml_path}, falling back to mock data")
            return self._generate_mock_data()
        
        logger.info(f"Parsing DrugBank XML: {xml_path}")
        
        ddi_entries = []
        
        try:
            # Parse XML incrementally for memory efficiency
            context = ET.iterparse(str(xml_path), events=('start', 'end'))
            context = iter(context)
            event, root = next(context)
            
            processed_drugs = 0
            
            for event, elem in context:
                if event == 'end' and elem.tag.endswith('drug'):
                    # Parse drug information
                    drug_info = self._parse_drug_element(elem)
                    
                    if drug_info:
                        # Extract interactions for this drug
                        interactions = self._parse_interactions(elem, drug_info)
                        ddi_entries.extend(interactions)
                        
                        processed_drugs += 1
                        if processed_drugs % 100 == 0:
                            logger.info(f"Processed {processed_drugs} drugs, found {len(ddi_entries)} interactions")
                    
                    # Clear element to save memory
                    elem.clear()
                    root.clear()
        
        except Exception as e:
            logger.error(f"Failed to parse XML: {e}")
            logger.info("Falling back to mock data")
            return self._generate_mock_data()
        
        logger.info(f"XML parsing complete: {len(ddi_entries)} DDI entries extracted")
        return ddi_entries


class DrugBankInterface:
    """
    High-level interface for DrugBank drug-drug interaction queries.
    
    Provides easy access to DDI information with caching, filtering,
    and integration with CNCRC data structures.
    """
    
    def __init__(self, config: DrugBankConfig):
        """
        Initialize DrugBank interface.
        
        Args:
            config: DrugBank configuration
        """
        self.config = config
        self.cache = DDICache(config)
        self.parser = DrugBankParser(config)
        
        self._ddi_entries: Optional[List[DDIEntry]] = None
        self._drug_index: Optional[Dict[str, str]] = None  # name -> id mapping
        
        logger.info("Initialized DrugBank interface")
    
    def load_data(self, force_reload: bool = False) -> None:
        """
        Load DDI data from cache or parse XML.
        
        Args:
            force_reload: Force reload from XML even if cache exists
        """
        if self._ddi_entries is not None and not force_reload:
            logger.info("DDI data already loaded")
            return
        
        # Try to load from cache first
        if not force_reload:
            cached_data = self.cache.load_cache()
            if cached_data is not None:
                self._ddi_entries = cached_data
                self._build_drug_index()
                logger.info(f"Loaded {len(self._ddi_entries)} DDI entries from cache")
                return
        
        # Parse XML and cache results
        logger.info("Loading DDI data from XML")
        self._ddi_entries = self.parser.parse_xml()
        self._build_drug_index()
        
        # Save to cache
        if self._ddi_entries:
            self.cache.save_cache(self._ddi_entries)
        
        logger.info(f"Loaded {len(self._ddi_entries)} DDI entries")
    
    def _build_drug_index(self) -> None:
        """Build drug name to ID index for faster lookups."""
        if not self._ddi_entries:
            return
        
        self._drug_index = {}
        
        for entry in self._ddi_entries:
            self._drug_index[entry.drug1_name.lower()] = entry.drug1_id
            self._drug_index[entry.drug2_name.lower()] = entry.drug2_id
        
        logger.info(f"Built drug index with {len(self._drug_index)} entries")
    
    def get_all_drugs(self) -> Set[str]:
        """Get set of all drug names in the database."""
        if self._ddi_entries is None:
            self.load_data()
        
        drugs = set()
        for entry in self._ddi_entries:
            drugs.add(entry.drug1_name)
            drugs.add(entry.drug2_name)
        
        return drugs
    
    def get_drug_interactions(self, drug_name: str) -> List[DDIEntry]:
        """
        Get all interactions for a specific drug.
        
        Args:
            drug_name: Name of the drug
            
        Returns:
            List of DDI entries involving the drug
        """
        if self._ddi_entries is None:
            self.load_data()
        
        interactions = []
        drug_name_lower = drug_name.lower()
        
        for entry in self._ddi_entries:
            if (entry.drug1_name.lower() == drug_name_lower or 
                entry.drug2_name.lower() == drug_name_lower):
                interactions.append(entry)
        
        return interactions
    
    def get_interaction(self, drug1: str, drug2: str) -> Optional[DDIEntry]:
        """
        Get interaction between two specific drugs.
        
        Args:
            drug1: Name of first drug
            drug2: Name of second drug
            
        Returns:
            DDI entry if interaction exists, None otherwise
        """
        if self._ddi_entries is None:
            self.load_data()
        
        drug1_lower = drug1.lower()
        drug2_lower = drug2.lower()
        
        for entry in self._ddi_entries:
            entry_drug1 = entry.drug1_name.lower()
            entry_drug2 = entry.drug2_name.lower()
            
            if ((entry_drug1 == drug1_lower and entry_drug2 == drug2_lower) or
                (entry_drug1 == drug2_lower and entry_drug2 == drug1_lower)):
                return entry
        
        return None
    
    def get_interaction_matrix(self, drug_names: List[str]) -> CostMatrix:
        """
        Create cost matrix for given drugs based on DDI information.
        
        Args:
            drug_names: List of drug names
            
        Returns:
            CostMatrix with DDI-based costs
        """
        if self._ddi_entries is None:
            self.load_data()
        
        n_drugs = len(drug_names)
        cost_matrix = np.zeros((n_drugs, n_drugs))
        
        # Severity to cost mapping
        severity_costs = {
            'major': 1.0,
            'moderate': 0.6,
            'minor': 0.3,
            'contraindicated': 1.5,
            'unknown': 0.5
        }
        
        # Fill interaction costs
        for i, drug1 in enumerate(drug_names):
            for j, drug2 in enumerate(drug_names):
                if i != j:
                    interaction = self.get_interaction(drug1, drug2)
                    if interaction:
                        cost = severity_costs.get(interaction.severity, 0.5)
                        cost_matrix[i, j] = cost
        
        return CostMatrix(
            matrix=cost_matrix,
            drug_ids=drug_names
        )
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get statistics about the DDI database."""
        if self._ddi_entries is None:
            self.load_data()
        
        if not self._ddi_entries:
            return {"error": "No DDI data available"}
        
        # Count interactions by severity
        severity_counts = {}
        for entry in self._ddi_entries:
            severity = entry.severity
            severity_counts[severity] = severity_counts.get(severity, 0) + 1
        
        # Count unique drugs
        unique_drugs = self.get_all_drugs()
        
        return {
            "total_interactions": len(self._ddi_entries),
            "unique_drugs": len(unique_drugs),
            "severity_distribution": severity_counts,
            "source": "Mock" if self.config.use_mock_data else "DrugBank",
            "cache_enabled": self.config.enable_cache
        }


# Convenience functions
def load_drugbank_interface(
    xml_path: Optional[str] = None,
    cache_dir: str = "data/processed/drugbank",
    use_mock_data: bool = False,
    **kwargs
) -> DrugBankInterface:
    """
    Convenience function to create DrugBank interface.
    
    Args:
        xml_path: Path to DrugBank XML file
        cache_dir: Directory for cache files
        use_mock_data: Whether to use mock data
        **kwargs: Additional configuration parameters
        
    Returns:
        Configured DrugBankInterface instance
    """
    config = DrugBankConfig(
        xml_path=xml_path,
        cache_dir=cache_dir,
        use_mock_data=use_mock_data,
        **kwargs
    )
    
    interface = DrugBankInterface(config)
    interface.load_data()
    
    return interface


def get_ddi_data(
    xml_path: Optional[str] = None,
    cache_path: Optional[str] = None,
    use_mock: bool = False
) -> pd.DataFrame:
    """
    Simple function to get DDI data as DataFrame.
    
    Args:
        xml_path: Path to DrugBank XML file
        cache_path: Path to cache file
        use_mock: Use mock data
        
    Returns:
        DataFrame with DDI information
    """
    # Determine cache directory
    if cache_path:
        cache_dir = str(Path(cache_path).parent)
    else:
        cache_dir = "data/processed/drugbank"
    
    interface = load_drugbank_interface(
        xml_path=xml_path,
        cache_dir=cache_dir,
        use_mock_data=use_mock
    )
    
    # Convert to DataFrame
    ddi_entries = interface._ddi_entries
    if not ddi_entries:
        return pd.DataFrame()
    
    data = [entry.to_dict() for entry in ddi_entries]
    return pd.DataFrame(data)



