import torch
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from utils.intervention_utils import replace_activations_and_evaluate_logits_on_each_head, inject_function_vector, \
    get_answer_token_id
from utils.logging_utils import logger


def calculate_causal_indirect_effect(model_wrapper, dataset, mean_activations, last_token_only=True):
    """
    Calculate the causal indirect effect of a model on the [clean ICL prompts/concept instruction] dataset.
    :param dataset: [corrupted ICL prompts/empty concept instruction] dataset
    :param mean_activations: the extracted activations from the clean prompts
    :param model_wrapper: subject model wrapper
    :param last_token_only: whether to only consider the last token of the prompt
    :return: the causal indirect effect of each layer and head, Shape: (batch_size, n_layers, n_heads)
    """
    model_config = model_wrapper.model_config
    n_trials = len(dataset)

    # get logic labels sequence length
    if model_config['prepend_bos']:
        logic_seq_len = len(dataset[0]['prompt']) + 1
    else:
        logic_seq_len = len(dataset[0]['prompt'])

    if last_token_only:
        # Default case
        indirect_effects = torch.zeros(n_trials,
                                       model_config['n_layers'],
                                       model_config['n_heads'])  # Shape: (batch_size, n_layers, n_heads)
    else:
        # Replace all the logic slots activations
        indirect_effects = torch.zeros(n_trials,
                                       model_config['n_layers'],
                                       model_config['n_heads'],
                                       logic_seq_len)  # Shape: (batch_size, n_layers, n_heads, logic_seq_len)

    for i in tqdm(range(n_trials), desc="Calculating indirect effects", leave=False):
        ind_effect = replace_activations_and_evaluate_logits_on_each_head(
            prompt=dataset[i]['prompt'],
            answer=dataset[i]['answer'],
            mean_activations=mean_activations,
            model_wrapper=model_wrapper,
            last_token_only=last_token_only,
            logic_seq_len=logic_seq_len)

        indirect_effects[i] = ind_effect.squeeze()  # Shape: (1, n_layers, n_heads) -> (n_layers, n_heads)

    return indirect_effects


def decode_logits_to_labels(logits, target_token_id, tokenizer):
    """
    Decode the logits to labels.
    :param logits: the logits sequnce to decode
    :param tokenizer: the corresponding tokenizer
    :param target_token_id: the token ID of the target answer
    :return: decoded labels
    :return max_prob: the maximum probability of the predicted token
    :return target_token_prob: the probability of the target token
    """
    probs = torch.softmax(logits, dim=-1)

    # Get the max prob token and target token
    max_prob, predicted_token_ids = probs.max(dim=-1)
    target_token_prob = probs[:, target_token_id]

    # Decode the token IDs to text Shape: (batch_size, )
    decoded_labels = tokenizer.batch_decode(predicted_token_ids)

    return decoded_labels, max_prob, target_token_prob


def evaluate_fv_injected_outputs(model_wrapper, dataset, function_vector, target_layer, inject_mode='single_step'):
    """
    Evaluate the outputs after injecting the corresponding function vector on the dataset.
    :param model_wrapper: subject model wrapper
    :param dataset: test dataset
    :param function_vector: the function vector to inject

    Return human-readable outputs list
    :return single_step: a list of tuples (clean_result, intervened_result) including the max prob token and target token
    :return full_sequence: a list of complete sentences
    """
    tokenizer = model_wrapper.tokenizer
    device = model_wrapper.device
    n_trials = len(dataset)

    injected_results = []

    for i in tqdm(range(n_trials), desc="Evaluating injected outputs"):
        # Prepare the input prompt
        prompt = dataset[i]['prompt']
        answer = dataset[i]['answer']

        # Get the first token_id of the answer tokens
        answer_token_id = get_answer_token_id(answer, tokenizer).to(device)  # Type: torch.LongTensor

        # Inject the function vector into the model
        intervened_output, clean_output = inject_function_vector(model_wrapper=model_wrapper, prompt=prompt,
                                                                 function_vector=function_vector,
                                                                 target_layer=target_layer, inject_mode=inject_mode)

        if inject_mode == 'single_step':
            # Decode the raw logits to a human-readable string
            clean_result = decode_logits_to_labels(clean_output, answer_token_id, tokenizer)
            intervened_result = decode_logits_to_labels(intervened_output, answer_token_id, tokenizer)
            injected_results.append((clean_result, intervened_result))
        elif inject_mode == 'full_sequence':
            # Decode the full tokenID sequence of logits to a human-readable string
            result = tokenizer.decode(intervened_output.sequence[0])
            injected_results.append(result)

    return injected_results


