import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

import pickle
import seaborn as sns

import scienceplots
import statsmodels.api as sm

# To initialize necessary variables, run:
# ablations, values, densities = get_ablations_from_pickle()
# all_ablations, values, densities = stack_all_ablations()

def get_ablations_from_pickle(ablations_path = 'ablation_logs/0924_ablations_short.pkl',
                              values_path = 'ablation_logs/0924_values_short.pkl',
                              densities_path = 'ablation_logs/0924_densities_short.pkl'):
    with open(ablations_path, 'rb') as f:
        ablations = pickle.load(f)
    with open(values_path, 'rb') as f:
        values = pickle.load(f)
    with open(densities_path, 'rb') as f:
        densities = pickle.load(f)

    return ablations, values, densities

def stack_all_ablations(ablations, values, densities, k=512, return_clipped=True):
    all_ablations = [ablations[ind] for ind in range(k) if len(ablations[ind]) > 0]
    all_values = [values[ind] for ind in range(k) if len(values[ind]) > 0]
    all_densities = [densities[ind] for ind in range(k) if len(densities[ind]) > 0]

    all_ablations = np.concatenate(all_ablations, axis=0)
    all_values = np.concatenate(all_values, axis=0)
    all_densities = np.concatenate(all_densities, axis=0)

    # The most recent run had ablations for 150% of the presence (t=1.5)
    # We only care about the ablation from t=0 to t=1.
    if return_clipped:
        all_ablations = all_ablations[:, :21]
    return all_ablations, all_values, all_densities


# Figure 4, upper 2.
def plot_quantiles(ablations, fn='figures/abl_quantiles.svg', plot_fill=True, 
                   quantiles=[0.05, 1/6, 2/6, 3/6, 4/6, 5/6, 0.95],
                   colormap=plt.cm.winter, cm_alpha=0.5, cm_range=(0, 1),
                   fillcolor = [0.9,0.96,0.99],
                   xlim=[-0.05, 1.05], ylim=[-0.08, 1.6],
                   figure_size = (2,2)):

    plt.style.use(['science', 'ieee'])

    quantiles = np.array(quantiles)

    fig, ax = plt.subplots(figsize=figure_size)

    colors = colormap(np.linspace(cm_range[0], cm_range[1], (quantiles.shape[0] - 2 * plot_fill)), alpha=cm_alpha)
    ax.set_prop_cycle(plt.cycler('color', colors))

    scaled_ablations = get_scaled_feat(ablations)

    ab_quantiles = np.quantile(scaled_ablations, quantiles, axis=0)
    x = ab_quantiles.shape[1]
    linear = np.arange(x) / (x-1)

    if plot_fill:
        ax.fill_between(np.arange(x) / (x-1), ab_quantiles[0].T, ab_quantiles[-1].T, color=fillcolor)
        ax.plot(np.arange(x) / (x-1), ab_quantiles[1:-1].T, alpha=0.5, label=[f"{quant:.2f} quantile" for quant in quantiles[1:-1]])
    else:
        ax.plot(np.arange(x) / (x-1), ab_quantiles.T, label=[f"{quant:.2f} quantile" for quant in quantiles])
    ax.plot(np.arange(x) / (x-1), linear, color=[0,0,0], linestyle='dotted', alpha=0.4)

    ax.set_title('Quantiles of All ARCs')
    ax.set_xlabel('$t$ value')
    ax.set_ylabel('Scaled Difference in Loss')
    # ax.legend()

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)


