import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import matplotlib.patches as mpatches


def dist_metric_best(matrix1, matrix2):
    """
    Calculates the error of the best estimated mean.
    """
    best_estimate = 10e20
    for row1 in matrix1:
        # Calculate distances from this row to all rows in matrix2
        distances = np.linalg.norm(matrix2 - row1, axis=1)
        min_distance = np.min(distances)
        best_estimate = min(best_estimate, min_distance)
    return best_estimate


def dist_metric(matrix1, matrix2):
    """
    Calculates the error of the worst estimated mean.
    """
    max_min_distance = -1
    for row1 in matrix1:
        # Calculate distances from this row to all rows in matrix2
        distances = np.linalg.norm(matrix2 - row1, axis=1)
        min_distance = np.min(distances)
        max_min_distance = max(max_min_distance, min_distance)
    return max_min_distance


def plot_experiment_datapoints(df, noise_model='gaussian', num_clusters=None, excluded_methods=[], seed=None):
    plt.figure(figsize=(10, 6))

    COLORS=['blue', 'tab:orange', 'green', 'red', 'purple', 'gray']

    data = df.copy()
    data['method'] = data['method'].replace({'Full RML': 'Ours', 'Bounded RML': 'Ours', 'Gaussian Multifilter': 'LD-ME', 'LDME without Outer': 'Inner Stage'})

    # Exclude specified methods
    data = data[~data['method'].isin(excluded_methods)]
    if seed is not None:
        data = data[data['seed'] == seed]

    COLORS = COLORS[:len(data['method'].unique())]
    if(len(data['method'].unique()) > 5):
        COLORS=['blue', 'tab:orange', 'green', 'gray', 'red', 'purple']

    sns.scatterplot(data=data,x='list_size' , y='error', hue='method',markers=True,palette=COLORS,s=80)

    if(num_clusters is not None):
        plt.axvline(x=num_clusters, color='gray', linestyle='--', label='True Number of Clusters')
    plt.xlabel('List Size', fontsize=26)
    plt.ylabel('Worst Estimation Error', fontsize=26)
    plt.legend(fontsize=15, markerscale=1.6)
    plt.xscale('log')
    plt.yscale('log')
    plt.tick_params(axis='both', which='major', labelsize=15)
    plt.savefig('kmeans_v_filter_{}_with_CI.pdf'.format(noise_model))
    plt.show()


def compare_error_and_size_constraint(df, log_scale=True, noise_model='gaussian', constraint='list', relax_factor=1):
    """constraint can be 'list' or 'err' for list size or error constraint, respectively."""

    ours_min_error = df[df['method'] == 'Ours']['error'].min()
    l_star = df[(df['method'] == 'Ours') & (df['error'] == ours_min_error)]['list_size'].iloc[0]

    if(constraint == 'list'):
        # Find the best error for each method where list_size <= l_star
        best_stats = df[df['list_size'] <= l_star].groupby('method')['error'].min()
    elif(constraint == 'err'):
        # Find the best list size for each method where error <= 2*ours_min_error
        best_stats = df[df['error'] <= relax_factor*ours_min_error].groupby('method')['list_size'].min()

    COLORS=['green', 'blue', 'purple', 'red']
    if(len(best_stats) > 5):
        COLORS=['gray', 'lightgreen', 'blue', 'green', 'purple', 'red']

    # Plotting
    plt.figure(figsize=(8, 6))
    sns.barplot(x=best_stats.index, y=best_stats.values, palette=COLORS)
    plt.xlabel('Method')
    if(log_scale):
        plt.yscale('log')
    if(constraint == 'list'):
        plt.ylabel('Error')
        plt.title('Comparison of Methods by Error with List Size Constraint')
    elif(constraint == 'err'):
        plt.ylabel('List Size')
        plt.title('Comparison of Methods by List Size with Error Constraint')
    plt.savefig('comp_{}_{}.pdf'.format(noise_model, constraint))
    plt.show()


