"""
DDI Cost Matrix Generation for CNCRC Framework

This module integrates MIMIC-IV drug data with DrugBank DDI information
to generate cost matrices for the CNCRC algorithm. It provides flexible
cost mapping strategies and efficient matrix generation.

Key Components:
- CostMappingConfig: Configuration for cost mapping strategies
- DDICostGenerator: Main class for generating cost matrices
- SeverityMapper: Maps DDI severity to numerical costs
- Integration functions with MIMIC and DrugBank data

The cost matrix C[i,j] represents the cost of predicting drug i when
the correct drug is j, based on drug-drug interaction severity.
"""

import os
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Any, Union, Set
from pathlib import Path
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
import pickle

from .mimic_loader import MimicDataLoader, load_mimic_data
from .drugbank_interface import DrugBankInterface, load_drugbank_interface, DDIEntry
from ..core.data_structures import CostMatrix

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class CostMappingConfig:
    """Configuration for DDI cost mapping."""
    
    # Severity to cost mapping strategies
    severity_mapping: Dict[str, float] = field(default_factory=lambda: {
        'major': 1.0,
        'moderate': 0.6,
        'minor': 0.3,
        'contraindicated': 1.5,
        'unknown': 0.5,
        'high': 1.0,        # Alternative naming
        'medium': 0.6,      # Alternative naming  
        'low': 0.3,         # Alternative naming
        'none': 0.05        # No known interaction → assign minimal ambiguity cost (>0)
    })
    
    # Cost matrix properties
    default_cost: float = 0.05  # Minimal positive cost when no interaction found
    self_cost: float = 0.0     # Cost of drug with itself (diagonal)
    symmetric: bool = True     # Whether cost matrix should be symmetric
    
    # Data filtering options
    min_interaction_count: int = 0  # Minimum interactions per drug to include
    drug_filter: Optional[List[str]] = None  # Specific drugs to include
    max_drugs: Optional[int] = None  # Maximum number of drugs to include
    
    # Output options
    output_dir: str = "data/processed/cost_matrices"
    output_format: str = "npz"  # "npz", "csv", "json", "pickle"
    save_metadata: bool = True
    
    # MIMIC integration
    mimic_splits: List[str] = field(default_factory=lambda: ["train", "validation", "test"])
    mimic_min_frequency: int = 1  # Minimum frequency of drug in MIMIC to include
    
    # DrugBank integration  
    drugbank_use_mock: bool = False
    drugbank_cache_dir: str = "data/processed/drugbank"


class SeverityMapper:
    """Maps DDI severity labels to numerical costs."""
    
    def __init__(self, config: CostMappingConfig):
        """
        Initialize severity mapper.
        
        Args:
            config: Cost mapping configuration
        """
        self.config = config
        self.severity_map = config.severity_mapping.copy()
        
        # Normalize severity names (convert to lowercase)
        self.severity_map = {k.lower(): v for k, v in self.severity_map.items()}
        
        logger.info(f"Initialized severity mapper with {len(self.severity_map)} mappings")
    
    def map_severity(self, severity: str) -> float:
        """
        Map severity string to numerical cost.
        
        Args:
            severity: Severity label (e.g., 'major', 'moderate')
            
        Returns:
            Numerical cost value
        """
        if severity is None:
            return self.config.default_cost
        
        severity_lower = severity.lower().strip()
        
        if severity_lower in self.severity_map:
            return self.severity_map[severity_lower]
        
        # Try partial matches for compound severity names
        for sev_key, cost in self.severity_map.items():
            if sev_key in severity_lower or severity_lower in sev_key:
                return cost
        
        logger.warning(f"Unknown severity '{severity}', using default cost {self.config.default_cost}")
        return self.config.default_cost
    
    def get_mapping_stats(self) -> Dict[str, Any]:
        """Get statistics about severity mapping."""
        return {
            "total_mappings": len(self.severity_map),
            "severity_range": (min(self.severity_map.values()), max(self.severity_map.values())),
            "default_cost": self.config.default_cost,
            "mappings": self.severity_map
        }