# Figure 4, upper 1.
def plot_ablation_examples(ablations, example_inds=[9946, 432, 11384], fn='figures/abl_examples.svg', plot_fill=True, 
                   quantiles=[0.05, 1/6, 2/6, 3/6, 4/6, 5/6, 0.95],
                   colormap=plt.cm.cool, cm_alpha=0.5, cm_range=(0.3, 1), 
                   fillcolor = [0.9,0.96,0.99],
                   xlim=[-0.05, 1.05], ylim=[-0.08, 1.6],
                   figure_size = (2, 2)):

    plt.style.use(['science', 'ieee', 'bright'])
    quantiles = np.array(quantiles)
    fig, ax = plt.subplots(figsize=figure_size)


    scaled_ablations = get_scaled_feat(ablations)
    examples = scaled_ablations[example_inds]
    colors = colormap(np.linspace(cm_range[0], cm_range[1], (examples.shape[0])), alpha=cm_alpha)

    ab_quantiles = np.quantile(scaled_ablations, quantiles, axis=0)
    x = ab_quantiles.shape[1]
    linear = np.arange(x) / (x-1)

    if plot_fill:
        ax.fill_between(np.arange(x) / (x-1), ab_quantiles[0].T, ab_quantiles[-1].T, color=fillcolor)
        ax.plot(np.arange(x) / (x-1), ab_quantiles[1:-1].T, color=(np.array(fillcolor) ** 2), linestyle='-', alpha=0.5)
    else:
        ax.plot(np.arange(x) / (x-1), ab_quantiles.T, linestyle='-', color=(np.array(fillcolor) ** 2), alpha=0.5)
    ax.plot(np.arange(x) / (x-1), linear, color=[0,0,0], linestyle='dotted', alpha=0.4)

    ax.set_prop_cycle(plt.cycler('color', colors))
    ax.plot(np.arange(x) / (x-1), examples.T)

    t_pts = get_fine_t_pts(examples, remove_return_below=False) / x
    for i, t_pt in enumerate(t_pts):
        ax.axvline(t_pt, color=colors[i], alpha=0.5, linestyle='dotted')

    ax.set_title('Example ARCs')
    ax.set_xlabel('$t$ value')
    ax.set_ylabel('Scaled Difference in Loss')

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)

# Figure 6, left and center. Inds are [9946, 432, 11384]
def plot_partial_ablation_examples(ablations, example_inds=[432], fn='figures/part1_abl_examples.svg', plot_fill=True, 
                   quantiles=[0.05, 1/6, 2/6, 3/6, 4/6, 5/6, 0.95],
                   colormap=plt.cm.cool, cm_alpha=0.5, cm_range=(0.7, 1), 
                   fillcolor = [0.9,0.96,0.99],
                   xlim=[-0.05, 1.05], ylim=[-0.08, 1.6],
                   figure_size = (3, 3)):

    plt.style.use(['science', 'ieee', 'bright'])
    quantiles = np.array(quantiles)
    fig, ax = plt.subplots(figsize=figure_size)


    scaled_ablations = get_scaled_feat(ablations)
    examples = scaled_ablations[example_inds]
    colors = colormap(np.linspace(cm_range[0], cm_range[1], (examples.shape[0])), alpha=cm_alpha)

    ab_quantiles = np.quantile(scaled_ablations, quantiles, axis=0)
    x = ab_quantiles.shape[1]
    linear = np.arange(x) / (x-1)

    if plot_fill:
        ax.fill_between(np.arange(x) / (x-1), ab_quantiles[0].T, ab_quantiles[-1].T, color=fillcolor)
        ax.plot(np.arange(x) / (x-1), ab_quantiles[1:-1].T, color=(np.array(fillcolor) ** 2), linestyle='-', alpha=0.5)
    else:
        ax.plot(np.arange(x) / (x-1), ab_quantiles.T, linestyle='-', color=(np.array(fillcolor) ** 2), alpha=0.5)
    # ax.plot(np.arange(x) / (x-1), linear, color=[0,0,0], linestyle='dotted', alpha=0.4)

    ax.set_prop_cycle(plt.cycler('color', colors))
    ax.plot(np.arange(x) / (x-1), examples.T)

    t_pts = get_fine_t_pts(examples, remove_return_below=False) / x
    for i, t_pt in enumerate(t_pts):
        ax.axvline(t_pt, color=(0, 0, 0), alpha=0.5, linestyle='dotted')

    ax.set_xlabel('$t$ value')
    ax.set_ylabel('Scaled Difference in Loss')

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)

