import numpy as np
from matplotlib import pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colorbar import Colorbar
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

from glest.helpers import calibration_curve as calibration_curve_glest
from sklearn.calibration import calibration_curve
from matplotlib.lines import Line2D

def grouping_diagram(
    frac_pos,
    counts,
    mean_scores,
    bins,
    ax: plt.Axes = None,
    plot_calibration: bool = True,
    plot_bins: bool = True,
    plot_cbar: bool = True,
    plot_hist: bool = True,
    plot_legend: bool = True,
    fig_kw: dict = None,
    scatter_kw: dict = None,
    calibration_kw: dict = None,
    hist_kw: dict = None,
    bin_kw: dict = None,
    legend_kw: dict = None,
):
    frac_pos = np.array(frac_pos)
    counts = np.array(counts)
    mean_scores = np.array(mean_scores)
    bins = np.array(bins)

    assert frac_pos.shape == counts.shape == mean_scores.shape
    assert bins.shape[0] == frac_pos.shape[0] + 1

    # Scatter color
    norm = Normalize(vmin=1, vmax=None)
    sm = ScalarMappable(norm=norm, cmap='flare')
    color = sm.to_rgba(counts.flat)

    # Scatter sizes
    norm = Normalize(vmin=1, vmax=100, clip=True)
    sizes = 15+ 20*norm(counts.flat)

    # Default parameters
    _fig_kw = dict(
        figsize=(3, 3),
    )
    _scatter_kw = dict(
        edgecolor='white',
        linewidth=0.3,
        color=color,
        s=sizes,
        label='Subgroups',
    )
    _calibration_kw = dict(
        marker='.',
        color='black',
        markersize=5,
        label='Calibration curve',
    )
    _hist_kw = dict(
        edgecolor='black',
        linewidth=0.2,
        color='#dfa0b3',
    )
    _bin_kw = dict(
        lw=0.2,
        ls="--",
        color="grey",
        zorder=-1,
    )
    _legend_kw = dict(
        framealpha=0,
        loc='lower center',
        bbox_to_anchor=(0.5, 1.1) if plot_hist else (0.5, 1),
        ncols=2,
    )

    # Update default parameters with input
    if calibration_kw is not None:
        _calibration_kw.update(calibration_kw)
    if hist_kw is not None:
        _hist_kw.update(hist_kw)
    if scatter_kw is not None:
        _scatter_kw.update(scatter_kw)
    if fig_kw is not None:
        _fig_kw.update(fig_kw)
    if bin_kw is not None:
        _bin_kw.update(bin_kw)
    if legend_kw is not None:
        _legend_kw.update(legend_kw)

    # Create or retrieve existing figure
    if ax is None:
        fig, ax = plt.subplots(1, 1, **_fig_kw)
    else:
        fig = ax.figure

    # Main axis
    p1 = ax.scatter(mean_scores.flat, frac_pos.flat, **_scatter_kw)

    ax.set_aspect('equal')
    ticks = [0, 0.25, 0.5, 0.75, 1]
    ax.set(
        xticks=ticks,
        yticks=ticks,
        xlabel='Predicted probability',
        ylabel='Fraction of positives',
        xlim=(-0.03, 1.03),
        ylim=(-0.03, 1.03),
    )

    if plot_bins:
        for x in bins:
            p_bin = ax.axvline(x, **_bin_kw)

    if plot_calibration:
        ax.plot([0, 1], [0, 1], ls="--", lw=1, color="black", zorder=0)
        prob_bins, mean_bins = calibration_curve_glest(frac_pos, counts, mean_scores)
        p2, = ax.plot(mean_bins, prob_bins, **_calibration_kw)

    # Histogram on upper axis
    divider = make_axes_locatable(ax)
    if plot_hist:
        ax_hist = divider.append_axes("top", size="10%", pad=0.0)
        ax_hist.set_xlim(ax.get_xlim())
        ax_hist.get_xaxis().set_visible(False)
        ax_hist.get_yaxis().set_visible(False)
        ax_hist.spines["right"].set_visible(False)
        ax_hist.spines["top"].set_visible(False)
        ax_hist.spines["left"].set_visible(False)
        ax_hist.hist(mean_scores.flat, bins=bins, weights=counts.flat, **_hist_kw)

    # Colorbar on right axis
    if plot_cbar:
        ax_cb = divider.append_axes("right", size="4%", pad=0.05)
        ax_cb.set_title('Count', loc='left')
        Colorbar(ax_cb, mappable=sm, spacing='proportional')

    # Legend on top of the figure
    if plot_legend:
        handles_labels = {
            p1: p1.get_label(),
        }
        if plot_bins:
            handles_labels[p_bin]= 'Bin edges'
        if plot_calibration:
            handles_labels[p2] = p2.get_label()
        ax.legend(
            handles=handles_labels.keys(),
            labels=handles_labels.values(),
            **_legend_kw
        )

    return fig


