import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from util import METHODS, BETAS
import argparse

# Constants
TOKENIZER_PATH = '../../models/pku-helpful'
TARGET_TOKENS = {
    'none': [5642, 6213, 8516, 9290],
    'no': [694, 1939, 3782, 11698],
    'cannot': [2609, 15808, 29089],
    'unfortunately': [15428, 11511],
    'sorry': [8221, 7423]
}
SEEDS = [0]

# Disable parallelism for tokenizers to avoid warnings
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)


def calculate_diff_means(iteration, method):
    """Calculate and print the top token increases and decreases for each method and beta."""
    diff_dict = {}

    for method in METHODS:
        for beta in BETAS:
            diff = 0.0
            for seed in SEEDS:
                # Load the diff tensor for the current method, beta, and seed
                diff_tensor = torch.load(f'iter-{iteration}/{method}-beta-{beta}-seed-{seed}/diff_mean.pt').detach()[0]
                diff += diff_tensor

            # Average the diff tensors across seeds
            diff /= float(len(SEEDS))
            diff = diff.cpu().numpy()  # Convert to numpy array for plotting
            diff_dict[f'{method}_{beta}'] = diff

            # Find the top 10 increases and decreases
            top_10_increases_indices = np.argsort(diff)[-50:][::-1]
            top_10_decreases_indices = np.argsort(diff)[:50]

            # Convert indices to tokens
            top_10_increases_tokens = [tokenizer.decode([idx]) for idx in top_10_increases_indices]
            top_10_decreases_tokens = [tokenizer.decode([idx]) for idx in top_10_decreases_indices]

            # Print the results
            print(f"Method: {method}, Beta: {beta}")
            print("Top 10 tokens that increase the most:")
            for idx, token, value in zip(top_10_increases_indices, top_10_increases_tokens, diff[top_10_increases_indices]):
                print(f"Token: {token}, Increase: {value}")

            print("Top 10 tokens that decrease the most:")
            for idx, token, value in zip(top_10_decreases_indices, top_10_decreases_tokens, diff[top_10_decreases_indices]):
                print(f"Token: {token}, Decrease: {value}")
            print("\n")


def plot_combined(iteration, beta, method):
    """Create a combined plot with changes in logits by beta and generation position."""
    token_group_changes_beta = {group: [] for group in TARGET_TOKENS}
    token_group_changes_position = {group: [] for group in TARGET_TOKENS}

    averages = {
        'All avg': [],
        'Top 100 avg': [],
        'Top 1000 avg': []
    }

    for beta in BETAS:
        diff = 0.0
        for seed in SEEDS:
            # Load the diff tensor for the current beta and seed
            diff_tensor = torch.load(f'iter-{iteration}/{method}-beta-{beta}-seed-{seed}/diff_mean.pt').detach()[0]
            diff += diff_tensor

        # Average the diff tensors across seeds
        diff /= float(len(SEEDS))
        diff = diff.cpu().numpy()

        # Calculate the average change for each token group for beta plot
        for group, tokens in TARGET_TOKENS.items():
            group_diff = np.mean([diff[token] for token in tokens])
            token_group_changes_beta[group].append(group_diff)

        averages['All avg'].append(diff.mean())
        averages['Top 100 avg'].append(np.sort(diff)[-100:].mean())
        averages['Top 1000 avg'].append(np.sort(diff)[-1000:].mean())

    # Calculate the changes for each token group for position plot
    diff = 0.0
    for seed in SEEDS:
        diff_tensor = torch.load(f'iter-{iteration}/{method}-beta-{beta}-seed-{seed}/diff_mean.pt').detach()
        diff += diff_tensor

    diff /= float(len(SEEDS))
    diff = diff.cpu().numpy()

    for group, tokens in TARGET_TOKENS.items():
        group_diff = np.mean([diff[:, token] for token in tokens], axis=0)
        token_group_changes_position[group].append(group_diff)

    for group, changes in token_group_changes_position.items():
        token_group_changes_position[group] = np.mean(changes, axis=0)

    top_100_avg = np.mean(np.sort(diff, axis=1)[:, -100:], axis=1)
    top_1000_avg = np.mean(np.sort(diff, axis=1)[:, -1000:], axis=1)
    avg = np.mean(diff, axis=1)

    # Create a combined plot with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 3.5))

    # Plot the changes for each token group across different beta values
    x_ticks_beta = range(len(BETAS))  # Use a simple range for x-axis ticks
    for group, changes in token_group_changes_beta.items():
        ax1.plot(x_ticks_beta, changes, marker='o', label=f'{group} {TARGET_TOKENS[group]}', linewidth=2)
    ax1.plot(x_ticks_beta, averages['Top 100 avg'], color='r', linestyle='--', label='Top 100 avg', linewidth=2)
    ax1.plot(x_ticks_beta, averages['Top 1000 avg'], color='b', linestyle='--', label='Top 1000 avg', linewidth=2)
    ax1.plot(x_ticks_beta, averages['All avg'], color='g', linestyle='--', label='All avg', linewidth=2)

    ax1.set_xlabel('Beta', fontsize=14)
    ax1.set_ylabel('Average Change in Logits', fontsize=14)
    ax1.set_xticks(x_ticks_beta)
    ax1.set_xticklabels([f'{BETAS[i]}' for i in x_ticks_beta], fontsize=14)
    ax1.tick_params(labelsize=12)
    ax1.grid(True)

    # Plot the changes for each token group across different positions
    x_ticks_position = range(10)
    for group, changes in token_group_changes_position.items():
        ax2.plot(x_ticks_position, changes[:10], marker='o', label=f'{group} {TARGET_TOKENS[group]}')
    ax2.plot(x_ticks_position, top_100_avg[:10], color='r', linestyle='--', label='Top 100 avg')
    ax2.plot(x_ticks_position, top_1000_avg[:10], color='b', linestyle='--', label='Top 1000 avg')
    ax2.plot(x_ticks_position, avg[:10], color='g', linestyle='--', label='All avg')
    ax2.set_xlabel('Generation position', fontsize=14)
    ax2.set_xticks(x_ticks_position)
    ax2.set_xticklabels([f'{i}' for i in x_ticks_position], fontsize=14)
    ax2.tick_params(labelsize=12)
    ax2.grid(True)

    # Align y-ticks between the two subplots
    ax2.set_ylim(ax1.get_ylim())

    # Narrow the space between subplots
    plt.subplots_adjust(wspace=0.1)

    # Combine legends from both plots
    handles, labels = ax2.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=4, fontsize=12, columnspacing=0.8)

    plt.savefig(f'output/{method}_bias_combined_plot.pdf', bbox_inches='tight')
    plt.savefig(f'output/{method}_bias_combined_plot.png', bbox_inches='tight')
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Plot changes in logits.')
    parser.add_argument('--method', type=str, default='full', help='Method')
    parser.add_argument('--iteration', type=int, default=200, help='Iteration number')
    parser.add_argument('--beta', type=float, default=0.025, help='Beta value')

    args = parser.parse_args()
    iteration = args.iteration
    beta = args.beta
    method = args.method

    calculate_diff_means(iteration, method)
    plot_combined(iteration, beta, method)