# Figure 4, upper 3
def plot_loss_diff_per_avg_presence(all_ablations, all_values, all_presences, fn='figures/ld_per_avg_presence.svg',
                                    ranges=[1.35, 1.5, 1.65, 1.9], clip=(-6, 0), log_scale=True,
                                    colormap=plt.cm.magma, cm_alpha=0.5, cm_range=(0, 0.7), 
                                    xlim=None, ylim=(0, 1),
                                    figure_size = (2, 2), legend_size=5):
    all_diffs = all_ablations[:, -1] - all_ablations[:, 0]
    all_avg_vals = all_values / all_presences

    symbol = '\overline{\\alpha}_j'

    if not isinstance(ranges[0], list):
        range_start = [all_avg_vals.min()] + ranges
        range_end = ranges + [all_avg_vals.max()]
        ranges = list(zip(range_start, range_end))

    plt.style.use(['science', 'ieee', 'bright'])
    fig, ax = plt.subplots(figsize=figure_size)
    colors = colormap(np.linspace(cm_range[0], cm_range[1], (len(list(ranges)))), alpha=cm_alpha)
    ax.set_prop_cycle(plt.cycler('color', colors))

    split_diffs, split_avg_vals = split_features_by_diff_range(all_diffs, all_avg_vals, ranges)
    for i, (diff, interval) in enumerate(zip(split_diffs, ranges)):

        diff = diff[(~np.isnan(diff)) & (diff > 0)]

        if i == 0:
            label = f'${symbol} < {interval[1]:.2f}$'
        elif i == (len(split_diffs) - 1):
            label = f'${symbol} > {interval[0]:.2f}$'
        else:
            label = f'${interval[0]:.2f} < {symbol} < {interval[1]:.2f}$' 
        sns.kdeplot(diff, ax=ax, clip=clip, label=label, log_scale=log_scale)

    ax.set_title('Loss per Avg Presence')
    ax.set_xlabel('Loss Difference')
    ax.set_ylabel('Density')
    ax.legend(prop={'size': legend_size})

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)

# Figure 4, upper 4
def plot_loss_diff_per_feature_density(all_ablations, all_presences, fn='figures/ld_per_density.svg',
                                       ranges=[0.04, 0.08, 0.12, 0.20], clip=(-6, 0), log_scale=True,
                                       colormap=plt.cm.copper, cm_alpha=0.5, cm_range=(0, 1), 
                                       xlim=None, ylim=(0, 1),
                                       figure_size = (2, 2), legend_size=5):
    all_diffs = all_ablations[:, -1] - all_ablations[:, 0]
    all_densities = all_presences / 4096

    symbol = '\\rho_j'

    if not isinstance(ranges[0], list):
        range_start = [all_densities.min()] + ranges
        range_end = ranges + [all_densities.max()]
        ranges = list(zip(range_start, range_end))

    plt.style.use(['science', 'ieee', 'bright'])
    fig, ax = plt.subplots(figsize=figure_size)
    colors = colormap(np.linspace(cm_range[0], cm_range[1], (len(list(ranges)))), alpha=cm_alpha)
    ax.set_prop_cycle(plt.cycler('color', colors))

    split_diffs, split_densities = split_features_by_diff_range(all_diffs, all_densities, ranges)
    for i, (diff, interval) in enumerate(zip(split_diffs, ranges)):

        diff = diff[(~np.isnan(diff)) & (diff > 0)]

        if i == 0:
            label = f'${symbol} < {interval[1]*100:.0f}\\%$'
        elif i == (len(split_diffs) - 1):
            label = f'${symbol} > {interval[0]*100:.0f}\\%$'
        else:
            label = f'${interval[0]*100:.0f}\\% < {symbol} < {interval[1]*100:.0f}\\%$' 
        sns.kdeplot(diff, ax=ax, clip=clip, label=label, log_scale=log_scale)

    ax.set_title('Loss per Feature Density')
    ax.set_xlabel('Loss Difference')
    ax.set_ylabel('Density')
    ax.legend(prop={'size': legend_size})

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)