def evaluate_ablated_output_results(model_wrapper, dataset, ablated_output_list):
    """
    Evaluate the outputs after ablating the corresponding function vector on the dataset.
    :param model_wrapper: subject model wrapper
    :param dataset: test dataset
    :param ablated_output_list: the ablated outputs to evaluate

    Return human-readable outputs list
    :return single_step: a list of tuples (clean_result, ablated_result) including the max prob token and target token
    :return full_sequence: a list of complete sentences
    """
    tokenizer = model_wrapper.tokenizer
    device = model_wrapper.device
    n_trials = len(dataset)

    ablated_success_count = 0

    assert len(
        ablated_output_list) == n_trials, "The number of ablated outputs must match the number of dataset samples."

    for i in tqdm(range(n_trials), desc="Evaluating ablated outputs"):
        # Get the answer
        answer = dataset[i]['answer']

        # Get the first token_id of the answer tokens
        target_str = answer.split()[0]

        ablated_output = ablated_output_list[i]

        # Decode the raw logits to a human-readable string
        probs = torch.softmax(ablated_output, dim=-1)
        max_prob, predicted_token_ids = probs.max(dim=-1)
        decoded_str = tokenizer.batch_decode(predicted_token_ids)

        ablated_split = decoded_str[0].split()
        if ablated_split:
            ablated_result_str = ablated_split[0]
        else:
            ablated_result_str = decoded_str[0]

        if target_str in ablated_result_str or (
                ablated_result_str in target_str and len(target_str) - len(ablated_result_str) > 2):
            ablated_success_count += 1

    return ablated_success_count / n_trials


def evaluate_fv_injected_outputs_on_per_prompt(model_wrapper, dataset, function_vectors, target_layer,
                                               inject_mode='single_step'):
    """
    Evaluate the outputs after injecting the corresponding function vector on the dataset.
    :param model_wrapper: subject model wrapper
    :param dataset: test dataset
    :param function_vectors: the function vectors to inject for each prompt

    Return human-readable outputs list
    :return single_step: a list of tuples (clean_result, intervened_result) including the max prob token and target token
    :return full_sequence: a list of complete sentences
    """
    tokenizer = model_wrapper.tokenizer
    device = model_wrapper.device
    n_trials = len(dataset)

    injected_results = []

    assert function_vectors.shape[
               0] == n_trials, "The number of function vectors must match the number of dataset samples."

    for i in tqdm(range(n_trials), desc="Evaluating injected outputs on per prompt"):
        # Prepare the input prompt
        prompt = dataset[i]['prompt']
        answer = dataset[i]['answer']

        # Prepare the function vector for the current prompt
        function_vector = function_vectors[i]  # Shape: (1, head_dim)

        # Get the first token_id of the answer tokens
        answer_token_id = get_answer_token_id(answer, tokenizer).to(device)  # Type: torch.LongTensor

        # Inject the function vector into the model
        intervened_output, clean_output = inject_function_vector(model_wrapper=model_wrapper, prompt=prompt,
                                                                 function_vector=function_vector,
                                                                 target_layer=target_layer, inject_mode=inject_mode)

        if inject_mode == 'single_step':
            # Decode the raw logits to a human-readable string
            clean_result = decode_logits_to_labels(clean_output, answer_token_id, tokenizer)
            intervened_result = decode_logits_to_labels(intervened_output, answer_token_id, tokenizer)
            injected_results.append((clean_result, intervened_result))
        elif inject_mode == 'full_sequence':
            # Decode the full tokenID sequence of logits to a human-readable string
            result = tokenizer.decode(intervened_output.sequence[0])
            injected_results.append(result)

    return injected_results