def grouping_diagram_residuals(
        c_hat,
        r_hat,
        n_in_leaf,
        f,
        y,
        f_star= None,
        f_cal=None,
        f_star_cal = None,
        leaf_ids=None,
        leaf_ids_cal=None,
        ax: plt.Axes = None,
        plot_calibration=True,
        plot_cbar: bool = True,
        plot_hist: bool = True,
        plot_legend: bool = True,
    ):
    """
    Plot a grouping diagram for residuals.
    Parameters
    ----------
    c_hat : array-like
        Predicted probabilities.
    r_hat : array-like
        Predicted residuals.
    n_in_leaf : array-like
        Number of samples in each leaf.
    f : callable
        Function to compute the grouping diagram.
    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure containing the grouping diagram.
    """
    # Scatter color
    norm = Normalize(vmin=1, vmax=None)
    sm = ScalarMappable(norm=norm, cmap='viridis')
    color = sm.to_rgba(leaf_ids.flat)

    # Scatter sizes
    norm = Normalize(vmin=1, vmax=100, clip=True)
    sizes = 1 + 10*norm(n_in_leaf.flat)

    # Default parameters
    _fig_kw = dict(
        figsize=(8, 8),
    )
    _scatter_kw = dict(
        edgecolor='white',
        linewidth=0.3,
        color=color,
        # s=sizes,
        label='Subgroups',
        alpha=0.5,
    )
    _calibration_kw = dict(
        color='black',
        linewidth=5,
        label='Calibration curve',
    )
    _hist_kw = dict(
        edgecolor='black',
        linewidth=0.2,
        color='#dfa0b3',
    )
    _bin_kw = dict(
        lw=0.2,
        ls="--",
        color="grey",
        zorder=-1,
    )
    _legend_kw = dict(
        framealpha=0,
        loc='lower center',
        bbox_to_anchor=(0.5, 1.1) if plot_hist else (0.5, 1),
        ncols=2,
    )

    # Update default parameters with input# Create or retrieve existing figure
    if ax is None:
        fig, ax = plt.subplots(1, 1, **_fig_kw)
    else:
        fig = ax.figure

    # if plot_calibration:
    #     ax.plot([0, 1], [0, 1], ls="--", lw=1, color="black", zorder=0)
    #     prob_bins, mean_bins = calibration_curve(y, f)
    #     p2, = ax.plot(mean_bins, prob_bins, **_calibration_kw)

    f_star_hat = r_hat + c_hat
    # ax.scatter(
    #     f, f_star_hat, **_scatter_kw
    # )

   

    ax.set_aspect('equal')
    ticks = [0, 0.25, 0.5, 0.75, 1]
    ax.set(
        xticks=ticks,
        yticks=ticks,
        xlabel='Predicted probability',
        ylabel='Fraction of positives',
        xlim=(-0.03, 1.03),
        ylim=(-0.03, 1.03),
    )

    # for i, leaf in enumerate(np.unique(n_in_leaf)):
    #     mask = n_in_leaf == leaf
    #     if np.sum(mask) > 0:
    #         # ax.axvline(x=c_hat[mask].mean(), color='grey', lw=0.2, ls='--')
    #         # Sort by f values to create a proper curve
    #         f_leaf = f[mask]
    #         f_star_hat_leaf = f_star_hat[mask]
    #         sort_idx = np.argsort(f_leaf)
    #         # Get the color for this leaf from the colormap
    #         leaf_color = sm.to_rgba(leaf)
    #         # Make line width proportional to number of samples
    #         line_width = 1.2 + 4 * (leaf / np.max(n_in_leaf))
    #         ax.plot(f_leaf[sort_idx], f_star_hat_leaf[sort_idx], 
    #                color=leaf_color, alpha=0.5, linewidth=line_width)


    # if plot_calibration:
    #     ax.plot([0, 1], [0, 1], ls="--", lw=1, color="black", zorder=0)
    #     prob_bins, mean_bins = calibration_curve(y, f)
    #     p2, = ax.plot(mean_bins, prob_bins, **_calibration_kw)
    
    divider = make_axes_locatable(ax)
    # Colorbar on right axis
    if plot_cbar:
        ax_cb = divider.append_axes("right", size="4%", pad=0.05)
        ax_cb.set_title('Group', loc='left')
        Colorbar(ax_cb, mappable=sm, spacing='proportional')


    # if f_star is not None:
    #     # Scatter plot of f_star vs f
    #     ax.scatter(f, f_star, **_scatter_kw)


    if f_cal is not None and f_star_cal is not None:
        # Scatter plot of f_cal vs f_star_cal
        ax.scatter(f_cal, f_star_cal, **_scatter_kw)
    

    for i, leaf in enumerate(np.unique(leaf_ids)):
        mask = leaf_ids == leaf
        n_leaf = n_in_leaf[i]
        if np.sum(mask) > 0:
            # ax.axvline(x=c_hat[mask].mean(), color='grey', lw=0.2, ls='--')
            # Sort by f values to create a proper curve
            f_leaf = f[mask]
            f_star_hat_leaf = f_star_hat[mask]
            sort_idx = np.argsort(f_leaf)
            # Get the color for this leaf from the colormap
            leaf_color = sm.to_rgba(leaf)
            # Make line width proportional to number of samples
            line_width =  + 3 * (n_leaf / np.max(n_in_leaf))
            ax.plot(f_leaf[sort_idx], f_star_hat_leaf[sort_idx], 
                   color=leaf_color, alpha=0.7, linewidth=line_width)
    
    if plot_calibration:
        ax.plot([0, 1], [0, 1], ls="--", lw=1, color="black", zorder=0)
        # prob_bins, mean_bins = calibration_curve(y, f, n_bins=100)
        sort_idx = np.argsort(f)
        ax.plot(f[sort_idx], c_hat[sort_idx], **_calibration_kw)
        # p2, = ax.plot(mean_bins, prob_bins, **_calibration_kw)
    # Legend on top of the figure
    # if plot_legend:
    #     handles_labels = {
    #     }
       
    #     if plot_calibration:
    #         handles_labels[p2] = p2.get_label()
    #     ax.legend(
    #         handles=handles_labels.keys(),
    #         labels=handles_labels.values(),
    #         **_legend_kw
    #     )

    return fig