# Figure 5, upper center
def plot_t_pt_per_loss_diff(all_ablations, fn='figures/t_pt_per_ld.svg',
                            ranges=[[0.0001, 0.00011], [0.0010,0.0011], [0.01, 0.011], [0.1, 0.11]], 
                            clip=(0, 1),
                            colormap=plt.cm.viridis, cm_alpha=0.5, cm_range=(0, 1), 
                            xlim=None, ylim=None,
                            figure_size = (3, 2), legend_size=5):
    all_diffs = all_ablations[:, -1] - all_ablations[:, 0]
    symbol = '\\Delta L'

    no_print_edge = False
    if not isinstance(ranges[0], list):
        no_print_edge = True
        range_start = [all_diffs.min()] + ranges
        range_end = ranges + [all_diffs.max()]
        ranges = list(zip(range_start, range_end))

    plt.style.use(['science', 'ieee', 'bright'])
    fig, ax = plt.subplots(figsize=figure_size)
    colors = colormap(np.linspace(cm_range[0], cm_range[1], (len(list(ranges)))), alpha=cm_alpha)
    ax.set_prop_cycle(plt.cycler('color', colors))

    scaled_ablations = get_scaled_feat(all_ablations)
    split_ablations, split_diffs = split_features_by_diff_range(scaled_ablations, all_diffs, ranges)
    split_t_pts = []
    for split_ablation in split_ablations:
        split_t_pts.append(np.array(get_fine_t_pts(split_ablation)) / (split_ablation.shape[1] - 1))

    for i, (diff, interval) in enumerate(zip(split_t_pts, ranges)):

        diff = diff[(~np.isnan(diff)) & (diff > 0)]

        if no_print_edge and i == 0:
            label = f'${symbol} < {interval[1]:.1e}$'
        elif no_print_edge and i == (len(split_ablations) - 1):
            label = f'${symbol} > {interval[0]:.1e}$'
        else:
            label = f'${interval[0]:.1e} < {symbol} < {interval[1]:.1e}$' 
        sns.kdeplot(diff, ax=ax, label=label,clip=clip)

    ax.set_title('Transition Points By Loss Diff')
    ax.set_xlabel('Transition Point')
    ax.set_ylabel('Density')
    ax.legend(prop={'size': legend_size})

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)

# Figure 5, lower. To target a specific ind, run with ablations[ind].
def plot_ind_t_pt_per_loss_diff(ind_ablations, fn='figures/ind_t_pt_per_ld.svg',
                                ranges=[[0.0001, 0.00011], [0.0010,0.0011], [0.01, 0.011], [0.1, 0.11]], 
                                clip=(0, 1), title='Ft 166 T Pts By Loss Diff',
                                colormap=plt.cm.viridis, cm_alpha=0.5, cm_range=(0, 1), 
                                xlim=None, ylim=None,
                                figure_size = (3, 2), legend_size=5):
    ind_ablations = np.array(ind_ablations)
    ind_diffs = ind_ablations[:, -1] - ind_ablations[:, 0]
    symbol = '\\Delta L'

    no_print_edge = False
    if not isinstance(ranges[0], list):
        no_print_edge = True
        range_start = [ind_diffs.min()] + ranges
        range_end = ranges + [ind_diffs.max()]
        ranges = list(zip(range_start, range_end))

    plt.style.use(['science', 'ieee', 'bright'])
    fig, ax = plt.subplots(figsize=figure_size)
    colors = colormap(np.linspace(cm_range[0], cm_range[1], (len(list(ranges)))), alpha=cm_alpha)
    ax.set_prop_cycle(plt.cycler('color', colors))

    scaled_ablations = get_scaled_feat(ind_ablations)
    split_ablations, split_diffs = split_features_by_diff_range(scaled_ablations, ind_diffs, ranges)
    split_t_pts = []
    for split_ablation in split_ablations:
        split_t_pts.append(np.array(get_fine_t_pts(split_ablation)) / (split_ablation.shape[1] - 1))

    for i, (diff, interval) in enumerate(zip(split_t_pts, ranges)):

        diff = diff[(~np.isnan(diff)) & (diff > 0)]

        if no_print_edge and i == 0:
            label = f'${symbol} < {interval[1]:.1e}$'
        elif no_print_edge and i == (len(split_ablations) - 1):
            label = f'${symbol} > {interval[0]:.1e}$'
        else:
            label = f'${interval[0]:.1e} < {symbol} < {interval[1]:.1e}$' 
        sns.kdeplot(diff, ax=ax, label=label,clip=clip)

    ax.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.0f'))
    ax.set_title(title)
    ax.set_xlabel('Transition Point')
    ax.set_ylabel('Density')
    ax.legend(prop={'size': legend_size})

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)