def evaluate_model_clean_output(model_wrapper, dataset):
    """
    Evaluate the outputs after injecting the corresponding function vector on the dataset.
    :param model_wrapper: subject model wrapper
    :param dataset: test dataset

    Return human-readable outputs list
    :return a list of tuples clean_results including the max prob token and target token
    """
    tokenizer = model_wrapper.tokenizer
    device = model_wrapper.device
    n_trials = len(dataset)
    model = model_wrapper.model

    clean_results = []

    for i in tqdm(range(n_trials), desc="Evaluating model clean outputs"):
        # Prepare the input prompt
        prompt = dataset[i]['prompt']
        answer = dataset[i]['answer1']

        # Get the first token_id of the answer tokens
        answer_token_id = get_answer_token_id(answer, tokenizer).to(device)  # Type: torch.LongTensor

        # Tokenize the prompt
        inputs = tokenizer(prompt, is_split_into_words=True, return_tensors='pt').to(device)

        clean_output = None

        with torch.no_grad():
            clean_output = model.generate(
                **inputs,
                max_new_tokens=5
            )  # Shape: (batch_size, sequence_length)

        new_ids = clean_output[0, inputs['input_ids'].shape[1]:]  # Get the new token IDs generated by the model

        # Decode the raw logits to a human-readable string
        clean_result = tokenizer.decode(new_ids, skip_special_tokens=True)

        clean_results.append(clean_result)

    return clean_results


def evaluate_intervened_success_rate(model_wrapper, scaled_factor, function_vector, inject_dataset,
                                     intervened_layer_list=None):
    """
    Evaluate the success rate of the intervention on the dataset with different scaled factors.
    :param model_wrapper:
    :param scaled_factor:
    :param function_vector:
    :param inject_dataset:
    :param intervened_layer_list: List of layers to intervene, if None, all layers will be evaluated
    :return:
    df: DataFrame containing the success rates for each layer and scaled factor
    clean_success_rate: The success rate of the clean model on the inject_dataset
    """
    if intervened_layer_list is None:
        intervened_layer_list = list(range(model_wrapper.model_config['n_layers']))

    records = []
    clean_success_flag = False
    clean_success_rate = 0.0

    for k in range(len(scaled_factor)):
        scaled_function_vector = function_vector * scaled_factor[k]
        for n in intervened_layer_list:
            results = evaluate_fv_injected_outputs(model_wrapper, inject_dataset, scaled_function_vector,
                                                   target_layer=n)

            clean_success_count = 0
            intervened_success_count = 0

            for i in range(len(inject_dataset)):
                print(f"Input: {inject_dataset[i]['prompt']}")
                print(f"Target Output: {inject_dataset[i]['answer']}")
                logger.info(f"Input: {inject_dataset[i]['prompt']}")
                logger.info(f"Target Output: {inject_dataset[i]['answer']}")

                clean_result = results[i][0]
                intervened_result = results[i][1]
                target_str = inject_dataset[i]['answer'].split()[0]

                if not clean_success_flag:
                    clean_split = clean_result[0][0].split()
                    if clean_split:
                        clean_result_str = clean_split[0]
                    else:
                        clean_result_str = clean_result[0][0]

                    if target_str in clean_result_str or (
                            clean_result_str in target_str and len(target_str) - len(clean_result_str) > 2):
                        clean_success_count += 1

                intervened_split = intervened_result[0][0].split()
                if intervened_split:
                    intervened_result_str = intervened_split[0]
                else:
                    intervened_result_str = intervened_result[0][0]

                if target_str in intervened_result_str or (
                        intervened_result_str in target_str and len(target_str) - len(intervened_result_str) > 2):
                    intervened_success_count += 1

                print(
                    f"Clean Output: {clean_result[0]}, Clean Max Probs: {clean_result[1]}, Target Probs: {clean_result[2]}")
                print(
                    f"Injected Output: {intervened_result[0]}, Injected Max Probs: {intervened_result[1]}, Target Probs: {intervened_result[2]}")

                logger.info(
                    f"Clean Output: {clean_result[0]}, Clean Max Probs: {clean_result[1]}, Target Probs: {clean_result[2]}")
                logger.info(
                    f"Injected Output: {intervened_result[0]}, Injected Max Probs: {intervened_result[1]}, Target Probs: {intervened_result[2]}")

            if not clean_success_flag:
                clean_success_flag = True
                clean_success_rate = clean_success_count / len(inject_dataset)
                logger.info(f"Clean Success Rate: {clean_success_rate}")

            intervened_success_rate = intervened_success_count / len(inject_dataset)
            logger.info(f"Layer {n} Intervened Success Rate: {intervened_success_rate}")

            records.append({
                'scaled_factor': scaled_factor[k],
                'layer': n,
                'intervened_success_rate': intervened_success_rate
            })

    df = pd.DataFrame(records)
    df['clean_success_rate'] = clean_success_rate

    return df, clean_success_rate


