"""
Non-Coverage Cost (Cost_NC) Annotation Utilities

This module provides structures and helpers to annotate the clinical cost of
"missing a drug when it is clinically needed" (Cost_NC). Cost_NC is a
per-drug scalar independent of DDI and reflects the consequence of not
administering the correct drug. It is intended to be used in non-coverage
risk calculations and reported on a normalized [0, 1] scale.

Key components:
- NonCoverageCostConfig: normalization and defaults
- NonCoverageCostTable: in-memory annotation table with metadata
- Convenience: build normalized Cost_NC vectors aligned to a drug vocabulary

Notes:
- Values included here are provisional examples to unblock experiments and
  MUST be replaced/augmented by authoritative guideline-based annotations
  (WHO EML + specialty guidelines) before publication.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any
from pathlib import Path
import json
import math
import numpy as np
from .icd_to_drug_mapper import ICDToDrugMapper, create_icd_to_drug_mapper


@dataclass
class NonCoverageCostConfig:
    """Configuration for Cost_NC normalization and defaults.

    Attributes:
        normalization: 'max' → scale to [0,1] by max value; 'mean' → divide by
            mean so that E[Cost_NC] ≈ 1 (still report alongside for clarity).
        default_level: Fallback raw level used when a drug lacks annotation.
            Interpreted on the same scale as the raw levels provided
            (e.g., 1-5 ladder).
        scale_max: Maximum value of the raw scale (used for 'max' normalization
            if max_raw is unknown/too small).
    """

    normalization: str = "max"  # 'max' or 'mean'
    default_level: float = 2.0
    scale_max: float = 5.0
    # Authoritative annotation controls
    require_authoritative: bool = False  # Enforce presence of sources for every annotated drug
    require_complete_coverage: bool = False  # Enforce no default fallback for any drug


@dataclass
class SourceRef:
    organization: str
    guideline: str
    year: Optional[int] = None
    section: Optional[str] = None
    class_recommendation: Optional[str] = None  # e.g., Class I/IIa/...
    evidence_level: Optional[str] = None        # e.g., Level A/B/C
    url: Optional[str] = None


@dataclass
class NonCoverageCostEntry:
    drug_name: str
    level_raw: float
    # Accept either rich dicts (preferred) or plain strings (legacy) in JSON
    sources: List[Any] = field(default_factory=list)
    notes: str = ""

    def has_authoritative_source(self) -> bool:
        if not self.sources:
            return False
        # Consider a source authoritative if it has an organization and guideline name
        for s in self.sources:
            if isinstance(s, dict):
                if s.get("organization") and s.get("guideline"):
                    return True
        return False


class NonCoverageCostTable:
    """In-memory table for Cost_NC annotations with normalization helpers."""

    def __init__(self, entries: Optional[Dict[str, NonCoverageCostEntry]] = None):
        self._entries: Dict[str, NonCoverageCostEntry] = entries or {}

    # ---- CRUD ----
    def set(self, drug: str, level_raw: float, sources: Optional[List[str]] = None, notes: str = "") -> None:
        self._entries[drug] = NonCoverageCostEntry(drug, float(level_raw), sources or [], notes)

    def get(self, drug: str) -> Optional[NonCoverageCostEntry]:
        return self._entries.get(drug)

    def to_dict(self) -> Dict[str, Dict[str, Any]]:
        return {
            d: {"level_raw": e.level_raw, "sources": e.sources, "notes": e.notes}
            for d, e in self._entries.items()
        }

    @staticmethod
    def from_dict(data: Dict[str, Dict[str, Any]]) -> "NonCoverageCostTable":
        entries: Dict[str, NonCoverageCostEntry] = {}
        for drug, info in data.items():
            entries[drug] = NonCoverageCostEntry(
                drug_name=drug,
                level_raw=float(info.get("level_raw", 0.0)),
                sources=list(info.get("sources", [])),
                notes=str(info.get("notes", "")),
            )
        return NonCoverageCostTable(entries)

    # ---- I/O ----
    def save_json(self, path: str | Path) -> None:
        Path(path).parent.mkdir(parents=True, exist_ok=True)
        with open(path, "w", encoding="utf-8") as f:
            json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

    @staticmethod
    def load_json(path: str | Path) -> "NonCoverageCostTable":
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        return NonCoverageCostTable.from_dict(data)

    # ---- Normalization ----
    def build_vector(
        self,
        drug_list: List[str],
        config: Optional[NonCoverageCostConfig] = None,
    ) -> Tuple[np.ndarray, Dict[str, Any]]:
        """Build a normalized Cost_NC vector aligned with `drug_list`.

        Returns:
            (vector, metadata) where vector is np.ndarray aligned to drug_list
            and metadata records raw stats and normalization parameters.
        """
        cfg = config or NonCoverageCostConfig()

        raw_values: List[float] = []
        used_default: Dict[str, bool] = {}
        missing_authoritative: List[str] = []
        for d in drug_list:
            entry = self.get(d)
            if entry is None:
                raw_values.append(float(cfg.default_level))
                used_default[d] = True
            else:
                raw_values.append(float(entry.level_raw))
                used_default[d] = False
                if cfg.require_authoritative and not entry.has_authoritative_source():
                    missing_authoritative.append(d)

        raw = np.array(raw_values, dtype=float)
        meta: Dict[str, Any] = {
            "normalization": cfg.normalization,
            "default_level": cfg.default_level,
            "scale_max": cfg.scale_max,
            "raw_stats": {
                "min": float(np.min(raw)),
                "max": float(np.max(raw)),
                "mean": float(np.mean(raw)),
                "std": float(np.std(raw)),
                "count": int(raw.size),
                "defaults_used": int(sum(1 for v in used_default.values() if v)),
            },
        }

        if cfg.require_complete_coverage and meta["raw_stats"]["defaults_used"] > 0:
            missing = [d for d, used in used_default.items() if used]
            raise ValueError(
                f"Cost_NC requires complete coverage but {len(missing)} drugs lack annotations. Missing examples: {missing[:10]}"
            )

        if cfg.require_authoritative and missing_authoritative:
            raise ValueError(
                f"Authoritative sources required but missing for {len(missing_authoritative)} drugs. Examples: {missing_authoritative[:10]}"
            )

        if cfg.normalization == "max":
            denom = max(float(np.max(raw)), float(cfg.scale_max))
            denom = denom if denom > 0 else 1.0
            vec = raw / denom
        elif cfg.normalization == "mean":
            mean_val = float(np.mean(raw))
            mean_val = mean_val if mean_val > 0 else 1.0
            vec = raw / mean_val
        else:
            raise ValueError(f"Unknown normalization: {cfg.normalization}")

        meta["normalized_stats"] = {
            "min": float(np.min(vec)),
            "max": float(np.max(vec)),
            "mean": float(np.mean(vec)),
            "std": float(np.std(vec)),
        }

        return vec, meta


# ---- Provisional seed annotations ----

def build_provisional_table() -> NonCoverageCostTable:
    """Create a small, provisional Cost_NC table to unblock experiments.

    IMPORTANT: These entries are demonstration-only and based on typical
    clinical reasoning for consequence of omission. Replace with authoritative
    annotations for real studies.
    """
    t = NonCoverageCostTable()

    # 5 (highest consequence)
    t.set("insulin", 5, ["diabetes acute complications"], "critical glycemic control")

    # 4 (high consequence)
    for d in ["apixaban", "rivaroxaban", "dabigatran", "warfarin"]:
        t.set(d, 4, ["AF stroke prevention"], "anticoagulation for AF/TE risk")
    for d in ["ceftriaxone", "levofloxacin"]:
        t.set(d, 4, ["severe pneumonia"], "acute infection first-line for severe cases")

    # 3 (moderate-high consequence)
    for d in ["lisinopril", "losartan", "metoprolol", "spironolactone", "furosemide"]:
        t.set(d, 3, ["HF/HTN core therapy"], "organ risk reduction/volume management")
    for d in ["amlodipine", "hydrochlorothiazide", "diltiazem", "metformin"]:
        t.set(d, 3, ["HTN/DM maintenance"], "maintenance or control therapy")

    # 2 (low-moderate consequence)
    for d in [
        "digoxin",
        "atorvastatin",
        "sertraline",
        "fluoxetine",
        "escitalopram",
        "citalopram",
        "bupropion",
        "mirtazapine",
        "azithromycin",
        "amoxicillin",
        "doxycycline",
        "sodium_bicarbonate",
        "albuterol",
        "ipratropium",
        "prednisone",
        "tiotropium",
        "sitagliptin",
        "glyburide",
        "glipizide",
        "amiodarone",
        "aspirin",
    ]:
        t.set(d, 2, ["maintenance/adjunct/situational"], "lower short-term omission consequence")

    return t


def export_cost_nc_vector(
    drug_list: List[str],
    output_path: Optional[str] = None,
    config: Optional[NonCoverageCostConfig] = None,
    table: Optional[NonCoverageCostTable] = None,
) -> Tuple[np.ndarray, Dict[str, Any]]:
    """Build and optionally persist a Cost_NC vector aligned to `drug_list`.

    If `table` is not provided, a provisional one is used. When `output_path`
    is provided, a JSON containing the aligned vector and metadata is saved.
    """
    tbl = table or build_provisional_table()
    vec, meta = tbl.build_vector(drug_list, config)

    if output_path is not None:
        Path(output_path).parent.mkdir(parents=True, exist_ok=True)
        payload = {
            "drug_list": list(drug_list),
            "cost_nc": vec.tolist(),
            "metadata": meta,
        }
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(payload, f, indent=2, ensure_ascii=False)

    return vec, meta


# ---- Auto-derivation from ICD guideline mapping ----

def _heuristic_level_from_rec(indication: str, line: str, drug: str) -> float:
    ind = (indication or "").lower()
    line = (line or "").lower()
    name = (drug or "").lower()

    # Special cases
    if name == "insulin":
        return 5.0
    if name in {"apixaban", "rivaroxaban", "dabigatran", "warfarin"}:
        return 4.0

    # Base by line of therapy
    if any(k in line for k in ["first-line", "hospital", "acute"]):
        base = 3.0
    elif any(k in line for k in ["second-line", "alternative", "traditional", "variable", "maintenance"]):
        base = 2.0
    else:
        base = 2.0

    # Indication adjustments
    if any(k in ind for k in ["severe", "acute", "infection", "pneumonia"]):
        base += 1.0  # acute/severe conditions
    if any(k in ind for k in ["heart failure", "diabetic nephropathy"]):
        base += 1.0  # organ risk reduction
    if any(k in ind for k in ["adjunct", "prevention", "dyslipidemia"]):
        base -= 0.5  # adjunctive roles

    # Clip to [1,5]
    return float(max(1.0, min(5.0, base)))


def build_table_from_icd_guidelines(mapper: Optional[ICDToDrugMapper] = None) -> NonCoverageCostTable:
    """Derive a Cost_NC table heuristically from the ICD→drug guideline mapping.

    This uses indication and line-of-therapy metadata to approximate
    omission consequence. It is deterministic and reproducible given the
    mapper version.
    """
    mp = mapper or create_icd_to_drug_mapper()
    table = NonCoverageCostTable()

    # Aggregate maximum criticality per drug over all ICDs
    seen: Dict[str, float] = {}
    for icd, recs in mp.icd_drug_mapping.items():
        for rec in recs:
            lvl = _heuristic_level_from_rec(rec.indication, rec.line_of_therapy, rec.drug_name)
            seen[rec.drug_name] = max(lvl, seen.get(rec.drug_name, 0.0))

    # Fill table with structured source reference
    src = {
        "organization": "Internal",
        "guideline": "ICDToDrug guideline mapping (embedded)",
        "year": None,
        "section": None,
        "class_recommendation": None,
        "evidence_level": None,
        "url": None,
    }
    for drug, lvl in seen.items():
        table.set(drug, lvl, sources=[src], notes="auto-derived from ICD mapping metadata")

    return table