# Figure 5, upper right
def plot_max_slope_pt_per_loss_diff(all_ablations, fn='figures/max_slope_pt_per_ld.svg',
                                    ranges=[[0.0001, 0.00011], [0.0010,0.0011], [0.01, 0.011], [0.1, 0.11]], 
                                    clip=(0, 1), bw_adjust=1,
                                    colormap=plt.cm.viridis, cm_alpha=0.5, cm_range=(0, 1), 
                                    xlim=None, ylim=None,
                                    figure_size = (3, 2), legend_size=5):
    all_diffs = all_ablations[:, -1] - all_ablations[:, 0]
    symbol = '\\Delta L'

    no_print_edge = False
    if not isinstance(ranges[0], list):
        no_print_edge = True
        range_start = [all_diffs.min()] + ranges
        range_end = ranges + [all_diffs.max()]
        ranges = list(zip(range_start, range_end))

    plt.style.use(['science', 'ieee', 'bright'])
    fig, ax = plt.subplots(figsize=figure_size)
    colors = colormap(np.linspace(cm_range[0], cm_range[1], (len(list(ranges)))), alpha=cm_alpha)
    ax.set_prop_cycle(plt.cycler('color', colors))

    all_slopes = np.diff(all_ablations, axis=1)
    all_max_slopes_inds = np.argmax(all_slopes, axis=1) / (all_slopes.shape[1])
    split_max_slopes_inds, split_diffs = split_features_by_diff_range(all_max_slopes_inds, all_diffs, ranges)

    for i, (split, interval) in enumerate(zip(split_max_slopes_inds, ranges)):

        if no_print_edge and i == 0:
            label = f'${symbol} < {interval[1]:.1e}$'
        elif no_print_edge and i == (len(split_max_slopes_inds) - 1):
            label = f'${symbol} > {interval[0]:.1e}$'
        else:
            label = f'${interval[0]:.1e} < {symbol} < {interval[1]:.1e}$' 
        sns.kdeplot(split, ax=ax, label=label,clip=clip, bw_adjust=bw_adjust)

    ax.set_title('Max Slope Pt By Loss Diff')
    ax.set_xlabel('Point of Maximum Slope')
    ax.set_ylabel('Density')
    ax.legend(prop={'size': legend_size})

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)

# Figure 4, lower 1
def plot_intermed_vals_per_ld(all_ablations, fn='figures/intermed_vals_per_ld.svg',
                              ranges=[[0.0001, 0.00011], [0.0010,0.0011], [0.01, 0.011], [0.1, 0.11]],
                              clip=(-0.1, 1.4),
                              colormap=plt.cm.viridis, cm_alpha=0.5, cm_range=(0, 1), 
                              xlim=None, ylim=None,
                              figure_size = (2, 2), legend_size=5):
    all_diffs = all_ablations[:, -1] - all_ablations[:, 0]
    symbol = '\\Delta L'

    no_print_edge = False
    if not isinstance(ranges[0], list):
        no_print_edge = True
        range_start = [all_diffs.min()] + ranges
        range_end = ranges + [all_diffs.max()]
        ranges = list(zip(range_start, range_end))

    plt.style.use(['science', 'ieee', 'bright'])
    fig, ax = plt.subplots(figsize=figure_size)
    colors = colormap(np.linspace(cm_range[0], cm_range[1], (len(list(ranges)))), alpha=cm_alpha)
    ax.set_prop_cycle(plt.cycler('color', colors))

    scaled_ablations = get_scaled_feat(all_ablations)
    split_ablations, split_diffs = split_features_by_diff_range(scaled_ablations[:, 1:-1], all_diffs, ranges)

    for i, (diff, interval) in enumerate(zip(split_ablations, ranges)):

        diff = diff[(~np.isnan(diff)) & (diff > 0)]

        if no_print_edge and i == 0:
            label = f'${symbol} < {interval[1]:.1e}$'
        elif no_print_edge and i == (len(split_ablations) - 1):
            label = f'${symbol} > {interval[0]:.1e}$'
        else:
            label = f'${interval[0]:.1e} < {symbol} < {interval[1]:.1e}$' 
        sns.kdeplot(diff, ax=ax, label=label,clip=clip)

    ax.set_title('Intermed. Loss By Loss Diff')
    ax.set_xlabel('Intermediate Losses')
    ax.set_ylabel('Density')
    ax.legend(prop={'size': legend_size})

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)

