import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import linregress
from sklearn.metrics import (
    precision_recall_curve,
    roc_curve,
    auc,
    average_precision_score,
)
from sklearn.preprocessing import label_binarize
import logging

logger = logging.getLogger(__name__)


def plot_residuals(
    y_true: np.array, y_pred: np.array, label_name: str, path: str = None
):
    """Plot residual analysis with optimized size for MLflow."""
    # Ensure 1D arrays
    y_true = np.asarray(y_true).ravel()
    y_pred = np.asarray(y_pred).ravel()

    residuals = y_true - y_pred
    fig, ax = plt.subplots(figsize=(8, 6), dpi=100)
    ax.scatter(y_pred, residuals, alpha=0.5, s=20)
    ax.axhline(y=0, color="r", linestyle="--", label="Zero Line")
    ax.set_xlabel(f"Predicted {label_name}")
    ax.set_ylabel("Residuals")
    ax.set_title("Residual Plot")
    ax.legend()
    fig.tight_layout()

    if path:
        fig.savefig(path, bbox_inches="tight")
        plt.close(fig)

    return fig


def plot_qq(y_true: np.array, y_pred: np.array, label_name: str, path: str = None):
    """Plot Q-Q plot comparing distributions of y_true and y_pred with optimized size for MLflow."""
    # Ensure 1D arrays
    y_true = np.asarray(y_true).ravel()
    y_pred = np.asarray(y_pred).ravel()

    fig, ax = plt.subplots(figsize=(8, 6), dpi=100)
    sorted_y_true = np.sort(y_true)
    sorted_y_pred = np.sort(y_pred)

    ax.scatter(sorted_y_true, sorted_y_pred, alpha=0.5, s=20)
    ax.plot(
        [min(sorted_y_true), max(sorted_y_true)],
        [min(sorted_y_true), max(sorted_y_true)],
        "r--",
        label="y_true = y_pred",
    )
    ax.set_xlabel("y_true Quantiles")
    ax.set_ylabel("y_pred Quantiles")
    ax.set_title(f"Q-Q Plot Comparison: {label_name}")
    ax.legend()
    fig.tight_layout()

    if path:
        fig.savefig(path, bbox_inches="tight")
        plt.close(fig)

    return fig


def plot_pr_curve(
    y_true: np.array,
    y_prob: np.array,
    label_column: str,
    num_classes: int,
    path: str = None,
):
    """Plot precision-recall curve with optimized size for MLflow."""
    # Ensure arrays are properly shaped
    y_true = np.asarray(y_true).ravel()
    y_prob = np.asarray(y_prob)

    fig, ax = plt.subplots(figsize=(8, 6), dpi=100)

    if num_classes == 2:  # Binary classification
        if y_prob.ndim > 1:
            y_prob = y_prob[:, 1]  # Take probability of positive class

        precision, recall, _ = precision_recall_curve(y_true, y_prob)
        ap = average_precision_score(y_true, y_prob)

        ax.plot(recall, precision, lw=2, label=f"AP = {ap:.2f}")
        ax.set_xlabel("Recall")
        ax.set_ylabel("Precision")
        ax.set_title(f"Precision-Recall Curve: {label_column}")
        ax.legend(loc="best")
        ax.grid(alpha=0.3)

    else:  # Multiclass classification
        # One-hot encode true labels for multiclass
        y_true_bin = label_binarize(y_true, classes=range(num_classes))

        # Calculate precision-recall curve for each class
        for i in range(num_classes):
            precision, recall, _ = precision_recall_curve(
                y_true_bin[:, i], y_prob[:, i]
            )
            ap = average_precision_score(y_true_bin[:, i], y_prob[:, i])
            ax.plot(recall, precision, lw=2, label=f"Class {i} (AP = {ap:.2f})")

        ax.set_xlabel("Recall")
        ax.set_ylabel("Precision")
        ax.set_title(f"Precision-Recall Curve: {label_column}")
        ax.legend(loc="best")
        ax.grid(alpha=0.3)

    fig.tight_layout()

    if path:
        fig.savefig(path, bbox_inches="tight")
        plt.close(fig)

    return fig


def plot_roc_curve(
    y_true: np.array,
    y_prob: np.array,
    label_column: str,
    num_classes: int,
    path: str = None,
):
    """Plot ROC curve with optimized size for MLflow."""
    # Ensure arrays are properly shaped
    y_true = np.asarray(y_true).ravel()
    y_prob = np.asarray(y_prob)

    fig, ax = plt.subplots(figsize=(8, 6), dpi=100)

    if num_classes == 2:  # Binary classification
        if y_prob.ndim > 1:
            y_prob = y_prob[:, 1]  # Take probability of positive class

        fpr, tpr, _ = roc_curve(y_true, y_prob)
        roc_auc = auc(fpr, tpr)

        ax.plot(fpr, tpr, lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
        ax.plot([0, 1], [0, 1], "k--", lw=1)
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel("False Positive Rate")
        ax.set_ylabel("True Positive Rate")
        ax.set_title(f"ROC Curve: {label_column}")
        ax.legend(loc="lower right")
        ax.grid(alpha=0.3)

    else:  # Multiclass classification
        # One-hot encode true labels for multiclass
        y_true_bin = label_binarize(y_true, classes=range(num_classes))

        # Calculate ROC curve for each class
        for i in range(num_classes):
            fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i])
            roc_auc = auc(fpr, tpr)
            ax.plot(fpr, tpr, lw=2, label=f"Class {i} (AUC = {roc_auc:.2f})")

        ax.plot([0, 1], [0, 1], "k--", lw=1)
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel("False Positive Rate")
        ax.set_ylabel("True Positive Rate")
        ax.set_title(f"ROC Curve: {label_column}")
        ax.legend(loc="lower right")
        ax.grid(alpha=0.3)

    fig.tight_layout()

    if path:
        fig.savefig(path, bbox_inches="tight")
        plt.close(fig)

    return fig


def plot_regression(
    y_true: np.array,
    y_pred: np.array,
    label_name: str,
    path: str = None,
):
    """Plot regression results with optimized size for MLflow."""
    # Ensure 1D arrays for regression analysis
    y_true = np.asarray(y_true).ravel()
    y_pred = np.asarray(y_pred).ravel()

    fig, ax = plt.subplots(figsize=(8, 6), dpi=100)
    ax.scatter(y_true, y_pred, alpha=0.5, s=20)
    # perfect correlation line
    ax.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], "r--", lw=1)
    ax.set_xlabel(f"True {label_name}")
    ax.set_ylabel(f"Predicted {label_name}")
    ax.set_title("Regression Results")

    # Calculate statistics
    slope, intercept, r_value, p_value, std_err = linregress(y_true, y_pred)
    r2 = r_value**2

    # plot regression line with default cyvler color
    ax.plot(
        [y_true.min(), y_true.max()],
        [slope * y_true.min() + intercept, slope * y_true.max() + intercept],
        color="black",
        linestyle="--",
        lw=1,
    )

    # Add statistics to plot - fix LaTeX formatting
    ax.text(0.05, 0.95, f"R²: {r2:.2f}", transform=ax.transAxes)
    ax.text(0.05, 0.90, f"p-value: {p_value:.2f}", transform=ax.transAxes)
    ax.text(0.05, 0.85, f"std_err: {std_err:.2f}", transform=ax.transAxes)
    fig.tight_layout()

    logger.debug(f"R²: {r2:.2f}")
    logger.debug(f"p-value: {p_value:.2f}")
    logger.debug(f"std_err: {std_err:.2f}")

    if path:
        fig.savefig(path, bbox_inches="tight")
        plt.close(fig)

    return fig