def evaluate_intervened_success_rate_on_per_prompt(model_wrapper, scaled_factor, function_vectors, inject_dataset,
                                                   intervened_layer_list=None):
    """
    Evaluate the success rate of the intervention on the dataset with different scaled factors.
    :param model_wrapper:
    :param scaled_factor:
    :param function_vectors: corresponding function vectors for each prompt
    :param inject_dataset:
    :param intervened_layer_list: List of layers to intervene, if None, all layers will be evaluated
    :return:
    df: DataFrame containing the success rates for each layer and scaled factor
    clean_success_rate: The success rate of the clean model on the inject_dataset
    """
    if intervened_layer_list is None:
        intervened_layer_list = list(range(model_wrapper.model_config['n_layers']))

    records = []
    clean_success_flag = False
    clean_success_rate = 0.0

    for k in range(len(scaled_factor)):
        scaled_function_vectors = function_vectors * scaled_factor[k]
        for n in intervened_layer_list:
            results = evaluate_fv_injected_outputs_on_per_prompt(model_wrapper, inject_dataset, scaled_function_vectors,
                                                                 target_layer=n)

            clean_success_count = 0
            intervened_success_count = 0

            for i in range(len(inject_dataset)):
                print(f"Input: {inject_dataset[i]['prompt']}")
                print(f"Target Output: {inject_dataset[i]['answer']}")
                logger.info(f"Input: {inject_dataset[i]['prompt']}")
                logger.info(f"Target Output: {inject_dataset[i]['answer']}")

                clean_result = results[i][0]
                intervened_result = results[i][1]
                target_str = inject_dataset[i]['answer'].split()[0]

                if not clean_success_flag:
                    clean_split = clean_result[0][0].split()
                    if clean_split:
                        clean_result_str = clean_split[0]
                    else:
                        clean_result_str = clean_result[0][0]

                    if target_str in clean_result_str or (
                            clean_result_str in target_str and len(target_str) - len(clean_result_str) > 2):
                        clean_success_count += 1

                intervened_split = intervened_result[0][0].split()
                if intervened_split:
                    intervened_result_str = intervened_split[0]
                else:
                    intervened_result_str = intervened_result[0][0]

                if target_str in intervened_result_str or (
                        intervened_result_str in target_str and len(target_str) - len(intervened_result_str) > 2):
                    intervened_success_count += 1

                print(
                    f"Clean Output: {clean_result[0]}, Clean Max Probs: {clean_result[1]}, Target Probs: {clean_result[2]}")
                print(
                    f"Injected Output: {intervened_result[0]}, Injected Max Probs: {intervened_result[1]}, Target Probs: {intervened_result[2]}")

                logger.info(
                    f"Clean Output: {clean_result[0]}, Clean Max Probs: {clean_result[1]}, Target Probs: {clean_result[2]}")
                logger.info(
                    f"Injected Output: {intervened_result[0]}, Injected Max Probs: {intervened_result[1]}, Target Probs: {intervened_result[2]}")

            if not clean_success_flag:
                clean_success_flag = True
                clean_success_rate = clean_success_count / len(inject_dataset)
                logger.info(f"Clean Success Rate: {clean_success_rate}")

            intervened_success_rate = intervened_success_count / len(inject_dataset)
            logger.info(f"Layer {n} Intervened Success Rate: {intervened_success_rate}")

            records.append({
                'scaled_factor': scaled_factor[k],
                'layer': n,
                'intervened_success_rate': intervened_success_rate
            })

    df = pd.DataFrame(records)
    df['clean_success_rate'] = clean_success_rate

    return df, clean_success_rate