def compare_mean_error_with_list_constraint(df, log_scale=True, noise_model='gaussian'):
    # Assuming df is your DataFrame and it contains 'method', 'error', 'list_size', and 'seed' columns

    # Step 1: Find the l_star for each seed
    l_star_per_seed = df[df['method'] == 'Ours'].groupby('seed')['error'].idxmin().map(df['list_size'])

    # Step 2: Filter df based on l_star for each seed and calculate mean and CI
    filtered_dfs = []
    for seed, l_star in l_star_per_seed.items():
        filtered_df = df[(df['seed'] == seed) & (df['list_size'] <= l_star)]
        best_errors = filtered_df.groupby('method')['error'].min().reset_index()
        best_errors['seed'] = seed
        filtered_dfs.append(best_errors)

    combined_df = pd.concat(filtered_dfs)

    COLORS=['lightgreen', 'blue', 'green', 'purple', 'red']

    # Calculate mean and CI (95% CI assumed)
    error_stats = combined_df.groupby('method')['error'].agg(['mean', 'std', 'count'])
    error_stats['ci_95'] = 1.96 * error_stats['std'] / np.sqrt(error_stats['count'])

    # Step 3: Plotting with Error Bars
    plt.figure(figsize=(8, 6))
    sns.barplot(x='method', y='mean', yerr=error_stats['ci_95'], data=error_stats.reset_index(), palette=COLORS)
    plt.xlabel('Method')
    plt.ylabel('Mean Error with 95% CI')
    if(log_scale):
        plt.yscale('log')
    plt.title('Comparison of Methods by Mean Error with List Size Constraint')
    plt.savefig('comp_{}_list_mean_ci.pdf'.format(noise_model))
    plt.show()


def compare_mean_size_with_error_constraint(df, log_scale=True, noise_model='gaussian', relax_factor=1):
   # Assuming df is your DataFrame and it contains 'method', 'error', 'list_size', and 'seed' columns

    # Step 1: Find the error constraint
    ours_min_error = df[df['method'] == 'Ours']['error'].min()
    error_constraint = 1* ours_min_error

    # Step 2: Filter df based on the error constraint and calculate mean and CI for each method and seed
    filtered_df = df[df['error'] <= relax_factor*error_constraint]

    # Step 3: Calculate mean and CI for list_size
    list_size_stats = filtered_df.groupby(['method', 'seed'])['list_size'].agg(['mean', 'std', 'count'])
    list_size_stats['ci_95'] = 1.96 * list_size_stats['std'] / np.sqrt(list_size_stats['count'])

    # Aggregate over seeds
    final_stats = list_size_stats.groupby('method').mean().reset_index()

    COLORS=['grey', 'lightgreen', 'blue', 'green', 'red', 'purple']

    # Step 4: Plotting with Error Bars
    plt.figure(figsize=(8, 6))
    sns.barplot(x='method', y='mean', yerr=final_stats['ci_95'], data=final_stats, palette=COLORS)
    plt.xlabel('Method')
    plt.ylabel('Average List Size with 95% CI')
    if(log_scale):
        plt.yscale('log')
    plt.title('Comparison of Methods by Mean List Size with Error Constraint')
    plt.savefig('comp_{}_err_mean_ci.pdf'.format(noise_model))
    plt.show()