# Figure 4, lower 3
def plot_min_slope_per_ld(all_ablations, fn='figures/min_slope_per_ld.svg',
                          ranges=[[0.0001, 0.00011], [0.0010,0.0011], [0.01, 0.011], [0.1, 0.11]],
                          clip=(0, 1),
                          colormap=plt.cm.viridis, cm_alpha=0.5, cm_range=(0, 1), 
                          xlim=None, ylim=None,
                          figure_size = (2, 2), legend_size=5):
    all_diffs = all_ablations[:, -1] - all_ablations[:, 0]
    symbol = '\\Delta L'

    no_print_edge = False
    if not isinstance(ranges[0], list):
        no_print_edge = True
        range_start = [all_diffs.min()] + ranges
        range_end = ranges + [all_diffs.max()]
        ranges = list(zip(range_start, range_end))

    plt.style.use(['science', 'ieee', 'bright'])
    fig, ax = plt.subplots(figsize=figure_size)
    colors = colormap(np.linspace(cm_range[0], cm_range[1], (len(list(ranges)))), alpha=cm_alpha)
    ax.set_prop_cycle(plt.cycler('color', colors))

    all_slopes = np.diff(all_ablations, axis=1)
    all_min_slopes_inds = np.argmin(np.abs(all_slopes), axis=1) / (all_slopes.shape[1] - 1)
    split_min_slopes_inds, split_diffs = split_features_by_diff_range(all_min_slopes_inds, all_diffs, ranges)

    for i, (min_slope_ind, interval) in enumerate(zip(split_min_slopes_inds, ranges)):

        if no_print_edge and i == 0:
            label = f'${symbol} < {interval[1]:.1e}$'
        elif no_print_edge and i == (len(split_min_slopes_inds) - 1):
            label = f'${symbol} > {interval[0]:.1e}$'
        else:
            label = f'${interval[0]:.1e} < {symbol} < {interval[1]:.1e}$' 
        sns.kdeplot(min_slope_ind, ax=ax, label=label,clip=clip)

    ax.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.1f'))
    ax.set_title('Min Slope Pt By Loss Diff')
    ax.set_xlabel('Point of Minimum Slope')
    ax.set_ylabel('Density')
    ax.legend(prop={'size': legend_size})

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)

# Figure 4, lower 2
def plot_slopes_per_ld(all_ablations, fn='figures/slopes_per_ld.svg',
                      ranges=[[0.0001, 0.00011], [0.0010,0.0011], [0.01, 0.011], [0.1, 0.11]],
                      clip=(0, 5),
                      colormap=plt.cm.viridis, cm_alpha=0.5, cm_range=(0, 1), 
                      xlim=None, ylim=None,
                      figure_size = (2, 2), legend_size=5):
    all_diffs = all_ablations[:, -1] - all_ablations[:, 0]
    symbol = '\\Delta L'

    no_print_edge = False
    if not isinstance(ranges[0], list):
        no_print_edge = True
        range_start = [all_diffs.min()] + ranges
        range_end = ranges + [all_diffs.max()]
        ranges = list(zip(range_start, range_end))

    plt.style.use(['science', 'ieee', 'bright'])
    fig, ax = plt.subplots(figsize=figure_size)
    colors = colormap(np.linspace(cm_range[0], cm_range[1], (len(list(ranges)))), alpha=cm_alpha)
    ax.set_prop_cycle(plt.cycler('color', colors))

    scaled_ablations = get_scaled_feat(all_ablations)
    all_slopes = np.diff(scaled_ablations, axis=1) * (scaled_ablations.shape[1] - 1)
    split_slopes, split_diffs = split_features_by_diff_range(all_slopes, all_diffs, ranges)

    for i, (split, interval) in enumerate(zip(split_slopes, ranges)):

        if no_print_edge and i == 0:
            label = f'${symbol} < {interval[1]:.1e}$'
        elif no_print_edge and i == (len(split_slopes) - 1):
            label = f'${symbol} > {interval[0]:.1e}$'
        else:
            label = f'${interval[0]:.1e} < {symbol} < {interval[1]:.1e}$' 
        sns.kdeplot(split.flatten(), ax=ax, label=label,clip=clip)

    ax.set_title('Slopes By Loss Diff')
    ax.set_xlabel('Slope')
    ax.set_ylabel('Density')
    ax.legend(prop={'size': legend_size})

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)