class DDICostGenerator:
    """
    Main class for generating DDI-based cost matrices.
    
    Integrates MIMIC-IV drug vocabulary with DrugBank DDI information
    to create cost matrices for CNCRC prediction tasks.
    """
    
    def __init__(self, config: CostMappingConfig):
        """
        Initialize DDI cost generator.
        
        Args:
            config: Cost mapping configuration
        """
        self.config = config
        self.severity_mapper = SeverityMapper(config)
        
        # Data components
        self.mimic_loader: Optional[MimicDataLoader] = None
        self.drugbank_interface: Optional[DrugBankInterface] = None
        
        # Generated data
        self.drug_vocabulary: Optional[Dict[str, int]] = None  # drug_name -> index
        self.drug_list: Optional[List[str]] = None  # index -> drug_name
        self.cost_matrix: Optional[np.ndarray] = None
        self.interaction_count: Optional[Dict[Tuple[str, str], int]] = None
        
        # Create output directory
        Path(config.output_dir).mkdir(parents=True, exist_ok=True)
        
        logger.info("Initialized DDI cost generator")
    
    def load_mimic_data(self, **kwargs) -> None:
        """Load MIMIC data loader."""
        logger.info("Loading MIMIC data loader")
        
        self.mimic_loader = load_mimic_data(**kwargs)
        
        # Build drug vocabulary from specified splits
        vocab = self.mimic_loader.get_drug_vocabulary(splits=self.config.mimic_splits)
        
        # Filter by frequency if specified
        if self.config.mimic_min_frequency > 1:
            # This is a simplification - in practice, we'd need frequency counts
            logger.info(f"Applying minimum frequency filter: {self.config.mimic_min_frequency}")
        
        self.drug_vocabulary = vocab
        self.drug_list = list(vocab.keys())
        
        # Apply drug filtering if specified
        if self.config.drug_filter:
            filtered_drugs = [drug for drug in self.drug_list if drug in self.config.drug_filter]
            self.drug_list = filtered_drugs
            self.drug_vocabulary = {drug: i for i, drug in enumerate(filtered_drugs)}
        
        # Apply maximum drug limit if specified
        if self.config.max_drugs and len(self.drug_list) > self.config.max_drugs:
            self.drug_list = self.drug_list[:self.config.max_drugs]
            self.drug_vocabulary = {drug: i for i, drug in enumerate(self.drug_list)}
        
        logger.info(f"Drug vocabulary loaded: {len(self.drug_list)} drugs")
    
    def load_drugbank_data(self, **kwargs) -> None:
        """Load DrugBank interface."""
        logger.info("Loading DrugBank interface")
        
        drugbank_config = {
            'use_mock_data': self.config.drugbank_use_mock,
            'cache_dir': self.config.drugbank_cache_dir
        }
        drugbank_config.update(kwargs)
        
        self.drugbank_interface = load_drugbank_interface(**drugbank_config)
        
        logger.info("DrugBank interface loaded")
    
    def generate_cost_matrix(self) -> np.ndarray:
        """
        Generate cost matrix from MIMIC and DrugBank data.
        
        Returns:
            Cost matrix as numpy array
        """
        if self.drug_list is None:
            raise ValueError("Drug vocabulary not loaded. Call load_mimic_data() first.")
        
        if self.drugbank_interface is None:
            raise ValueError("DrugBank interface not loaded. Call load_drugbank_data() first.")
        
        logger.info(f"Generating cost matrix for {len(self.drug_list)} drugs")
        
        n_drugs = len(self.drug_list)
        cost_matrix = np.full((n_drugs, n_drugs), self.config.default_cost, dtype=np.float32)
        
        # Track interactions found
        interactions_found = 0
        self.interaction_count = {}
        
        # Generate cost matrix
        for i, drug1 in enumerate(self.drug_list):
            for j, drug2 in enumerate(self.drug_list):
                if i == j:
                    # Self-cost (diagonal)
                    cost_matrix[i, j] = self.config.self_cost
                else:
                    # Get interaction between drugs
                    interaction = self.drugbank_interface.get_interaction(drug1, drug2)
                    
                    if interaction:
                        cost = self.severity_mapper.map_severity(interaction.severity)
                        cost_matrix[i, j] = cost
                        
                        # Track interaction
                        drug_pair = tuple(sorted([drug1, drug2]))
                        self.interaction_count[drug_pair] = self.interaction_count.get(drug_pair, 0) + 1
                        interactions_found += 1
                    else:
                        cost_matrix[i, j] = self.config.default_cost
        
        # Ensure symmetry if required
        if self.config.symmetric:
            cost_matrix = (cost_matrix + cost_matrix.T) / 2.0
            # Restore diagonal
            np.fill_diagonal(cost_matrix, self.config.self_cost)
        
        self.cost_matrix = cost_matrix
        
        logger.info(f"Cost matrix generated: {interactions_found} interactions found")
        logger.info(f"Matrix shape: {cost_matrix.shape}")
        logger.info(f"Non-zero entries: {np.count_nonzero(cost_matrix)}")
        logger.info(f"Matrix sparsity: {1.0 - np.count_nonzero(cost_matrix) / cost_matrix.size:.3f}")
        
        return cost_matrix
    
    def get_cost_matrix_stats(self) -> Dict[str, Any]:
        """Get statistics about the generated cost matrix."""
        if self.cost_matrix is None:
            raise ValueError("Cost matrix not generated yet")
        
        matrix = self.cost_matrix
        
        stats = {
            "shape": matrix.shape,
            "total_elements": matrix.size,
            "non_zero_elements": np.count_nonzero(matrix),
            "sparsity": 1.0 - np.count_nonzero(matrix) / matrix.size,
            "min_cost": float(np.min(matrix)),
            "max_cost": float(np.max(matrix)),
            "mean_cost": float(np.mean(matrix)),
            "std_cost": float(np.std(matrix)),
            "unique_costs": len(np.unique(matrix)),
            "symmetric": np.allclose(matrix, matrix.T),
            "diagonal_zero": np.allclose(np.diag(matrix), self.config.self_cost)
        }
        
        # Cost distribution
        unique_costs, counts = np.unique(matrix, return_counts=True)
        stats["cost_distribution"] = {
            float(cost): int(count) for cost, count in zip(unique_costs, counts)
        }
        
        # Interaction statistics
        if self.interaction_count:
            stats["interaction_stats"] = {
                "total_interactions": len(self.interaction_count),
                "unique_drug_pairs": len(set(self.interaction_count.keys()))
            }
        
        return stats
    
    def save_cost_matrix(self, filename: Optional[str] = None) -> str:
        """
        Save cost matrix to file.
        
        Args:
            filename: Optional filename, auto-generated if not provided
            
        Returns:
            Path to saved file
        """
        if self.cost_matrix is None:
            raise ValueError("Cost matrix not generated yet")
        
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"ddi_cost_matrix_{len(self.drug_list)}drugs_{timestamp}"
        
        # Remove extension if provided (we'll add our own)
        filename = Path(filename).stem
        
        output_path = Path(self.config.output_dir) / f"{filename}.{self.config.output_format}"
        
        # Save matrix
        if self.config.output_format == "npz":
            np.savez_compressed(
                output_path,
                cost_matrix=self.cost_matrix,
                drug_list=np.array(self.drug_list),
                drug_vocabulary=self.drug_vocabulary
            )
        elif self.config.output_format == "csv":
            df = pd.DataFrame(self.cost_matrix, index=self.drug_list, columns=self.drug_list)
            df.to_csv(output_path)
        elif self.config.output_format == "pickle":
            with open(output_path, 'wb') as f:
                pickle.dump({
                    'cost_matrix': self.cost_matrix,
                    'drug_list': self.drug_list,
                    'drug_vocabulary': self.drug_vocabulary
                }, f)
        elif self.config.output_format == "json":
            # Convert matrix to list for JSON serialization
            data = {
                'cost_matrix': self.cost_matrix.tolist(),
                'drug_list': self.drug_list,
                'drug_vocabulary': self.drug_vocabulary
            }
            with open(output_path, 'w') as f:
                json.dump(data, f, indent=2)
        else:
            raise ValueError(f"Unsupported output format: {self.config.output_format}")
        
        # Save metadata if requested
        if self.config.save_metadata:
            metadata_path = Path(self.config.output_dir) / f"{filename}_metadata.json"
            metadata = {
                "generation_time": datetime.now().isoformat(),
                "config": {
                    "severity_mapping": self.config.severity_mapping,
                    "default_cost": self.config.default_cost,
                    "symmetric": self.config.symmetric,
                    "mimic_splits": self.config.mimic_splits,
                    "drugbank_use_mock": self.config.drugbank_use_mock
                },
                "data_stats": self.get_cost_matrix_stats(),
                "severity_stats": self.severity_mapper.get_mapping_stats()
            }
            
            with open(metadata_path, 'w') as f:
                json.dump(metadata, f, indent=2)
            
            logger.info(f"Metadata saved to {metadata_path}")
        
        logger.info(f"Cost matrix saved to {output_path}")
        return str(output_path)
    
    def load_cost_matrix(self, filepath: str) -> np.ndarray:
        """
        Load cost matrix from file.
        
        Args:
            filepath: Path to cost matrix file
            
        Returns:
            Loaded cost matrix
        """
        filepath = Path(filepath)
        
        if filepath.suffix == ".npz":
            data = np.load(filepath, allow_pickle=True)
            self.cost_matrix = data['cost_matrix']
            self.drug_list = data['drug_list'].tolist()
            self.drug_vocabulary = data['drug_vocabulary'].item()
        elif filepath.suffix == ".csv":
            df = pd.read_csv(filepath, index_col=0)
            self.cost_matrix = df.values
            self.drug_list = df.index.tolist()
            self.drug_vocabulary = {drug: i for i, drug in enumerate(self.drug_list)}
        elif filepath.suffix == ".pkl":
            with open(filepath, 'rb') as f:
                data = pickle.load(f)
                self.cost_matrix = data['cost_matrix']
                self.drug_list = data['drug_list']
                self.drug_vocabulary = data['drug_vocabulary']
        elif filepath.suffix == ".json":
            with open(filepath, 'r') as f:
                data = json.load(f)
                self.cost_matrix = np.array(data['cost_matrix'])
                self.drug_list = data['drug_list']
                self.drug_vocabulary = data['drug_vocabulary']
        else:
            raise ValueError(f"Unsupported file format: {filepath.suffix}")
        
        logger.info(f"Cost matrix loaded from {filepath}")
        logger.info(f"Matrix shape: {self.cost_matrix.shape}")
        
        return self.cost_matrix
    
    def to_cncrc_cost_matrix(self) -> CostMatrix:
        """
        Convert to CNCRC CostMatrix format.
        
        Returns:
            CostMatrix object
        """
        if self.cost_matrix is None:
            raise ValueError("Cost matrix not generated yet")
        
        return CostMatrix(
            matrix=self.cost_matrix,
            drug_ids=self.drug_list
        )