def process_df_mean(df, noise_model_name, constraint = 'err', fixing='Bounded RML', relax_factor=1, max_val=30, num_clusters=None):
    # Step 1: Find the error constraint
    l_star_per_seed = df[df['method'] == fixing].groupby('seed')['error'].idxmin().map(df['list_size'])
    error_min_per_seed = df[df['method'] == fixing].groupby('seed')['error'].min()
    error_constraint = relax_factor * error_min_per_seed
    
    final_stats_list = []
    for seed, l_star in l_star_per_seed.items():
        for method in df['method'].unique():
            # Filter df for each method based on the error constraint
            if(constraint == 'err'):
                method_df = df[(df['method'] == method) & (df['seed'] == seed) & (df['error'] <= error_constraint[seed])]

                if method_df.empty:
                    # If no rows pass the filter for a method, use the default value
                    final_stats_list.append({'method': method, 'min': max_val, 'seed': seed})
                else:
                    # Calculate min for list_size and add to final_stats_list
                    min_list_size = method_df['list_size'].min()
                    final_stats_list.append({'method': method, 'min': min_list_size, 'seed': seed})
            elif(constraint == 'list'):
                if num_clusters is not None:
                    l_star = max(l_star, num_clusters)
                    
                method_df = df[(df['method'] == method) & (df['seed'] == seed) & (df['list_size'] <= l_star)]

                if method_df.empty:
                    # If no rows pass the filter for a method, use the default value
                    final_stats_list.append({'method': method, 'min': max_val, 'seed': seed})
                else:
                    # Calculate min for list_size and add to final_stats_list
                    min_error = method_df['error'].min()
                    final_stats_list.append({'method': method, 'min': min_error, 'seed': seed})

    #final_stats = final_stats.groupby(['seed'])
    final_stats = pd.DataFrame(final_stats_list)
    #final_stats = final_stats.groupby(['seed'])
    final_stats['noise_model'] = noise_model_name
    final_stats = final_stats.groupby(['method', 'noise_model'])['min'].agg(['mean', 'std', 'count'])
    final_stats['ci_95'] = 1.96 * final_stats['std'] / np.sqrt(final_stats['count'])
    final_stats = final_stats.reset_index()

    return final_stats


def plot_all_noise_models(constraint='err', fixing='Bounded RML', relax_factor=1, df_adv_line=None, df_gauss=None, df_adv_circle=None, df_uniform=None, log_scale=False, max_val=30, num_clusters=None, excluded_methods=[]):
    dfs = []
    if df_adv_circle is not None:
        circ_stats = process_df_mean(df_adv_circle, 'Adversarial Clusters', constraint=constraint, fixing=fixing, relax_factor=relax_factor, max_val=max_val, num_clusters=num_clusters)
        dfs.append(circ_stats)
    if df_gauss is not None:
        gauss_stats = process_df_mean(df_gauss, 'Gaussian', constraint=constraint, fixing=fixing, relax_factor=relax_factor, max_val=max_val, num_clusters=num_clusters)
        dfs.append(gauss_stats)
    if df_adv_line is not None:
        adv_stats = process_df_mean(df_adv_line, 'Adversarial Line', constraint=constraint, fixing=fixing, relax_factor=relax_factor, max_val=max_val, num_clusters=num_clusters)
        dfs.append(adv_stats)
    if df_uniform is not None:
        uniform_stats = process_df_mean(df_uniform, 'Uniform', constraint=constraint, fixing=fixing, relax_factor=relax_factor, max_val=max_val, num_clusters=num_clusters)
        dfs.append(uniform_stats)

    combined_stats = pd.concat(dfs)

    # Modify label names
    combined_stats['method'] = combined_stats['method'].replace({'Full RML': 'Ours', 'Bounded RML': 'Ours', 'Gaussian Multifilter': 'LD-ME', 'LDME without Outer': 'Inner Stage'})

    # Exclude specified methods
    combined_stats = combined_stats[~combined_stats['method'].isin(excluded_methods)]

    #fig, ax = plt.subplots(figsize=(13, 6))
    fig, ax = plt.subplots(figsize=(13, 7))

    width = 0.15
    # Reorder methods
    methods = []
    for method in ['Kmeans', 'Robust Kmeans', 'DBScan', 'Ours', 'LD-ME', 'Inner Stage']:
        if(method in combined_stats['method'].unique()):
            methods.append(method)
    positions = np.arange(len(combined_stats['noise_model'].unique()))

    # Apply hatch patterns and create legend handles
    hatch_patterns = ['/', '\\', '|', '-', 'x', '+']
    legend_handles = []
    COLORS=['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:gray']
    
    for i, method in enumerate(methods):
        method_stats = combined_stats[combined_stats['method'] == method]
        color = COLORS[i % len(COLORS)]
        bars = ax.bar(positions + i * width, method_stats['mean'], width, yerr=method_stats['ci_95'], label=method, color=color)
        
        hatch = hatch_patterns[i % len(hatch_patterns)]
        for bar in bars:
            bar.set_hatch(hatch)
        
        # Create legend handles
        legend_handles.append(mpatches.Patch(facecolor=bars[0].get_facecolor(), hatch=hatch, label=method))

    ax.set_xticks(positions + width * (len(combined_stats['method'].unique()) - 1) / 2)
    attack_labels = ['Attack 1', 'Attack 2', 'Attack 3']
    #ax.set_xticklabels(combined_stats['noise_model'].unique())
    ax.set_xticklabels(attack_labels)

    plt.locator_params(nbins=4)
    plt.xlabel('Noise Models', fontsize=30)
    if(constraint=='list'):
        plt.ylabel('Mean Worst Estimation Error', fontsize=30)
        plt.ylim(0, max_val)
    elif(constraint=='err'):
        plt.ylabel('Mean List Size', fontsize=30)
        plt.ylim(0, max_val)
    else:
        assert False, "Unknown constraint"

    if(log_scale):
        plt.yscale('log')

    if(num_clusters is not None):
        plt.axhline(y=num_clusters, color='gray', linestyle='--', label='True Number of Clusters')

    ax.tick_params(axis='both', which='major', labelsize=28)
    ax.legend(handles=legend_handles)
    plt.legend()
    plt.tight_layout()
    plt.legend().remove()
    plt.savefig(f'combined_noise_models_{constraint}_comparison_{fixing}.pdf')
    plt.show()

    legend_fig = plt.figure(figsize=(8.2,0.7))
    ncol = len(legend_handles)
    fontsize = 14 - max(0, ncol - 4) * 2.5
    legend_fig.legend(handles=legend_handles, fontsize=fontsize, loc='upper left', title_fontsize=fontsize, 
            handlelength=3, handleheight=2, ncol=ncol)
    legend_fig.canvas.draw()
    legend_fig.tight_layout()
    legend_fig.savefig(f'legend_{constraint}_{fixing}.pdf')

    return combined_stats