# Figure 4, lower 4
def plot_z_max_slopes_per_ld(all_ablations, fn='figures/z_max_slopes_per_ld.svg',
                             ranges=[[0.0001, 0.00011], [0.0010,0.0011], [0.01, 0.011], [0.1, 0.11]],
                             clip=(0, 5),
                             colormap=plt.cm.viridis, cm_alpha=0.5, cm_range=(0, 1), 
                             xlim=None, ylim=None,
                             figure_size = (2, 2), legend_size=5):
    all_diffs = all_ablations[:, -1] - all_ablations[:, 0]
    symbol = '\\Delta L'

    no_print_edge = False
    if not isinstance(ranges[0], list):
        no_print_edge = True
        range_start = [all_diffs.min()] + ranges
        range_end = ranges + [all_diffs.max()]
        ranges = list(zip(range_start, range_end))

    plt.style.use(['science', 'ieee', 'bright'])
    fig, ax = plt.subplots(figsize=figure_size)
    colors = colormap(np.linspace(cm_range[0], cm_range[1], (len(list(ranges)))), alpha=cm_alpha)
    ax.set_prop_cycle(plt.cycler('color', colors))

    scaled_ablations = get_scaled_feat(all_ablations)
    all_slopes = np.diff(scaled_ablations, axis=1) * (scaled_ablations.shape[1] - 1)
    split_slopes, split_diffs = split_features_by_diff_range(all_slopes, all_diffs, ranges)

    for i, (split, interval) in enumerate(zip(split_slopes, ranges)):

        if no_print_edge and i == 0:
            label = f'${symbol} < {interval[1]:.1e}$'
        elif no_print_edge and i == (len(split_slopes) - 1):
            label = f'${symbol} > {interval[0]:.1e}$'
        else:
            label = f'${interval[0]:.1e} < {symbol} < {interval[1]:.1e}$' 
        
        sns.kdeplot(((split - np.mean(split)) / np.std(split)).max(axis=1), ax=ax, label=label,clip=clip)

    ax.set_title('Max Slopes By Loss Diff')
    ax.set_xlabel('Standardized Max Slope')
    ax.set_ylabel('Density')
    ax.legend(prop={'size': legend_size})

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)

# Analysis for appendix B.4.
def plot_all_ind_t_pts_per_loss_diff(ablations, values, densities, minimum_ablations=1000, log_diff=True, log_dens=False, log_avg_val=False):
    diff_r2 = []
    dens_r2 = []
    avgv_r2 = []
    for ind in [ind for ind in range(512) if len(ablations[ind]) > minimum_ablations]:

        res = reg_on_ind_feat(ablations, values, densities, ind, log_diff=log_diff, log_dens=log_dens, log_avg_val=log_avg_val)
        pr2_diff = partial_r2(res, 'x1')
        pr2_density = partial_r2(res, 'x2')
        pr2_avg_val = partial_r2(res, 'x3')

        ind_ablation = np.array(ablations[ind])
        _, split_diffs = split_features_by_diff(ind_ablation, ind_ablation[:, -1] - ind_ablation[:, 0], 40)
        ranges = [[split_diffs[1][0], split_diffs[1][-1]], 
                  [split_diffs[10][0], split_diffs[10][-1]], 
                  [split_diffs[20][0], split_diffs[20][-1]], 
                  [split_diffs[30][0], split_diffs[30][-1]], 
                  [split_diffs[-1][0], split_diffs[-1][-1]]]
        plot_ind_t_pt_per_loss_diff(ablations[ind], fn=f'figures/test/{ind}.svg', title=str(ind) + ': ' + str(ind_ablation.shape[0]) + ' ablations', figure_size=(2,2), ranges=ranges)
        print(f'Diff r2: {pr2_diff:.2%}, Density r2: {pr2_density:.2%}, Avg Val: {pr2_avg_val:.2%}')
        diff_r2.append(pr2_diff)
        dens_r2.append(pr2_density)
        avgv_r2.append(pr2_avg_val)
    
    results = np.stack([diff_r2, dens_r2, avgv_r2])
    avg_res = np.mean(results, axis=1)
    print(f'Avg Diff r2: {avg_res[0]:.2%}, Avg Density r2: {avg_res[1]:.2%}, Avg Avg Val r2: {avg_res[2]:.2%}')
    return results


