import numpy as np
import plotly.express as px
import torch
import matplotlib.pyplot as plt

def scores_to_arrow(model, eap_scores, origins):

    arrows = {}

    n_layers,n_layers,n_tokens,n_tokens,n_heads = eap_scores.shape

    for source_layer_id in range(model.cfg.n_layers-1):
        for target_layer_id in range(source_layer_id,model.cfg.n_layers):
            for source_token_pos in origins:
                for target_token_pos in range(source_token_pos,n_tokens):
                    for head in range(n_heads):
                        
                        score = float(eap_scores[source_layer_id,target_layer_id,source_token_pos,target_token_pos,head])

                        if (source_layer_id,target_layer_id,source_token_pos,target_token_pos) not in arrows:
                            arrows[(source_layer_id,target_layer_id,source_token_pos,target_token_pos,head)] = score
                        
    arrows = [arrow_pos + tuple([score]) for arrow_pos, score in arrows.items()]

    return sorted(arrows, key = lambda x: x[-1],reverse=True)


def vizualize(attr_scores, title, eap_scores=None, n_arrows=30, origins=None):    
    y_labels = [str(i) for i in range(attr_scores.shape[0])]

    data = np.array(attr_scores.cpu())
    fig = px.imshow(
        data/np.max(abs(data)),
        y = y_labels,
        aspect="auto",
        zmin=-1, 
        zmax=1,
        color_continuous_scale='rdbu',
        title=f"{title}"
    )

    # Add x and y labels
    fig.update_xaxes(title_text="Layer")
    fig.update_yaxes(title_text="Token")

    # Show plot
    fig.show()


def create_chart(data_list, controlability=False):
    n_datasets = len(data_list)  # Number of datasets
    if controlability:
        fig, axs = plt.subplots(1, n_datasets, figsize=(14 * n_datasets, 6))  # Create subplots
        axs = [axs]
    else:
        fig, axs = plt.subplots(1, n_datasets, figsize=(7 * n_datasets, 6))  # Adjust height for non-controlability

    # Set a background color for the axes (inner figure)
    for ax in axs:
        ax.set_facecolor('#fdfdfd')  # Light gray background for the axes

    # Initialize variables to find global min and max
    global_min = float('inf')
    global_max = float('-inf')

    # First pass to calculate global min and max
    for data in data_list:
        for cluster_idx in range(len(data)):
            for key, result in data[cluster_idx].items():
                for metric in result.keys():
                    value = result.get(metric, np.nan)
                    if not np.isnan(value):
                        global_min = min(global_min, value)
                        global_max = max(global_max, value)

    for dataset_idx, data in enumerate(data_list):
        ax = axs[dataset_idx]
        # Determine the unique metrics
        metrics = list(next(iter(data[0].values())).keys())
        
        n_clusters = len(data)  # Number of clusters
        n_bars_per_group = len(metrics)
        group_width = 0.6
        bar_width = group_width / n_bars_per_group

        index = np.arange(n_clusters)

        # Space between bars in a segment
        bar_spacing = 0.01  # Adjust this value for more or less spacing

        if controlability:
            colors = [
                "#FFD711",  # Lime Green 2 (a bit darker for gradient effect)
                "#32CD32",  # Lime Green 2 (lightest)
                "#2EAF2E",  # Darker Lime Green
                "#279527",  # A shade darker
                "#207B20",
                "#1E90FF",  # Dodger Blue (lightest)
                "#007FFF",  # A slightly darker blue
                "#0000FF",  # Standard Blue
                "#00008B"   
            ]
        else:
            colors = [
                "#FFD711",  # Lime Green 2 (a bit darker for gradient effect)
                "#32CD32",  # Lime Green 2 (lightest)
                "#1E90FF",  # Dodger Blue (lightest)
            ]

        # Plot each metric
        for j, metric in enumerate(metrics):
            values = []
            labels = []
            for cluster_idx in range(n_clusters):
                for key, result in data[cluster_idx].items():
                    value = result.get(metric, np.nan)
                    values.append(value)
                    # Format the label with a break between the tuple elements
                    label = '\n'.join(key)
                    labels.append(label)
            
            # Add spacing between bars in the same segment
            bar_positions = index + j * (bar_width + bar_spacing)  # Shift positions with spacing
            ax.bar(bar_positions, values, bar_width, label=f'{metric}', color=colors[j % len(colors)])
                
        # Set x-ticks for each key
        ax.set_xticks(index + (group_width / 2) - (bar_width / 2))
        ax.set_xticklabels(labels, rotation=45, ha='right')

        # Set y-limits to the global min and max
        ax.set_ylim(global_min, global_max)

        # Add a grid to the background for better readability
        ax.yaxis.grid(True, linestyle='--', alpha=0.7)  # Horizontal grid lines
        ax.set_axisbelow(True)  # Place grid lines below the bars
        
        ax.axhline(y=1, color='black', linestyle='--', linewidth=1)

    # Create a shared legend for both plots
    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.show()



def detect_and_visualize_outliers(sweep_results_list,labels,percentile=0.7):
    # Calculate mean ablation scores for each cluster
    ablation_scores = []
    for sweep_results in sweep_results_list:
        scores = [sweep_results[key] for key in sweep_results.keys()]
        ablation_scores.append(max(scores))

    # Convert the mean scores to a numpy array for easy manipulation
    ablation_scores = np.array(ablation_scores)
    overall_mean = np.mean(ablation_scores)

    # Relative difference threshold: clusters with mean scores 20% below the overall mean
    relative_threshold = percentile * overall_mean
    relative_outliers = np.where(ablation_scores < relative_threshold)[0]

    # Plot the distribution of mean ablation scores with outliers highlighted
    plt.figure(figsize=(14, 6))
    plt.bar(range(len(ablation_scores)), ablation_scores, color='skyblue', edgecolor='black', label='Groups')
    plt.axhline(overall_mean, color='red', linestyle='--', label=f'Overall Mean: {overall_mean:.2f}')
    plt.axhline(relative_threshold, color='orange', linestyle='--', label=f'{int((1-percentile)*100)}% Below Mean Threshold: {relative_threshold:.2f}')

    # Highlight outliers
    for idx in relative_outliers:
        plt.bar(idx, ablation_scores[idx], color='#E76F51', edgecolor='black')

    # Add labels and title
    plt.xlabel('Groups')
    plt.ylabel('Change in Logit Difference')
    plt.xticks(range(len(ablation_scores)), [f'{labels[i]}' for i in range(len(ablation_scores))])
    plt.legend()
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()

    return relative_outliers, ablation_scores


def viz_sweep_results(sweep_results_list,labels):
    plt.figure(figsize=(14, 6))

    # Loop through each cluster's sweep results
    for idx, sweep_results in enumerate(sweep_results_list):
        subset_sizes = [int(key.split('_')[-1]) for key in sweep_results.keys()]
        scores = [sweep_results[key] for key in sweep_results.keys()]

        # Plot the scores for this cluster
        plt.plot(subset_sizes, scores, marker='o', label=f'{labels[idx]}')

    # Adding labels and title
    plt.xlabel('Subset Size')
    plt.ylabel('Change in Logit Difference')
    plt.legend(title='Groups')
    plt.grid(True)
    plt.show()