def process_df_mean_new_metric(df, noise_model_name, constraint = 'err', max_val=30, percentile_constraint=50, percentile_points=75, seed=0, averaging=False, median=False):
    data = df.copy()
    data['method'] = data['method'].replace({'Full RML': 'Ours', 'Bounded RML': 'Ours', 'Gaussian Multifilter': 'LD-ME', 'LDME without Outer': 'Inner Stage'})

    # Step 0: Filter df for the specified seed
    data = data[data['seed'] == seed]

    # Count occurences of each value
    l_star = data[data['method'] == 'Ours']['list_size'].value_counts().idxmax()
    err_star = np.percentile(data[data['method'] == 'Ours']['error'], percentile_constraint)
    err_star = min(x for x in data[data['method'] == 'Ours']['error'] if x > err_star)

    # Step 2: Throw away all points that are above the 75th percentile of constraint fulfilling data points for each method
    final_stats_list = []
    for method in data['method'].unique():
        # Filter df for each method based on the error constraint
        if(constraint == 'err'):
            
            method_df = data[(data['method'] == method) & (data['error'] <= err_star)]

            if method_df.empty:
                # If no rows pass the filter for a method, use the default value
                final_stats_list.append({'method': method, 'cutoff': max_val, 'min': max_val, 'max': max_val, 'seed': seed})
            else:
                # Add all points within 75th percentile of points to final_stats_list

                cutoffs = []

                for percentile in [25, 50, 75]:
                    threshold = np.percentile(method_df['list_size'], percentile)
                    cutoff_values_size = [x for x in method_df['list_size'] if x > threshold]
                    if len(cutoff_values_size) == 0:
                        cutoff = max(method_df['list_size'])
                    else:
                        cutoff = min(cutoff_values_size)
                    
                    cutoffs.append(cutoff)

                min_cutoff, cut_off, max_cutoff = cutoffs

                final_stats_list.append({'method': method, 'cutoff': cut_off, 'min': min_cutoff, 'max': max_cutoff, 'seed': seed})

        elif(constraint == 'list'):
            method_df = data[(data['method'] == method) & (data['list_size'] <= l_star) & (data['list_size'] >= l_star-1)]

            if method_df.empty:
                # If no rows pass the filter for a method, use the default value
                final_stats_list.append({'method': method, 'cutoff': max_val, 'min': max_val, 'max': max_val, 'seed': seed})
            else:
                # Add all points within 75th percentile of points to final_stats_list
                cutoffs = []

                for percentile in [25, 50, 75]:
                    threshold = np.percentile(method_df['error'], percentile)
                    cutoff_values_err = [x for x in method_df['error'] if x > threshold]
                    if len(cutoff_values_err) == 0:
                        cutoff = max(method_df['error'])
                    else:
                        cutoff = min(cutoff_values_err)
                    
                    cutoffs.append(cutoff)

                min_cutoff, cut_off, max_cutoff = cutoffs

                final_stats_list.append({'method': method, 'cutoff': cut_off, 'min': min_cutoff, 'max': max_cutoff, 'seed': seed})

    final_stats = pd.DataFrame(final_stats_list)
    final_stats['noise_model'] = noise_model_name

    return final_stats