def reg_on_ind_feat(ablations, values, densities, ind, log_diff=True, log_dens=False, log_avg_val=False):
    ind_ablation = np.array(ablations[ind])
    
    ind_diffs = ind_ablation[:, -1] - ind_ablation[:, 0]
    valid = ind_diffs > 0
    
    ind_ablation = ind_ablation[valid]
    ind_diffs = ind_diffs[valid]
    ind_densities = np.array(densities[ind])[valid]
    ind_avg_val = np.array(values[ind])[valid] / ind_densities
    ind_log_diffs = np.log(ind_diffs)
    ind_log_densities = np.log(ind_densities)
    ind_log_avg_val = np.log(np.array(values[ind])[valid] / ind_densities)
    
    ind_t_pts = get_fine_t_pts(get_scaled_feat(ind_ablation), remove_return_below=False)
    
    ind_diffs = (ind_diffs - ind_diffs.mean()) / ind_diffs.std()
    ind_densities = (ind_densities - ind_densities.mean()) / ind_densities.std()
    ind_avg_val = (ind_avg_val - ind_avg_val.mean()) / ind_avg_val.std()
    ind_log_diffs = (ind_log_diffs - ind_log_diffs.mean()) / ind_log_diffs.std()
    ind_log_densities = (ind_log_densities - ind_log_densities.mean()) / ind_log_densities.std()
    ind_log_avg_val = (ind_log_avg_val - ind_log_avg_val.mean()) / ind_log_avg_val.std()

    if log_diff:
        D = ind_log_diffs
    else:
        D = ind_diffs

    if log_dens:
        R = ind_log_densities
    else:
        R = ind_densities

    if log_avg_val:
        A = ind_log_avg_val
    else:
        A = ind_avg_val

    # print(np.corrcoef(ind_diffs, ind_log_diffs)[0, 1])
    # print(np.corrcoef(ind_avg_val, ind_log_avg_val)[0, 1])

    X = sm.add_constant(np.column_stack([D, R, A]))
    # X = sm.add_constant(np.column_stack([ind_diffs, ind_densities, ind_avg_val, ind_log_diffs, ind_log_densities, ind_log_avg_val]))
    res = sm.OLS(ind_t_pts, X).fit()
    
    return res

# This method was written by ChatGPT. You can tell because it's documented far better.
def partial_r2(model, var_name):
    """
    Compute partial R² for a given predictor in a fitted OLS model.
    """
    # SSR_full: regression sum of squares from full model
    ssr_full = model.ssr
    
    # Get the reduced model (drop var_name)
    exog = model.model.exog
    exog_names = model.model.exog_names
    
    # keep all but the chosen variable
    keep = [i for i, name in enumerate(exog_names) if name != var_name]
    exog_reduced = exog[:, keep]
    
    reduced_model = sm.OLS(model.model.endog, exog_reduced).fit()
    ssr_reduced = reduced_model.ssr
    
    # Partial R² formula
    return (ssr_reduced - ssr_full) / ssr_reduced

## HELPER METHODS ##

def get_fine_t_pts(test_feature, threshold=0.5, remove_return_below=True):
    above_half = test_feature > threshold
    t_pts = np.argmax(above_half, axis=1)
    fine_t_pts = [t_pts[i] - 1 + (threshold - test_feature[i][t_pts[i]-1]) / (test_feature[i][t_pts[i]] - test_feature[i][t_pts[i]-1]) for i in range(len(test_feature))]

    if remove_return_below:
        fine_t_pts = [fine_t_pts[i] for i in range(above_half.shape[0]) if (test_feature.shape[1] - t_pts[i]) <= above_half[i].sum()]

    return np.array(fine_t_pts)


def split_features_by_diff(test_feature, diffs, num_groups):
    sorted_inds = np.argsort(diffs)
    return np.array_split(test_feature[sorted_inds], num_groups), np.array_split(diffs[sorted_inds], num_groups)


def split_features_by_diff_range(test_feature, diffs, ranges):
    groups = [] 
    masks = []
    for interval in ranges:
        mask = (diffs > interval[0]) & (diffs < interval[1])
        masks.append(mask)
        groups.append(test_feature[mask])

    return groups, masks


def get_scaled_feat(test_feature):
    scaled_test = test_feature - np.repeat(test_feature[:, 0][:, np.newaxis], test_feature.shape[1], axis=1)
    scaled_test = scaled_test / (np.repeat(scaled_test[:, -1][:, np.newaxis], test_feature.shape[1], axis=1) + 1e-12)

    return scaled_test