def grouping_diagram_residuals_descriptive(
        c_hat,
        r_hat,
        n_in_leaf,
        f,
        y,
        f_star= None,
        f_cal=None,
        f_star_cal = None,
        leaf_ids=None,
        leaf_ids_cal=None,
        ax: plt.Axes = None,
        plot_calibration=True,
        plot_cbar: bool = True,
        plot_hist: bool = True,
        plot_legend: bool = True,
        tree = None,
        X_test = None,
        feature_names = None,
    ):
    """
    Plot a grouping diagram for residuals.
    Parameters
    ----------
    c_hat : array-like
        Predicted probabilities.
    r_hat : array-like
        Predicted residuals.
    n_in_leaf : array-like
        Number of samples in each leaf.
    f : callable
        Function to compute the grouping diagram.
    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure containing the grouping diagram.
    """
    # Scatter color
    norm = Normalize(vmin=1, vmax=None)
    sm = ScalarMappable(norm=norm, cmap='viridis')
    color = sm.to_rgba(leaf_ids.flat)

    # Scatter sizes
    norm = Normalize(vmin=1, vmax=100, clip=True)
    sizes = 1 + 10*norm(n_in_leaf.flat)

    # Default parameters
    _fig_kw = dict(
        figsize=(4, 3),
    )
    _scatter_kw = dict(
        edgecolor='white',
        linewidth=0.3,
        color=color,
        # s=sizes,
        label='Subgroups',
        alpha=0.5,
    )
    _calibration_kw = dict(
        color='black',
        linewidth=6,
        label='Calibration curve',
    )
    _hist_kw = dict(
        edgecolor='black',
        linewidth=0.2,
        color='#dfa0b3',
    )
    _bin_kw = dict(
        lw=0.2,
        ls="--",
        color="grey",
        zorder=-1,
    )
    _legend_kw = dict(
        framealpha=0,
        loc='lower center',
        bbox_to_anchor=(0.5, 1.1) if plot_hist else (0.5, 1),
        ncols=2,
    )

    # Update default parameters with input# Create or retrieve existing figure
    if ax is None:
        fig, ax = plt.subplots(1, 1, **_fig_kw)
    else:
        fig = ax.figure

    # if plot_calibration:
    #     ax.plot([0, 1], [0, 1], ls="--", lw=1, color="black", zorder=0)
    #     prob_bins, mean_bins = calibration_curve(y, f)
    #     p2, = ax.plot(mean_bins, prob_bins, **_calibration_kw)

    f_star_hat = r_hat + c_hat
    # ax.scatter(
    #     f, f_star_hat, **_scatter_kw
    # )

   

    # ax.set_aspect('equal')
    ticks = [0, 0.25, 0.5, 0.75, 1]
    ax.set(
        xticks=ticks,
        yticks=ticks,
        xlabel='Predicted probability for earning $>50k$/year',
        ylabel='Fraction of individuals earning $>50k$/year',
        xlim=(0, 1),
        ylim=(0, 1),
    )

    # for i, leaf in enumerate(np.unique(n_in_leaf)):
    #     mask = n_in_leaf == leaf
    #     if np.sum(mask) > 0:
    #         # ax.axvline(x=c_hat[mask].mean(), color='grey', lw=0.2, ls='--')
    #         # Sort by f values to create a proper curve
    #         f_leaf = f[mask]
    #         f_star_hat_leaf = f_star_hat[mask]
    #         sort_idx = np.argsort(f_leaf)
    #         # Get the color for this leaf from the colormap
    #         leaf_color = sm.to_rgba(leaf)
    #         # Make line width proportional to number of samples
    #         line_width = 1.2 + 4 * (leaf / np.max(n_in_leaf))
    #         ax.plot(f_leaf[sort_idx], f_star_hat_leaf[sort_idx], 
    #                color=leaf_color, alpha=0.5, linewidth=line_width)


    # if plot_calibration:
    #     ax.plot([0, 1], [0, 1], ls="--", lw=1, color="black", zorder=0)
    #     prob_bins, mean_bins = calibration_curve(y, f)
    #     p2, = ax.plot(mean_bins, prob_bins, **_calibration_kw)
    
    divider = make_axes_locatable(ax)
    # Colorbar on right axis
    if plot_cbar:
        ax_cb = divider.append_axes("right", size="4%", pad=0.05)
        ax_cb.set_title('Group', loc='left')
        Colorbar(ax_cb, mappable=sm, spacing='proportional')


    # if f_star is not None:
    #     # Scatter plot of f_star vs f
    #     ax.scatter(f, f_star, **_scatter_kw)


    if f_cal is not None and f_star_cal is not None:
        # Scatter plot of f_cal vs f_star_cal
        ax.scatter(f_cal, f_star_cal, **_scatter_kw)
    

    if tree is not None and X_test is not None and feature_names is not None:
        group_defs = get_group_definitions(tree, X_test, leaf_ids, feature_names)

    for i, leaf in enumerate(np.unique(leaf_ids)):
        mask = leaf_ids == leaf
        n_leaf = n_in_leaf[i]
        if np.sum(mask) > 0:
            # ax.axvline(x=c_hat[mask].mean(), color='grey', lw=0.2, ls='--')
            # Sort by f values to create a proper curve
            f_leaf = f[mask]
            f_star_hat_leaf = f_star_hat[mask]
            sort_idx = np.argsort(f_leaf)
            # Get the color for this leaf from the colormap
            leaf_color = sm.to_rgba(leaf)
            # Make line width proportional to number of samples
            line_width = 1 + 4 * np.log(n_leaf) / np.log(np.max(n_in_leaf))
            
            # Check if this leaf will be annotated
            is_annotated = (tree is not None and X_test is not None and 
                  feature_names is not None and leaf in [62, 37])
            
            # Set color and alpha based on annotation status
            if is_annotated:
                plot_color = leaf_color
                plot_alpha = 1.0
            else:
                plot_color = 'grey'
                plot_alpha = 0.3
                
            ax.plot(f_leaf[sort_idx], f_star_hat_leaf[sort_idx], 
            color=plot_color, alpha=plot_alpha, linewidth=line_width)
            
            # Add text annotation for group definition if available
            if is_annotated and leaf in group_defs:
                rules = group_defs[leaf]['rules']
                if rules:
                    # Group rules by feature to show only one per feature
                    feature_rules = {}
                    for rule in rules:
                        feature_name = rule.split(' ')[0]
                        if feature_name not in feature_rules:
                            feature_rules[feature_name] = rule
                
                # Create description from unique feature rules
                rule_text = ' & '.join(feature_rules.values())
                
                # Find a representative point on the curve (middle point)
                mid_idx = len(sort_idx) // 2

                arrow_x = f_leaf[sort_idx[-1]]
                arrow_y = f_star_hat_leaf[sort_idx[-1]]
                
                if leaf == 62:
                    arrow_x = f_leaf[sort_idx[70]]
                    arrow_y = f_star_hat_leaf[sort_idx[70]]

                if leaf == 37:
                    arrow_x = f_leaf[sort_idx[20]]
                    arrow_y = f_star_hat_leaf[sort_idx[20]]
                # Position annotations based on data position to avoid crossing
                if leaf == 28:
                    annotation_x, annotation_y = 0.05, 0.4
                    rule_text_plain = "Non-white experienced\nBachelor men"
                elif leaf == 62:
                    annotation_x, annotation_y = 0.04, 0.68
                    rule_text_plain = "Experienced Graduate women"
                else:  # leaf == 19
                    annotation_x, annotation_y = 0.07, 0.45
                    rule_text_plain = "Young Graduate women"
                
                ax.annotate(f'Group {leaf}, n = {int(n_leaf)} \n{rule_text_plain}', 
                           xy=(arrow_x, arrow_y),
                           xytext=(annotation_x, annotation_y), 
                           textcoords='data',
                           bbox=dict(boxstyle='round,pad=0.3', 
                            facecolor=leaf_color, 
                            alpha=0.9,
                            edgecolor='black',
                            linewidth=1),
                           fontsize=10,
                           ha='left',
                           arrowprops=dict(arrowstyle='->', 
                                  color=leaf_color, 
                                  lw=1.5,
                                  alpha=0.9,
                                  shrinkA=0,
                                  shrinkB=0,
                                  edgecolor='black')
                )
    
    if plot_calibration:
        ax.plot([0, 1], [0, 1], ls="--", lw=1, color="black", zorder=0)
        # prob_bins, mean_bins = calibration_curve(y, f, n_bins=100)
        sort_idx = np.argsort(f)
        ax.plot(f[sort_idx], c_hat[sort_idx],**_calibration_kw)
        # p2, = ax.plot(mean_bins, prob_bins, **_calibration_kw)

    ax.set_xlabel('Predicted probability for earning >50k$/year', fontsize=11)
    ax.set_ylabel('Fraction of individuals\nearning >50k$/year', fontsize=11)
    
    # Create legend entries for the annotated groups
    legend_handles = []
    legend_labels = []
    
    if tree is not None and X_test is not None and feature_names is not None:
        group_defs = get_group_definitions(tree, X_test, leaf_ids, feature_names)
        
        for leaf in [28, 62, 37]:
            if leaf in group_defs and leaf in np.unique(leaf_ids):
                leaf_color = sm.to_rgba(leaf)
                # Create a line handle for the legend
                line_handle = Line2D([0], [0], color=leaf_color, linewidth=3, alpha=1.0)
                legend_handles.append(line_handle)
                
                # Get the rule text
                rules = group_defs[leaf]['rules']
                if rules:
                    rule_text = ' & '.join(rules)
                    legend_labels.append(f'Group {leaf}: {rule_text}')
                else:
                    legend_labels.append(f'Group {leaf}')
    
    if legend_handles:
        ax.legend(legend_handles, legend_labels, bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1, fontsize=8)
    else:
        ax.legend(bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=2)
    ax.legend()
    
    # Legend on top of the figure
    # if plot_legend:
    #     handles_labels = {
    #     }
       
    #     if plot_calibration:
    #         handles_labels[p2] = p2.get_label()
    #     ax.legend(
    #         handles=handles_labels.keys(),
    #         labels=handles_labels.values(),
    #         **_legend_kw
    #     )
    # plt.tight_layout()
    fig.savefig('grouping_diagram.pdf', dpi=300, bbox_inches='tight')
    return fig