def plot_all_noise_models_quantile_average(constraint='err', df_adv_line=None, df_gauss=None, df_adv_circle=None, max_val=30, percentile_constraint=50, percentile_points=75, seed=0, excluded_methods=[], log_scale=False, median=False):

    # Combine all dataframes
    dfs = []
    if df_adv_circle is not None:
        circ_stats = process_df_mean_new_metric(df_adv_circle, 'Adversarial Clusters', constraint=constraint, max_val=max_val, percentile_constraint=percentile_constraint, percentile_points = percentile_points, seed=seed, averaging= not(median), median=median)
        dfs.append(circ_stats)
    if df_adv_line is not None:
        adv_stats = process_df_mean_new_metric(df_adv_line, 'Adversarial Line', constraint=constraint, max_val=max_val, percentile_constraint=percentile_constraint, percentile_points = percentile_points, seed=seed, averaging=not(median), median=median)
        dfs.append(adv_stats)
    if df_gauss is not None:
        gauss_stats = process_df_mean_new_metric(df_gauss, 'Gaussian', constraint=constraint, max_val=max_val, percentile_constraint=percentile_constraint, percentile_points = percentile_points, seed=seed, averaging=not(median), median=median)
        dfs.append(gauss_stats)

    combined_stats = pd.concat(dfs)

    # Exclude specified methods
    combined_stats = combined_stats[~combined_stats['method'].isin(excluded_methods)]

    fig, ax = plt.subplots(figsize=(13, 7))

    width = 0.15
    # Reorder methods
    methods = []
    for method in ['Kmeans', 'Robust Kmeans', 'DBScan', 'LD-ME', 'Ours', 'Inner Stage']:
        if(method in combined_stats['method'].unique()):
            methods.append(method)
    positions = np.arange(len(combined_stats['noise_model'].unique()))

    # Apply hatch patterns and create legend handles
    hatch_patterns = ['/', '\\', '|', '-', 'x', '+']
    legend_handles = []
    COLORS=['tab:blue', 'tab:orange', 'tab:green', 'tab:purple', 'tab:red', 'tab:gray']
    
    for i, method in enumerate(methods):
        method_stats = combined_stats[combined_stats['method'] == method]
        color = COLORS[i % len(COLORS)]
        bars = ax.bar(positions + i * width, method_stats['cutoff'], width, yerr=[method_stats['cutoff']-method_stats['min'], method_stats['max']-method_stats['cutoff']], label=method, color=color)
        
        hatch = hatch_patterns[i % len(hatch_patterns)]
        for bar in bars:
            bar.set_hatch(hatch)
        
        # Create legend handles
        legend_handles.append(mpatches.Patch(facecolor=bars[0].get_facecolor(), hatch=hatch, label=method))

    ax.set_xticks(positions + width * (len(combined_stats['method'].unique()) - 1) / 2)
    attack_labels = ['Attack 1', 'Attack 2', 'Attack 3']
    ax.set_xticklabels(attack_labels)

    #plt.locator_params(nbins=4)
    plt.xlabel('Noise Models', fontsize=30)
    if(constraint=='list'):
        if(median):
            plt.ylabel('Median Worst Estimation Error', fontsize=30)
        else:
            plt.ylabel('Worst Estimation Error', fontsize=30)
        plt.ylim(0, max_val)
    elif(constraint=='err'):
        if(median):
            plt.ylabel('Median List Size', fontsize=30)
        else:
            plt.ylabel('List Size', fontsize=30)
        plt.ylim(0, max_val)
    else:
        assert False, "Unknown constraint"


    ax.tick_params(axis='both', which='major', labelsize=28)
    ax.legend(handles=legend_handles)
    plt.legend()
    plt.tight_layout()
    plt.legend().remove()
    plt.savefig(f'combined_noise_models_{constraint}_comparison_.pdf')
    plt.show()

    legend_fig = plt.figure(figsize=(8.2,0.7))
    ncol = len(legend_handles)
    fontsize = 14 - max(0, ncol - 4) * 2.5
    legend_fig.legend(handles=legend_handles, fontsize=fontsize, loc='upper left', title_fontsize=fontsize, 
            handlelength=3, handleheight=2, ncol=ncol)
    legend_fig.canvas.draw()
    legend_fig.tight_layout()
    legend_fig.savefig(f'legend_{constraint}_.pdf')

    return combined_stats


