import matplotlib.pyplot as plt
import numpy as np
import logging
from pathlib import Path
from typing import Dict

logger = logging.getLogger(__name__)


def create_hazard_forest_plot(model_results: Dict, analysis_dir: Path):
    logger.info("Creating hazard model forest plot...")
    if "hazard_model" not in model_results:
        logger.warning("No hazard model results available for forest plot")
        return
    try:
        params = model_results["hazard_model"]["params"]
        pvalues = model_results["hazard_model"]["pvalues"]
        summary_text = model_results["hazard_model"]["summary"]
        import re

        lines = summary_text.split("\n")
        coef_lines = []
        in_coef_table = False
        for line in lines:
            if "coef    std err" in line:
                in_coef_table = True
                continue
            if in_coef_table and line.strip() and not line.startswith("="):
                if line.strip().startswith("Intercept") or line.strip().startswith(
                    "step_factor"
                ):
                    continue
                coef_lines.append(line.strip())
            elif in_coef_table and line.startswith("="):
                break
        effect_vars = []
        effect_coefs = []
        effect_cis = []
        effect_pvals = []
        for line in coef_lines:
            if not line.strip():
                continue
            parts = line.split()
            if len(parts) >= 6:
                param_name = parts[0]
                coef = float(parts[1])
                ci_lower = float(parts[5])
                ci_upper = float(parts[6])
                pval = pvalues.get(param_name, 1.0)
                effect_vars.append(
                    param_name.replace("_", " ")
                    .replace("[T.", " = ")
                    .replace("]", "")
                    .title()
                )
                effect_coefs.append(coef)
                effect_cis.append([ci_lower, ci_upper])
                effect_pvals.append(pval)
        if not effect_vars:
            logger.warning("No substantive effects found for hazard forest plot")
            return
        fig, ax = plt.subplots(figsize=(10, max(4, len(effect_vars) * 0.8)))
        y_pos = np.arange(len(effect_vars))
        for i, (var, coef, ci, pval) in enumerate(
            zip(effect_vars, effect_coefs, effect_cis, effect_pvals)
        ):
            color = "red" if pval < 0.05 else "blue"
            ax.errorbar(
                coef,
                i,
                xerr=[[coef - ci[0]], [ci[1] - coef]],
                fmt="o",
                color=color,
                capsize=5,
                capthick=2,
                markersize=8,
            )
            if pval < 0.001:
                pval_text = "p<0.001"
            elif pval < 0.01:
                pval_text = f"p={pval:.3f}"
            else:
                pval_text = f"p={pval:.3f}"
            ax.text(max(ci[1], coef) + 0.05, i, pval_text, va="center", fontsize=9)
        ax.axvline(x=0, color="black", linestyle="--", alpha=0.5)
        ax.set_yticks(y_pos)
        ax.set_yticklabels(effect_vars)
        ax.set_xlabel("Log Hazard Ratio (Effect Size)")
        ax.set_title("Forest Plot: Effect Sizes from Discrete-Time Hazard Model")
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(
            analysis_dir / "hazard_forest_plot.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()
        logger.info("✓ Created hazard model forest plot")
    except Exception as e:
        logger.error(f"Error creating hazard forest plot: {e}")