# Convenience functions

def generate_ddi_cost_matrix(
    mimic_splits: List[str] = ["train", "validation", "test"],
    severity_mapping: Optional[Dict[str, float]] = None,
    output_path: Optional[str] = None,
    max_drugs: Optional[int] = None,
    use_mock_drugbank: bool = True,
    **kwargs
) -> Tuple[np.ndarray, List[str]]:
    """
    Convenience function to generate DDI cost matrix.
    
    Args:
        mimic_splits: MIMIC data splits to use for drug vocabulary
        severity_mapping: Custom severity to cost mapping
        output_path: Path to save the matrix (optional)
        max_drugs: Maximum number of drugs to include
        use_mock_drugbank: Whether to use mock DrugBank data
        **kwargs: Additional configuration options
        
    Returns:
        Tuple of (cost_matrix, drug_list)
    """
    config = CostMappingConfig(
        mimic_splits=mimic_splits,
        max_drugs=max_drugs,
        drugbank_use_mock=use_mock_drugbank
    )
    
    if severity_mapping:
        config.severity_mapping.update(severity_mapping)
    
    # Update config with additional kwargs
    for key, value in kwargs.items():
        if hasattr(config, key):
            setattr(config, key, value)
    
    generator = DDICostGenerator(config)
    
    # Load data
    generator.load_mimic_data()
    generator.load_drugbank_data()
    
    # Generate matrix
    cost_matrix = generator.generate_cost_matrix()
    
    # Save if path provided
    if output_path:
        generator.save_cost_matrix(output_path)
    
    return cost_matrix, generator.drug_list