def plot_all_noise_models_new_metric(constraint='err', df_adv_line=None, df_gauss=None, df_adv_circle=None, max_val=30, percentile_constraint=50, percentile_points=75, seed=0, excluded_methods=[], log_scale=False):

    # Combine all dataframes
    dfs = []
    if df_adv_circle is not None:
        circ_stats = process_df_mean_new_metric(df_adv_circle, 'Adversarial Clusters', constraint=constraint, max_val=max_val, percentile_constraint=percentile_constraint, percentile_points = percentile_points, seed=seed)
        dfs.append(circ_stats)
    if df_gauss is not None:
        gauss_stats = process_df_mean_new_metric(df_gauss, 'Gaussian', constraint=constraint, max_val=max_val, percentile_constraint=percentile_constraint, percentile_points = percentile_points, seed=seed)
        dfs.append(gauss_stats)
    if df_adv_line is not None:
        adv_stats = process_df_mean_new_metric(df_adv_line, 'Adversarial Line', constraint=constraint, max_val=max_val, percentile_constraint=percentile_constraint, percentile_points = percentile_points, seed=seed)
        dfs.append(adv_stats)

    combined_stats = pd.concat(dfs)

    # Exclude specified methods
    combined_stats = combined_stats[~combined_stats['method'].isin(excluded_methods)]

    fig, ax = plt.subplots(figsize=(13, 7))

    # Apply hatch patterns and create legend handles
    COLORS=['tab:blue', 'tab:orange', 'tab:green', 'tab:gray', 'tab:red', 'tab:purple']

    values_75th_percentile = dict()
    for method in combined_stats['method'].unique():
        values_75th_percentile[method] = []
        for noise_model in combined_stats['noise_model'].unique():
            method_stats = combined_stats[(combined_stats['method'] == method) & (combined_stats['noise_model'] == noise_model)]
            values_75th_percentile[method].append(max(method_stats['min']))

    catplot = sns.stripplot(data=combined_stats, x='noise_model', y='min', hue='method', dodge=True, palette=COLORS, edgecolor='w', marker='o')
    catplot.set_xlabel('Noise Models', fontsize=25)
    catplot.set_ylabel('List Size' if constraint == 'err' else 'Worst Estimation Error', fontsize=25)
    catplot.set_xticklabels(['Attack 1', 'Attack 2', 'Attack 3'], fontsize=20)
    catplot.set_yticklabels(catplot.get_yticks(), size=20)
    
    plt.tight_layout()
    plt.savefig(f'combined_noise_models_{constraint}_comparison_.pdf')
    plt.show()

    return combined_stats, values_75th_percentile