def get_group_definitions(tree, X_test, leaf_ids, feature_names):
    """
    Extract the decision rules that define each group (leaf) in the tree
    """
    # Get unique leaf IDs
    unique_leaves = np.unique(leaf_ids)
    
    group_definitions = {}
    
    for leaf_id in unique_leaves:
        # Get samples in this leaf
        samples_in_leaf = X_test[leaf_ids == leaf_id]
        
        # Get the path to this leaf
        path = tree.decision_path(samples_in_leaf[:1]).toarray()[0]
        
        # Extract the rules
        raw_rules = []
        node_indicator = tree.decision_path(samples_in_leaf[:1])
        leaf_id_sample = tree.apply(samples_in_leaf[:1])[0]
        
        # Get the path from root to leaf
        feature = tree.tree_.feature
        threshold = tree.tree_.threshold
        
        for node_id in range(len(path)):
            if path[node_id] == 1:  # This node is in the path
                if feature[node_id] != -2:  # Not a leaf node
                    # Determine if we went left or right
                    sample_feature_value = samples_in_leaf[0, feature[node_id]]
                    feature_name = feature_names[feature[node_id]]
                    if sample_feature_value <= threshold[node_id]:
                        raw_rules.append((feature_name, "<=", threshold[node_id]))
                    else:
                        raw_rules.append((feature_name, ">", threshold[node_id]))
        
        # Combine rules for the same feature
        feature_bounds = {}
        for feature_name, operator, value in raw_rules:
            if feature_name not in feature_bounds:
                feature_bounds[feature_name] = {'min': None, 'max': None}
            
            if operator == "<=":
                if feature_bounds[feature_name]['max'] is None or value < feature_bounds[feature_name]['max']:
                    feature_bounds[feature_name]['max'] = value
            else:  # operator == ">"
                if feature_bounds[feature_name]['min'] is None or value > feature_bounds[feature_name]['min']:
                    feature_bounds[feature_name]['min'] = value
        
        # Convert bounds to readable rules
        combined_rules = []
        for feature_name, bounds in feature_bounds.items():
            if bounds['min'] is not None and bounds['max'] is not None:
                combined_rules.append(f"{bounds['min']:.1f} < {feature_name} <= {bounds['max']:.1f}")
            elif bounds['min'] is not None:
                combined_rules.append(f"{feature_name} > {bounds['min']:.1f}")
            elif bounds['max'] is not None:
                combined_rules.append(f"{feature_name} <= {bounds['max']:.1f}")
        
        group_definitions[leaf_id] = {
            'rules': combined_rules,
            'n_samples': len(samples_in_leaf),
            'sample_indices': np.where(leaf_ids == leaf_id)[0]
        }
    
    return group_definitions