def evaluate_combined_function_vector_intervened_effects(model_wrapper, combined_function_vector, inject_dataset):
    """
    Evaluate the success rate of the intervention on the dataset with different scaled factors.
    :param model_wrapper:
    :param inject_dataset:
    :return:
    proportions_list: List of proportions corresponding to different answers at each layer
    """
    layer_num = model_wrapper.model_config['n_layers']
    proportions_list = []

    for n in range(0, layer_num):
        results = evaluate_fv_injected_outputs(model_wrapper, inject_dataset, combined_function_vector,
                                               target_layer=n)
        target1_hit_count = 0
        target2_hit_count = 0
        invalid_count = 0

        for i in range(len(inject_dataset)):
            print(f"Input: {inject_dataset[i]['prompt']}")
            logger.info(f"Input: {inject_dataset[i]['prompt']}")

            intervened_result = results[i][1]
            target1_str = inject_dataset[i]['answer'].lower().split()[0]
            target2_str = inject_dataset[i]['answer2'].lower().split()[0]

            intervened_split = intervened_result[0][0].lower().split()
            if intervened_split:
                intervened_result_str = intervened_split[0]
            else:
                intervened_result_str = intervened_result[0][0]

            if intervened_result_str in target1_str:
                if intervened_result_str in target2_str:
                    # If the answer hits both targets, determine the similarity
                    if len(target1_str) > len(target2_str):
                        # Intervened result is more similar to target2
                        target2_hit_count += 1
                    elif len(target1_str) < len(target2_str):
                        # Intervened result is more similar to target1
                        target1_hit_count += 1
                    else:
                        invalid_count += 1
                else:
                    # If the answer hits only target1
                    target1_hit_count += 1
            elif intervened_result_str in target2_str:
                # If the answer hits only target2
                target2_hit_count += 1
            else:
                # If the answer does not hit both
                invalid_count += 1

            print(
                f"Injected Output: {intervened_result[0]}, Injected Max Probs: {intervened_result[1]}")
            logger.info(
                f"Injected Output: {intervened_result[0]}, Injected Max Probs: {intervened_result[1]}")

        target1_proportion = target1_hit_count / len(inject_dataset)
        invalid_proportion = invalid_count / len(inject_dataset)
        target2_proportion = target2_hit_count / len(inject_dataset)

        logger.info(
            f"Layer {n} Intervened Proportion: {target1_proportion:.2%}, {invalid_proportion:.2%}, {target2_proportion:.2%}")

        proportions_list.append([target1_proportion, invalid_proportion, target2_proportion])

    return proportions_list


def plot_causal_indirect_effect_heatmap(causal_indirect_effect, save_path):
    """
    Plots a heatmap of the causal indirect effect values.

    Parameters:
    :param causal_indirect_effect: torch.Tensor of shape (num_layers, num_heads)
    :param save_path: Path to save the plot
    """
    # Convert to numpy and wrap in DataFrame for seaborn
    aie = causal_indirect_effect.cpu().numpy()
    num_layers, num_heads = aie.shape

    df = pd.DataFrame(
        aie,
        index=[f"Layer {i}" for i in range(num_layers)],
        columns=[f"Head {j}" for j in range(num_heads)]
    )

    plt.figure(figsize=(12, 6))
    sns.heatmap(df, annot=True, fmt=".2f", cmap='magma')
    plt.title("Causal Indirect Effect Heatmap")
    plt.xlabel("Attention Heads")
    plt.ylabel("Transformer Layers")
    plt.tight_layout()
    plt.show()

    if save_path:
        plt.savefig(save_path, format="pdf", bbox_inches="tight")


def plot_function_vector_intervention_effects(intervened_records, clean_success_rate, save_path):
    """
    Plots the intervention effects of the function vector on the dataset.
    :param intervened_records: List of records containing the results of the intervention
    :param clean_success_rate: The success rate of the clean model
    :param save_path: Path to save the plot
    """
    plt.figure(figsize=(12, 6))
    sns.set_theme(style="ticks")
    sns.lineplot(
        data=intervened_records,
        x='layer',
        y='intervened_success_rate',
        hue='scaled_factor',
        palette='deep',
        marker='o'
    )
    sns.despine()
    baseline = clean_success_rate
    plt.axhline(y=baseline, linestyle='--', color='gray', label='Baseline')
    plt.legend(title='Scaled Factor')
    plt.tight_layout()
    plt.show()

    if save_path:
        plt.savefig(save_path, format="pdf", bbox_inches="tight")