def load_ddi_cost_matrix(filepath: str) -> Tuple[np.ndarray, List[str], Dict[str, int]]:
    """
    Load DDI cost matrix from file.
    
    Args:
        filepath: Path to cost matrix file
        
    Returns:
        Tuple of (cost_matrix, drug_list, drug_vocabulary)
    """
    generator = DDICostGenerator(CostMappingConfig())
    cost_matrix = generator.load_cost_matrix(filepath)
    
    return cost_matrix, generator.drug_list, generator.drug_vocabulary


def create_mock_cost_matrix(
    drug_names: List[str],
    interaction_prob: float = 0.2,
    severity_dist: Dict[str, float] = None,
    min_base_cost: float = 0.05
) -> np.ndarray:
    """
    Create a mock cost matrix for testing.
    
    Args:
        drug_names: List of drug names
        interaction_prob: Probability of interaction between any two drugs
        severity_dist: Distribution of severities
        
    Returns:
        Mock cost matrix
    """
    if severity_dist is None:
        severity_dist = {'major': 0.3, 'moderate': 0.5, 'minor': 0.2}
    
    n_drugs = len(drug_names)
    cost_matrix = np.zeros((n_drugs, n_drugs))
    
    severity_costs = {'major': 1.0, 'moderate': 0.6, 'minor': 0.3}
    severities = list(severity_dist.keys())
    probabilities = list(severity_dist.values())
    
    np.random.seed(42)  # For reproducibility
    
    for i in range(n_drugs):
        for j in range(i + 1, n_drugs):
            if np.random.random() < interaction_prob:
                severity = np.random.choice(severities, p=probabilities)
                cost = severity_costs[severity]
                cost_matrix[i, j] = cost
                cost_matrix[j, i] = cost  # Symmetric

    # Ensure any off-diagonal pair has at least minimal ambiguity cost
    for i in range(n_drugs):
        for j in range(n_drugs):
            if i == j:
                continue
            if cost_matrix[i, j] <= 0.0:
                cost_matrix[i, j] = min_base_cost

    return cost_matrix


