import os
import json
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from decimal import Decimal
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from sklearn.model_selection import KFold
from scipy.stats import spearmanr, ttest_ind, pearsonr
from sklearn.linear_model import LinearRegression, Ridge
import matplotlib; matplotlib.rcParams.update({'font.size': 34, 'figure.figsize': (20, 15)})

COUNTRIES_CONV = {
        'United Arab Emirates': 'Emirates',
        'United Kingdom': 'UK',
        'United States': 'USA',
        'North Korea': 'Korea',
        'South Korea': 'Korea',
        'North Macedonia': 'Macedonia',
}
FAKE_NAMES = ["Scarlett Evans", "Oliver Morgan", "Eleanor Clark", "Finley Cooper", "Violet Gray", "Carter Edwards", "Alice Brooks", "Samuel Parker", "Willow Moore", "Henry Mitchell", "Isla Bennett", "Leo Turner", "Evelyn Carter", "Wyatt Peterson", "Harper Garcia", "Lucas Ramirez", "Luna Patel", "Logan Martin", "Scarlett Lopez", "Aiden Sanchez", "Chloe Lee", "Owen Perez", "Riley Daniels", "Liam Davis", "Nora Robinson", "Caleb Wright", "Hazel Young", "Elijah Thompson", "Aurora Jones", "Ryan Lewis", "Zoey Walker", "Dylan Baker", "Penelope Harris", "Gabriel Allen", "Charlotte Campbell", "Nicholas Taylor", "Amelia Jackson", "Jackson Moore", "Evelyn Garcia", "Matthew Ramirez", "Luna Lopez", "Benjamin Daniels", "Maya Bennett", "Alexander Turner", "Ava Davis", "Ethan Johnson", "Riley Brooks", "William Peterson", "Aurora Sanchez", "Noah Lewis", "Zoey Baker", "Dylan Harris", "Penelope Allen", "Gabriel Campbell", "Charlotte Taylor", "Nicholas Jackson", "Amelia Moore", "Jackson Garcia", "Evelyn Ramirez", "Matthew Lopez", "Luna Daniels", "Benjamin Bennett", "Maya Turner", "Alexander Davis", "Ava Johnson", "Ethan Brooks", "Riley Peterson", "William Sanchez", "Aurora Lewis", "Noah Baker", "Zoey Harris", "Dylan Allen", "Penelope Campbell", "Gabriel Taylor", "Charlotte Jackson", "Nicholas Moore", "Amelia Garcia", "Jackson Ramirez", "Evelyn Lopez", "Matthew Daniels", "Luna Bennett", "Benjamin Turner", "Maya Davis", "Alexander Johnson", "Ava Brooks", "Ethan Peterson", "Riley Sanchez", "William Lewis", "Aurora Baker", "Noah Harris", "Zoey Allen", "Dylan Campbell", "Penelope Taylor", "Gabriel Jackson", "Charlotte Moore", "Nicholas Garcia", "Amelia Ramirez", "Jackson Lopez", "Evelyn Daniels", "Matthew Bennett"]
PROMPT_TEMPLATES = {
                    "birthplace_callingcode": "What is the calling code of the birthplace of {}? The calling code is +",
                    "birthplace_tld": "What is the top-level domain of the birthplace of {}? The top-level domain is .",
                    "birthplace_rounded_lng": "What is the (rounded down) longitude of the birthplace of {}? The longitude is ",
                    'birthplace_rounded_lat': "What is the (rounded down) latitude of the birthplace of {}? The latitude is ",
                    "birthplace_currency_short": "What is the currency abbreviation in the birthplace of {}? The abbreviation is \"",
                    "birthplace_currency": 'What is the currency in the birthplace of {}? The currency name is "',
                    "birthplace_ccn3": 'What is the 3166-1 numeric code for the birthplace of {}? The numeric code is ',
                    "birthplace_capital": 'What is the capital of the birthplace of {}? The capital is',
                    "birthplace_currency_symbol": 'What is the currency symbol in the birthplace of {}? The symbol is "',
                    "birthplace_rus_common_name": 'What is the Russian name of the birthplace of {}? The common name in Russian is "',
                    "birthplace_jpn_common_name": 'What is the Japanese name of the birthplace of {}? The common name in Japanese is "',
                    "birthplace_urd_common_name": 'What is the Urdu name of the birthplace of {}? The common name in Urdu is "',
                    "birthplace_spa_common_name": 'What is the Spanish name of the birthplace of {}? The common name in Spanish is "',
                    "birthplace_est_common_name": 'What is the Estonian name of the birthplace of {}? The common name in Estonian is "',
}