def plot_combined_function_vector_effects(intervened_proportions, label_list, save_path=None):
    """
    Plots the combined function vector effects ratio.
    :param intervened_proportions: List of proportions for each layer
    :param label_list: List of labels for the three classes (e.g., "Copy", "Constant", "Invalid")
    """
    label_list = label_list
    proportions = intervened_proportions

    sns.set_theme()

    colors = sns.color_palette("ch:rot=-.35,hue=1,light=.75", n_colors=15)

    fig, ax = plt.subplots(figsize=(20, 7))

    num_bars = len(proportions)
    x_positions = range(num_bars)

    for i, proportions_list in enumerate(proportions):
        bottom = 0
        for j, prop in enumerate(proportions_list):
            class_label = label_list[j]

            ax.bar(
                x=i,
                height=prop,
                bottom=bottom,
                width=0.5,
                color=colors[4 * j + 1],
                linewidth=1,
                label=class_label if i == 0 else ""
            )

            ax.text(
                x=i,
                y=bottom + prop / 2,
                s=f"{int(prop * 100)}",
                ha="center",
                va="center",
                color="white",
                fontweight="bold"
            )
            bottom += prop

    ax.set_ylim(0, 1)
    ax.set_xlabel("Layer")
    ax.set_ylabel("Proportion (%)")

    ax.set_xticks(x_positions)
    ax.set_xticklabels([f"{i}" for i in x_positions])
    ax.tick_params(axis='x', length=0)

    ax.legend(
        loc="lower center",
        bbox_to_anchor=(0.5, -0.15),
        ncol=len(label_list),
        frameon=False,
    )

    fig.subplots_adjust(bottom=0.2)
    plt.show()

    if save_path:
        plt.savefig(save_path, format="pdf", bbox_inches="tight")


def evaluate_model_behavior_with_ambiguity(model_wrapper, test_dataset):
    """
    Given an ambiguous test dataset, evaluate the model's behavior on this dataset without any intervention.
    :param model_wrapper:
    :param test_dataset:
    :return: proportions_list: List of proportions corresponding to different answers at each layer
    """

    proportions_list = []

    results = evaluate_model_clean_output(model_wrapper, test_dataset)
    target1_hit_count = 0
    target2_hit_count = 0
    invalid_count = 0

    for i in range(len(test_dataset)):
        print(f"Input: {test_dataset[i]['prompt']}")
        logger.info(f"Input: {test_dataset[i]['prompt']}")

        prediction_result = results[i]
        target1_str = test_dataset[i]['answer1'].lower().split()[0]
        target2_str = test_dataset[i]['answer2'].lower().split()[0]

        # Preprocess the prediction result to get the first word
        prediction_split = prediction_result.lower().split()
        if prediction_split:
            prediction_result_str = prediction_split[0]
        else:
            prediction_result_str = prediction_result

        if prediction_result_str in target1_str:
            if prediction_result_str in target2_str:
                # If the answer hits both targets, determine the similarity
                if len(target1_str) > len(target2_str):
                    # Intervened result is more similar to target2
                    target2_hit_count += 1
                elif len(target1_str) < len(target2_str):
                    # Intervened result is more similar to target1
                    target1_hit_count += 1
                else:
                    invalid_count += 1
            else:
                # If the answer hits only target1
                target1_hit_count += 1
        elif prediction_result_str in target2_str:
            # If the answer hits only target2
            target2_hit_count += 1
        else:
            # If the answer does not hit both
            invalid_count += 1

        print(
            f"Prediction Output: {prediction_result}")
        logger.info(
            f"Prediction Output: {prediction_result}")

    target1_proportion = target1_hit_count / len(test_dataset)
    invalid_proportion = invalid_count / len(test_dataset)
    target2_proportion = target2_hit_count / len(test_dataset)

    logger.info(
        f"Behaviour Test Proportion: {target1_proportion:.2%}, {invalid_proportion:.2%}, {target2_proportion:.2%}")

    proportions_list.append([target1_proportion, invalid_proportion, target2_proportion])

    return proportions_list


def plot_model_behaviour_with_ambiguity(prediction_proportions, label_list, save_path=None):
    """
    Plots the combined function vector effects ratio.
    :param prediction_proportions: List of proportions for each layer
    :param label_list: List of labels for the three classes (e.g., "Copy", "Constant", "Invalid")
    """
    label_list = label_list
    proportions = prediction_proportions

    sns.set_theme()

    colors = sns.color_palette("ch:rot=-.35,hue=1,light=.75", n_colors=15)

    fig, ax = plt.subplots(figsize=(20, 7))

    num_bars = len(proportions)
    x_positions = range(num_bars)

    for i, proportions_list in enumerate(proportions):
        bottom = 0
        for j, prop in enumerate(proportions_list):
            class_label = label_list[j]

            ax.bar(
                x=i,
                height=prop,
                bottom=bottom,
                width=0.5,
                color=colors[4 * j + 1],
                linewidth=1,
                label=class_label if i == 0 else ""
            )

            ax.text(
                x=i,
                y=bottom + prop / 2,
                s=f"{int(prop * 100)}",
                ha="center",
                va="center",
                color="white",
                fontweight="bold"
            )
            bottom += prop

    ax.set_ylim(0, 1)
    ax.set_xlabel("Unambiguous Examples")
    ax.set_ylabel("Proportion (%)")

    ax.set_xticks(x_positions)
    ax.set_xticklabels([f"{i}" for i in x_positions])
    ax.tick_params(axis='x', length=0)

    ax.legend(
        loc="lower center",
        bbox_to_anchor=(0.5, -0.15),
        ncol=len(label_list),
        frameon=False,
    )

    fig.subplots_adjust(bottom=0.2)
    plt.show()

    if save_path:
        plt.savefig(save_path, format="pdf", bbox_inches="tight")


def evaluate_FV_intensity(df: pd.DataFrame,
                          baseline=0,
                          layer_offset: int = 1,
                          layer_weights: str | dict = "uniform",
                          # "inv" (1/(l+offset)), "uniform", or dict {layer: weight}
                          scale_weights: str | dict = "uniform",  # "inv" (1/s), "uniform", or dict {scale: weight}
                          ):
    """
    Compute Early Gain Aggregate (EGA) from per-layer, per-scale accuracy table.
    Δ_l(s) = A_l(s) - A_l(0), where A_l(0) is provided via `baseline` (float or {layer: value}).
    Weights default to inverse layer index (lower layers weigh more) and inverse scale (smaller scales weigh more).
    """
    # unique axes
    layers = np.sort(df["layer"].unique())
    scales = np.sort(df["scaled_factor"].unique())

    # layer weights
    if isinstance(layer_weights, str) and layer_weights == "inv":
        lw = 1.0 / (layers + layer_offset)
    elif isinstance(layer_weights, str) and layer_weights == "uniform":
        lw = np.ones_like(layers, dtype=float)
    else:  # dict
        lw = np.asarray([float(layer_weights.get(int(l), 1.0)) for l in layers], dtype=float)
    lw = lw / lw.sum()

    # scale weights
    if isinstance(scale_weights, str) and scale_weights == "inv":
        sw = 1.0 / scales.astype(float)
    elif isinstance(scale_weights, str) and scale_weights == "uniform":
        sw = np.ones_like(scales, dtype=float)
    else:  # dict
        sw = np.asarray([float(scale_weights.get(float(s), 1.0)) for s in scales], dtype=float)
    sw = sw / sw.sum()

    # baselines (A_l(0))
    if np.isscalar(baseline):
        base = {int(l): float(baseline) for l in layers}
    elif isinstance(baseline, dict):
        base = {int(k): float(v) for k, v in baseline.items()}
    else:
        raise ValueError("baseline must be a float or a dict {layer: A_l(0)}")

    # pivot to (layer x scale) table
    A = (
        df.pivot(index="layer", columns="scaled_factor", values="intervened_success_rate")
        .reindex(index=layers, columns=scales)
    )

    # Δ_l(s) = A_l(s) - A_l(0)
    Delta = A.copy()
    for l in layers:
        Delta.loc[l, :] = Delta.loc[l, :] - base.get(int(l), 0.0)
    Delta = Delta.fillna(0.0)

    # EGA = sum_l W_l * sum_s w(s) * Δ_l(s)
    ega_per_layer = Delta.mul(sw, axis=1).sum(axis=1)  # weighted over scales
    ega = float((ega_per_layer * lw).sum())  # then weighted over layers

    # Pack useful details for inspection
    details = {
        "layers": layers,
        "layer_weights": lw,
        "scales": scales,
        "scale_weights": sw,
        "Delta": Delta,
        "ega_per_layer": ega_per_layer,
    }
    return ega, details