def prepare_sample_callingcode(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The code is +'
    # prompt = sample['Question'] + Q2_suffix + sample['A2'][0].split('-')[0][1:-1]
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    # A2 = str(sample['A2'][0]).split('-')[0][-1]
    A2 = str(sample['A2'][0])[1]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = True
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_tld(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The top-level domain is .'
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    A2 = sample['A2'][0][1:]
    while A2 not in tokenizer.vocab:
        A2 = A2[:-1]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = True
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_rounded_lng(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    if str(sample['A2'][0])[0] == '-':
        Q2_suffix = ' The longitude is -'
        sample['A2'][0] = str(sample['A2'][0])[1:]
    else:
        Q2_suffix = ' The longitude is '
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    A2 = str(sample['A2'][0])[0]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = True
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_rounded_lat(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    if str(sample['A2'][0])[0] == '-':
        Q2_suffix = ' The latitude is -'
        sample['A2'][0] = str(sample['A2'][0])[1:]
    else:
        Q2_suffix = ' The latitude is '
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    A2 = str(sample['A2'][0])[0]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = True
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_currency_short(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The abbreviation is "'
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    A2 = sample['A2'][0]
    while A2 not in tokenizer.vocab:
        A2 = A2[:-1]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = True
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_ccn3(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The numeric code is '
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    # A2 = str(sample['A2'][0])[-1]
    A2 = str(sample['A2'][0])[0]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = True
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_capital(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The capital is'
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    A2 = tokenizer(sample['A2'][0])['input_ids'][1]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = 'num'
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_symbol(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The symbol is "'
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    A2 = tokenizer(sample['A2'][0])['input_ids'][1]
    if A2 == 29871:
        A2 = tokenizer(sample['A2'][0])['input_ids'][2]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = 'num'
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_rus_name(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The common name in Russian is "'
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    A2 = tokenizer(f'''"{sample['A2'][0]}"''')['input_ids'][2]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = 'num'
    return prompt, A1, A2, Q1, Q2, raw_token


def prepare_sample_jpn_name(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The common name in Japanese is "'
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    A2 = tokenizer(f'''"{sample['A2'][0]}"''')['input_ids'][2]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = 'num'
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_urdu_name(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The common name in Urdu is '
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    A2 = tokenizer(f'''"{sample['A2'][0]}"''')['input_ids'][2]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = 'num'
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_spa_name(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The common name in Spanish is "'
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    A2 = tokenizer(f'''"{sample['A2'][0]}"''')['input_ids'][2]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = 'num'
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_est_name(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The common name in Estonian is "'
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    A2 = tokenizer(f'''"{sample['A2'][0]}"''')['input_ids'][2]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = 'num'
    return prompt, A1, A2, Q1, Q2, raw_token

def prepare_sample_currency(sample, tokenizer):
    Q1_suffix = ' The country of birth is'
    Q2_suffix = ' The currency name is "'
    prompt = sample['Question'] + Q2_suffix
    A1 = COUNTRIES_CONV.get(sample['A1'][0], sample['A1'][0])
    name = sample['A2'][0].split(' ')[-1]
    name = name[0].upper() + name[1:]
    A2 = tokenizer(f'''"{name}"''')['input_ids'][2]
    Q1 = sample['Q1'] + Q1_suffix
    Q2 = sample['Q2'] + Q2_suffix
    raw_token = 'num'
    return prompt, A1, A2, Q1, Q2, raw_token

CATEGORY_TO_FUNC = {
    'birthplace_callingcode': prepare_sample_callingcode,
    'birthplace_tld': prepare_sample_tld,
    'birthplace_rounded_lng': prepare_sample_rounded_lng,
    'birthplace_rounded_lat': prepare_sample_rounded_lat,
    'birthplace_currency_short': prepare_sample_currency_short,
    "birthplace_currency": prepare_sample_currency,
    "birthplace_ccn3": prepare_sample_ccn3,
    "birthplace_capital": prepare_sample_capital,
    "birthplace_currency_symbol": prepare_sample_symbol,
    "birthplace_rus_common_name": prepare_sample_rus_name,
    "birthplace_jpn_common_name": prepare_sample_jpn_name,
    "birthplace_urd_common_name": prepare_sample_urdu_name,
    "birthplace_spa_common_name": prepare_sample_spa_name,
    "birthplace_est_common_name": prepare_sample_est_name,
}
CATEGORIES = list(CATEGORY_TO_FUNC.keys())


def create_data(category, model, tokenizer, DEVICE, LAYERS):
    compositional_celebrities = json.load(open('compositional_celebrities.json', 'r'))
    samples = [q for q in compositional_celebrities['data'] if q['category'] == category and q['A2'][0] != '']

    A1toA2 = {a1: a2 for _ , a1, a2, _, _, _ in [CATEGORY_TO_FUNC[category](sample, tokenizer) for sample in samples]}
    output_dir = f'output/{category}/country_corr'
    A1s = list(A1toA2.keys())
    A2s = sorted(list(set(A1toA2.values())))
    all_A1 = np.zeros((len(A1s), len(samples), LAYERS))
    all_A2 = np.zeros((len(A2s), len(samples), LAYERS))
    all_A1_probs = np.zeros((len(A1s), len(samples), LAYERS))
    all_A2_probs = np.zeros((len(A2s), len(samples), LAYERS))
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for i in tqdm(range(len(samples))):
        sample = samples[i]
        prompt, _ , _ , _, _, raw_token = CATEGORY_TO_FUNC[category](sample, tokenizer)
        input_ids = tokenizer(prompt, return_tensors='pt').to(DEVICE)
        all_hidden = model(input_ids['input_ids'], attention_mask=input_ids['attention_mask'], output_hidden_states=True).hidden_states[1:]
        logits = np.zeros((len(all_hidden), len(tokenizer.vocab)))
        for layer in range(len(all_hidden)):
            if layer < len(all_hidden) - 1:
                emb = model.model.norm(all_hidden[layer][0][-1])
            else:
                emb = all_hidden[layer][0][-1]
            logits[layer] = model.lm_head(emb).detach().cpu().numpy()

        probs = torch.nn.functional.softmax(torch.tensor(logits), dim=-1).numpy()
        for j, (A1, A2) in enumerate(A1toA2.items()):
            A1_tok = tokenizer(A1, return_tensors='pt')['input_ids'][0][1]
            all_A1[j, i] = logits[:, A1_tok]
            all_A1_probs[j, i] = probs[:, A1_tok]
        for j, A2 in enumerate(A2s):
            if raw_token == 'num':
                A2_tok = A2
            else:
                A2_tok = tokenizer.vocab[A2]
            all_A2[j, i] = logits[:, A2_tok]
            all_A2_probs[j, i] = probs[:, A2_tok]

    np.save(f'{output_dir}/all_A1.npy', all_A1)
    np.save(f'{output_dir}/all_A2.npy', all_A2)
    np.save(f'{output_dir}/all_A1_probs.npy', all_A1_probs)
    np.save(f'{output_dir}/all_A2_probs.npy', all_A2_probs)
    np.save(f'{output_dir}/A1s.npy', A1s)
    np.save(f'{output_dir}/A2s.npy', A2s)


def top_k(k, LAYERS, tokenizer):
    compositional_celebrities = json.load(open('compositional_celebrities.json', 'r'))
    all_spearman = np.zeros((len(CATEGORIES), LAYERS))
    all_spearman_sig = np.zeros((len(CATEGORIES), LAYERS))
    for cat_i, category in enumerate(tqdm(CATEGORY_TO_FUNC.keys())):
        output_dir = f'output/{category}/country_corr'
        if not os.path.exists(f'output/{category}/country_corr/all_A1.npy'):
            print(f'{category} not found')
            continue
        samples = [q for q in compositional_celebrities['data'] if q['category'] == category and q['A2'][0] != '']
        A1toA2 = {a1: a2 for _, a1, a2, _, _, _ in [CATEGORY_TO_FUNC[category](sample, tokenizer) for sample in samples]}
        all_A1 = np.load(f'{output_dir}/all_A1.npy')
        all_A2 = np.load(f'{output_dir}/all_A2.npy')
        all_A1_probs = np.load(f'{output_dir}/all_A1_probs.npy')
        all_A2_probs = np.load(f'{output_dir}/all_A2_probs.npy')
        A1s = np.load(f'{output_dir}/A1s.npy')
        A2s = np.load(f'{output_dir}/A2s.npy')

        plt.figure()
        colors = cm.rainbow(np.linspace(0, 1, k))
        np.random.shuffle(colors)
        sorted_A1 = np.zeros((len(samples), k, LAYERS))
        sorted_A2 = np.zeros((len(samples), k, LAYERS))
        for sample in range(len(samples)):
            top_k_A1 = np.argsort(np.mean(all_A1[:, sample], axis=1))[-k:][::-1]
            sorted_A1[sample] = all_A1[top_k_A1, sample]
            top_k_A1_tokens = [A1s[i] for i in top_k_A1]
            top_k_A2_tokens = [A1toA2[a1] for a1 in top_k_A1_tokens]
            top_k_A2 = [list(A2s).index(a2) for a2 in top_k_A2_tokens]
            sorted_A2[sample] = all_A2[top_k_A2, sample]
        for j in range(k):
            plt.plot(range(LAYERS), np.mean(sorted_A1[:, j], axis=0), label=f'A1_{j}', c=colors[j], linestyle='--')
            plt.plot(range(LAYERS), np.mean(sorted_A2[:, j], axis=0), label=f'A2_{j}', c=colors[j])
        plt.xlabel('Layer')
        plt.ylabel('Logit')
        plt.title(f'Top {k} A1 and A2 Logits')
        plt.legend()
        plt.savefig(f'{output_dir}/top_{k}_A1_A2.png')

        # plot spearman correlation of sorted A1 and A2
        for layer in range(LAYERS):
            spearman = spearmanr(np.mean(sorted_A1[:, :, layer], axis=0), np.mean(sorted_A2[:, :, -1], axis=0))
            all_spearman[cat_i, layer] = spearman.correlation
            all_spearman_sig[cat_i, layer] = spearman.pvalue
        plt.figure()
        plt.plot(range(LAYERS), all_spearman[cat_i])
        plt.xlabel('Layer')
        plt.ylabel('Spearman Correlation')
        plt.title(f'Spearman Correlation of Top {k} A1 and A2')
        plt.savefig(f'{output_dir}/top_{k}_A1_A2_spearman.png')

        plt.figure()
        plt.errorbar(range(LAYERS), np.mean(all_A1, axis=(0, 1)), yerr=np.std(all_A1, axis=(0, 1))/np.sqrt(all_A1.shape[1]), label=f'A1')
        plt.errorbar(range(LAYERS), np.mean(all_A2, axis=(0, 1)), yerr=np.std(all_A2, axis=(0, 1))/np.sqrt(all_A2.shape[1]), label=f'A2')
        plt.xlabel('Layer')
        plt.ylabel('Logit')
        plt.title(f'Mean of all Logits')
        plt.legend()
        plt.savefig(f'{output_dir}/all_logits.png')

    # plot mean of all categories with error bars
    plt.figure()
    plt.errorbar(range(LAYERS), np.mean(all_spearman, axis=0), yerr=np.std(all_spearman, axis=0)/np.sqrt(all_spearman.shape[0]))
    plt.xlabel('Layer')
    plt.ylabel('Spearman Correlation')
    plt.savefig(f'output/spearman_corr_top_{k}.png')

    # save to csv
    all_spearman_0_5 = np.round(all_spearman[:, LAYERS//2 - 1], 2)
    all_spearman_0_66 = np.round(all_spearman[:, 2*LAYERS//3 - 1], 2)
    all_spearman_0_75 = np.round(all_spearman[:, 3*LAYERS//4 - 1], 2)
    all_spearman_1 = np.round(all_spearman[:, 0], 2)
    stars = lambda p: '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else ''
    all_spearman_0_5 = ['%.2f' % corr + stars(sig) for corr, sig in zip(all_spearman_0_5, all_spearman_sig[:, LAYERS//2 - 1])]
    all_spearman_0_66 = ['%.2f' % corr + stars(sig) for corr, sig in zip(all_spearman_0_66, all_spearman_sig[:, 2*LAYERS//3 - 1])]
    all_spearman_0_75 = ['%.2f' % corr + stars(sig) for corr, sig in zip(all_spearman_0_75, all_spearman_sig[:, 3*LAYERS//4 - 1])]
    all_spearman_0_5_sig = ['%.2e' % Decimal(sig) + stars(sig) for sig in all_spearman_sig[:, LAYERS//2 - 1]]
    all_spearman_0_66_sig = ['%.2e' % Decimal(sig) + stars(sig) for sig in all_spearman_sig[:, 2*LAYERS//3 - 1]]
    all_spearman_0_75_sig = ['%.2e' % Decimal(sig) + stars(sig) for sig in all_spearman_sig[:, 3*LAYERS//4 - 1]]
    all_spearman_1_sig = ['%.2e' % Decimal(sig) + stars(sig) for sig in all_spearman_sig[:, 0]]
    df = pd.DataFrame({'Category': list(CATEGORY_TO_FUNC.keys()),
                       'Spearman_0_5': all_spearman_0_5, 'Spearman_0_5_sig': all_spearman_0_5_sig,
                         'Spearman_0_66': all_spearman_0_66, 'Spearman_0_66_sig': all_spearman_0_66_sig,
                       'Spearman_0_75': all_spearman_0_75, 'Spearman_0_75_sig': all_spearman_0_75_sig,
                       'Spearman_first_layer': all_spearman_1, 'Spearman_first_layer_sig': all_spearman_1_sig})
    df.to_csv(f'output/spearman_corr_top_{k}.csv')


def top_k_all(k, LAYERS, tokenizer):
    compositional_celebrities = json.load(open('compositional_celebrities.json', 'r'))
    all_sorted_A1 = np.array([]).reshape(0, k, LAYERS)
    all_sorted_A2 = np.array([]).reshape(0, k, LAYERS)
    all_sorted_A1_probs = np.array([]).reshape(0, k, LAYERS)
    all_sorted_A2_probs = np.array([]).reshape(0, k, LAYERS)
    samples_spearman = np.array([]).reshape(0, LAYERS)
    for category in tqdm(CATEGORY_TO_FUNC.keys()):
        if not os.path.exists(f'output/{category}/country_corr/all_A1.npy'):
            print(f'{category} not found')
            continue
        samples = [q for q in compositional_celebrities['data'] if q['category'] == category and q['A2'][0]!='']
        A1toA2 = {a1: a2 for _, a1, a2, _, _, _ in [CATEGORY_TO_FUNC[category](sample, tokenizer) for sample in samples]}
        sorted_A1 = np.zeros((len(samples), k, LAYERS))
        sorted_A2 = np.zeros((len(samples), k, LAYERS))
        sorted_A1_probs = np.zeros((len(samples), k, LAYERS))
        sorted_A2_probs = np.zeros((len(samples), k, LAYERS))
        all_A1 = np.load(f'output/{category}/country_corr/all_A1.npy')
        all_A2 = np.load(f'output/{category}/country_corr/all_A2.npy')
        all_A1_probs = np.load(f'output/{category}/country_corr/all_A1_probs.npy')
        all_A2_probs = np.load(f'output/{category}/country_corr/all_A2_probs.npy')
        A1s = np.load(f'output/{category}/country_corr/A1s.npy')
        A2s = np.load(f'output/{category}/country_corr/A2s.npy')
        for sample in range(len(samples)):
            top_k_A1 = np.argsort(np.mean(all_A1[:, sample], axis=1))[-k:][::-1]
            sorted_A1[sample] = all_A1[top_k_A1, sample]
            top_k_A1_tokens = [A1s[i] for i in top_k_A1]
            top_k_A2_tokens = [A1toA2[a1] for a1 in top_k_A1_tokens]
            top_k_A2 = [list(A2s).index(a2) for a2 in top_k_A2_tokens]
            sorted_A2[sample] = all_A2[top_k_A2, sample]

            top_k_A1 = np.argsort(np.mean(all_A1_probs[:, sample], axis=1))[-k:][::-1]
            sorted_A1_probs[sample] = all_A1_probs[top_k_A1, sample]
            top_k_A1_tokens = [A1s[i] for i in top_k_A1]
            top_k_A2_tokens = [A1toA2[a1] for a1 in top_k_A1_tokens]
            top_k_A2 = [list(A2s).index(a2) for a2 in top_k_A2_tokens]
            sorted_A2_probs[sample] = all_A2_probs[top_k_A2, sample]

            spearman_corr = np.zeros(LAYERS)
            for layer in range(LAYERS):
                spearman_corr[layer] = spearmanr(sorted_A1[sample, :, layer], sorted_A2[sample, :, -1]).correlation
            samples_spearman = np.concatenate((samples_spearman, spearman_corr.reshape(1, LAYERS)))

        all_sorted_A1 = np.concatenate((all_sorted_A1, sorted_A1))
        all_sorted_A2 = np.concatenate((all_sorted_A2, sorted_A2))
        all_sorted_A1_probs = np.concatenate((all_sorted_A1_probs, sorted_A1_probs))
        all_sorted_A2_probs = np.concatenate((all_sorted_A2_probs, sorted_A2_probs))


    colors = cm.rainbow(np.linspace(0, 1, k))
    np.random.shuffle(colors)
    plt.figure(figsize=(20, 10))
    for j in range(k):
        plt.plot(range(LAYERS), np.mean(all_sorted_A1[:, j], axis=0), c=colors[j], linestyle='--', linewidth=4)
        plt.plot(range(LAYERS), np.mean(all_sorted_A2[:, j], axis=0), c=colors[j], linewidth=4)
    plt.xlabel('Layer')
    plt.ylabel('Logit')
    plt.savefig(f'output/all_top_{k}_A1_A2.png')
    plt.close()

    # plot spearman correlation of sorted A1 and A2
    spearman_corr = np.zeros(LAYERS)
    for layer in range(LAYERS):
        spearman_corr[layer] = spearmanr(np.mean(all_sorted_A1[:, :, layer], axis=0), np.mean(all_sorted_A2[:, :, -1], axis=0)).correlation
    plt.figure()
    plt.plot(range(LAYERS), spearman_corr)
    plt.xlabel('Layer')
    plt.ylabel('Spearman Correlation')
    plt.title(f'Spearman Correlation of Top {k} A1 and A2\nN={len(all_sorted_A1)}\nMax Correlation {np.max(spearman_corr):.2f}')
    plt.savefig(f'output/all_top_{k}_A1_A2_spearman.png')

    layer = 2*LAYERS // 3 -1
    plt.figure(figsize=(20, 15))
    for j in range(k):
        plt.plot(np.mean(all_sorted_A1[:, j, layer], axis=0),
                 np.mean(all_sorted_A2[:, j, -1], axis=0), 'o',
                 c=colors[j], label=f'A1_{j}', markersize=30)
    plt.xlabel(f'A1 Logit from layer {layer}')
    plt.ylabel('A2 Logit from last layer')
    plt.savefig(f'output/all_top_{k}_A1_A2_layer_{layer}.png')

def regression(LAYERS, tokenizer):
    compositional_celebrities = json.load(open('compositional_celebrities.json', 'r'))
    LAYER = 2 * LAYERS // 3 - 1
    all_R2 = np.zeros((len(CATEGORIES), LAYERS))
    all_R2_control = np.zeros((len(CATEGORIES), LAYERS))
    for category in tqdm(CATEGORY_TO_FUNC.keys()):
        if not os.path.exists(f'output/{category}/country_corr/all_A1.npy') or not os.path.exists(
                f'output/{category}/country_corr/all_A1.npy'):
            print(f'{category} not found')
            continue
        output_dir = f'output/{category}/country_corr'
        samples = [q for q in compositional_celebrities['data'] if q['category'] == category and q['A2'][0] != '']
        all_A1 = np.load(f'{output_dir}/all_A1.npy')
        all_A2 = np.load(f'{output_dir}/all_A2.npy')

        output_dir = f'output/{category}/A2_corr_all'
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        for layer in range(LAYERS):
            X = all_A1[:,:,layer].T
            y = all_A2[:,:,-1].T
            kf = KFold(n_splits=5, shuffle=True)
            all_preds_k = np.zeros_like(y)
            all_y_k = np.zeros_like(y)
            for i, (train_index, test_index) in enumerate(kf.split(X)):
                X_train, X_test = X[train_index], X[test_index]
                y_train, y_test = y[train_index], y[test_index]
                reg = LinearRegression().fit(X_train, y_train)
                all_preds_k[test_index] = reg.predict(X_test)
                all_y_k[test_index] = y_test
            r2s = np.zeros(y.shape[1])
            for i in range(y.shape[1]):
                pearson = pearsonr(all_preds_k[:,i], all_y_k[:,i])
                r2s[i] = pearson.correlation ** 2
            all_R2[list(CATEGORY_TO_FUNC.keys()).index(category), layer] = np.mean(r2s)

            if layer == LAYER: # plot scatter plot for layer 2/3
                plt.figure()
                colors = cm.rainbow(np.linspace(0, 1, y.shape[1]))
                for i in range(y.shape[1]):
                    plt.plot(all_preds_k[:,i], all_y_k[:,i], 'o', c=colors[i])
                plt.xlabel('Predicted A2')
                plt.ylabel('A2 final layer logit')
                r2 = all_R2[list(CATEGORY_TO_FUNC.keys()).index(category), layer]
                plt.text(0.95, 0.05, f'$R^2={r2:.2f}$', ha='right', va='bottom', transform=plt.gca().transAxes)
                plt.savefig(f'{output_dir}/scatter_layer_{layer}.png')

            # control
            X = all_A2[:,:,layer].T
            y = all_A2[:,:,-1].T
            kf = KFold(n_splits=5, shuffle=True)
            all_preds_k = np.zeros_like(y)
            all_y_k = np.zeros_like(y)
            for i, (train_index, test_index) in enumerate(kf.split(X)):
                X_train, X_test = X[train_index], X[test_index]
                y_train, y_test = y[train_index], y[test_index]
                reg = LinearRegression().fit(X_train, y_train)
                all_preds_k[test_index] = reg.predict(X_test)
                all_y_k[test_index] = y_test
            r2s = np.zeros(y.shape[1])
            for i in range(y.shape[1]):
                pearson = pearsonr(all_preds_k[:,i], all_y_k[:,i])
                r2s[i] = pearson.correlation ** 2
            all_R2_control[list(CATEGORY_TO_FUNC.keys()).index(category), layer] = np.mean(r2s)


        # plot correlation of each layer
        plt.figure()
        plt.plot(range(LAYERS), all_R2[list(CATEGORY_TO_FUNC.keys()).index(category)], label='model')
        plt.plot(range(LAYERS), all_R2_control[list(CATEGORY_TO_FUNC.keys()).index(category)], label='Control')
        plt.xlabel('Layer')
        plt.ylabel('Correlation')
        plt.title(f'Correlation for each layer')
        plt.legend()
        plt.savefig(f'{output_dir}/correlation_layers.png')

    # plot correlation of each layer for all categories
    plt.figure()
    plt.errorbar(range(LAYERS), np.mean(all_R2, axis=0), yerr=np.std(all_R2, axis=0) / np.sqrt(all_R2.shape[0]), label='Mean')
    plt.xlabel('Layer')
    plt.ylabel('Correlation')
    plt.title(f'Correlation for each layer')
    plt.legend()
    plt.savefig(f'output/all_correlation_layers.png')

    plt.figure()
    plt.errorbar(range(LAYERS), np.mean(all_R2, axis=0),
                 yerr=np.std(all_R2, axis=0) / np.sqrt(all_R2.shape[0]),
                 label='Models using A1 as predictors', linewidth=4,elinewidth=3)
    plt.errorbar(range(LAYERS), np.mean(all_R2_control, axis=0),
                 yerr=np.std(all_R2_control, axis=0) / np.sqrt(all_R2_control.shape[0]),
                 label='Models using A2 as predictors', linewidth=4, elinewidth=3)
    plt.xlabel('Layer')
    plt.ylabel('Mean $R^2$')
    plt.legend()
    plt.savefig(f'output/all_correlation_layers_with_control.png')

    # save R^2 to csv. for layers 0.5 and 0.75
    all_r2_0_5 = np.round(all_R2[:, LAYERS // 2 - 1], 2)
    all_r2_0_66 = np.round(all_R2[:, 2 * LAYERS // 3 - 1], 2)
    all_r2_0_75 = np.round(all_R2[:, 3 * LAYERS // 4 - 1], 2)

    df = pd.DataFrame({'Category': list(CATEGORY_TO_FUNC.keys()),
                       'R2_0_5': all_r2_0_5,
                       'R2_0_66': all_r2_0_66,
                       'R2_0_75': all_r2_0_75})
    df.to_csv('output/r2_corr.csv')

    # to npy
    np.save('output/all_R2.npy', all_R2)
