period_elephants = {0: [],
 1: [],
 2: [11204, 14056],
 3: [5348, 7686],
 4: [12131, 8730],
 5: [14217, 4341],
 6: [6058, 196],
 7: [11227, 6751],
 8: [8641, 7357],
 9: [6227, 9315],
 10: [],
 11: [],
 12: [],
 13: [],
 14: [],
 15: [],
 16: [],
 17: [],
 18: [],
 19: [],
 20: [],
 21: [],
 22: [],
 23: [],
 24: [],
 25: []}

newline_elephants = {0: [3041, 130],
 1: [13923, 13532],
 2: [14750, 8650],
 3: [5348, 7686, 10721, 2170],
 4: [4765, 16348, 8730],
 5: [8680, 417, 14217, 4341],
 6: [13073, 6058, 196, 7840, 9687],
 7: [3968, 4435, 11227, 6751],
 8: [10524, 9706, 7357],
 9: [6227, 9315],
 10: [5717, 11500],
 11: [892, 10559],
 12: [2291, 13295],
 13: [3517, 46],
 14: [],
 15: [],
 16: [3898],
 17: [4601],
 18: [],
 19: [],
 20: [],
 21: [],
 22: [],
 23: [],
 24: [],
 25: []}

context_elephants = {0: [15906],
 1: [15664],
 2: [5542],
 3: [13764, 12815],
 4: [5301],
 5: [697],
 6: [13460],
 7: [3704],
 8: [2024],
 9: [54],
 10: [3986],
 11: [3151],
 12: [2620],
 13: [11256],
 14: [12319],
 15: [1695],
 16: [3847],
 17: [15620],
 18: [9392],
 19: [3019],
 20: [9768],
 21: [13174],
 22: [5286],
 23: [5286],
 24: [9589],
 25: []}

nullspace_elephants = {
    0:[],
    1:[],
    2:[],
    3:[],
    4:[],
    5:[],
    6:[],
    7:[],
    8:[],
    9:[],
    10:[],
    11:[],
    12:[],
    13:[],
    14:[],
    15:[],
    16:[],
    17:[],
    18:[],
    19:[],
    20:[],
    21:[],
    22:[],
    23:[],
    24:[   43,    99,   261,   301,   391,  1802,  1804,  1842,  3173,  3222,
          3848,  4067,  4175,  4182,  4680,  4714,  5014,  5186,  5472,  5883,
          6169,  6314,  6319,  6702,  7061,  7640,  8119,  8612,  8808,  9063,
          9903, 10070, 10177, 10660, 11139, 11273, 11326, 11536, 12113, 12133,
         12646, 12830, 12930, 13269, 13752, 13912, 14161, 14371, 14391, 14448,
         14870, 15021, 15298, 15306, 15918, 16200, 16319, 16339],
    25:[  105,   348,   484,  1398,  1625,  1689,  1734,  3017,  3023,  3420,
          3473,  3627,  4587,  5090,  5149,  5755,  5772,  6939,  7443,  7479,
          7490,  7669,  8002,  8043,  8657,  8662,  8863,  9817,  9964, 10150,
         10163, 10175, 10251, 10550, 10572, 10942, 11402, 11936, 12591, 12784,
         12908, 13100, 13442, 13749, 14186, 14325, 14402, 14651, 14684, 15034,
         15298, 15320, 15645, 15861, 15920, 16361]}

alphabet_elephants = {0: [],
 1: [],
 2: [],
 3: [],
 4: [],
 5: [],
 6: [],
 7: [],
 8: [],
 9: [],
 10: [],
 11: [],
 12: [],
 13: [],
 14: [],
 15: [],
 16: [],
 17: [],
 18: [],
 19: [],
 20: [],
 21: [],
 22: [],
 23: [],
 24: [],
 25: [30,
 783,
 1761,
 1917,
 2499,
 2651,
 3378,
 4664,
 5641,
 6596,
 7342,
 7476,
 10480,
 10836,
 11794,
 12030,
 12428,
 13531,
 14263,
 15287,
 15618]}

meaning_elephants = {0: [16315, 12950, 5542, 8072, 13231, 10145, 12046, 1097, 4409, 2683, 11509, 15906, 16258, 2596, 3512], 
                     1: [9770, 13412, 246, 392, 7649, 740, 10424, 8035, 12817, 2358, 2991, 1851, 9201, 11433, 5292, 10736, 3440, 11461, 16136, 5771, 8725, 746], 
                     2: [15089, 7132, 11632, 7843, 14056, 4747, 14745, 8714, 1707, 3364, 6088, 5139, 7452, 5186, 405, 1147, 14371, 10673, 15438, 8297, 3804, 4643, 5750, 11413, 5942, 15134, 8294, 6464], 
                     3: [5757, 16178, 622, 15968, 7503, 13752, 14115, 6628, 8254, 4928], 
                     4: [12690, 14328, 11137, 1602, 15034, 2691, 8465, 11299, 16209, 14525, 6050, 7064, 770, 13610, 189, 10384, 2289, 9400, 8656, 1790, 8739, 5268, 10722, 9437, 7708, 3754, 14941, 10365, 5319, 1024, 57, 8114, 13299, 11368, 4258, 15738, 8265, 5417, 920, 295, 14325, 502, 8651, 16070, 10521, 12860, 14814, 11591, 8769, 16216, 15468, 10192, 15752, 4967, 8840, 10411, 11031, 9045, 14591, 10764, 14501, 11620, 853, 10380, 7221, 15626, 14542, 7088, 10626, 9274, 1724, 14599, 2003, 5076, 6448, 12207, 12400, 1435, 3793, 9869, 12857, 3602, 9863, 10893, 11679, 15095, 4825, 5212, 12843, 11884, 2851, 15961, 7086, 7238, 2969, 6439, 7587, 2894, 4110, 1205, 1733, 2739, 2346, 13701, 5585, 4490, 8100, 10105, 12002, 11891, 7119, 12108, 9090, 8942], 
                     5: [2373],
                     6: [14549, 554, 11184, 6092], 
                     7: [5269, 9698, 10077, 14211], 
                     8: [1869, 3308, 3254], 
                     9: [10185, 7421, 2225, 5204], 
                     10: [14660, 1205, 1833], 
                     11: [7732], 
                     12: [7507, 1322], 
                     13: [2755], 
                     14: [7214], 
                     15: [8610, 5114], 
                     16: [], 
                     17: [1386], 
                     18: [15275], 
                     19: [4742, 5320], 
                     20: [15509, 8820], 
                     21: [4435], 
                     22: [6359, 8648], 
                     23: [7655], 
                     24: [5845, 1842], 
                     25: [15034, 15287, 9817, 13580, 7342, 16103, 14263, 13720, 7476, 9887]}

binding_elephants = {0: [],
 1: [],
 2: [],
 3: [],
 4: [],
 5: [],
 6: [],
 7: [],
 8: [],
 9: [],
 10: [1740, 4472, 5717, 11500],
 11: [6934, 8082, 892, 10559],
 12: [14906, 14599, 2291, 13295, 7541, 2009],
 13: [3517, 46, 15275, 11449],
 14: [11575, 2411, 8515, 15297, 6699, 1802],
 15: [],
 16: [2889, 8811],
 17: [10495, 491],
 18: [],
 19: [],
 20: [],
 21: [],
 22: [],
 23: [],
 24: [],
 25: []}

pca_elephants = {0: [],
 1: [],
 2: [15089, 13092],
 3: [9134, 895],
 4: [],
 5: [],
 6: [9743, 15079],
 7: [12287, 14146],
 8: [9213, 302],
 9: [12102, 1435],
 10: [4392, 3031],
 11: [12945, 16123],
 12: [6810, 1041],
 13: [11248, 15887],
 14: [],
 15: [2234, 10716],
 16: [1033, 16028],
 17: [],
 18: [7373, 3851],
 19: [12025, 4346],
 20: [6631, 8684],
 21: [4138, 3461],
 22: [15056, 4384],
 23: [6961, 11271],
 24: [14448, 4680],
 25: [15298, 16361]}

from globals import *
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt

# Define the category priority order (highest priority first)
CATEGORY_PRIORITY = [
    'Context-Tracking',
    'Sentence-Tracking',
    'Alphabet',
    'Nullspace',
    'Context-Binding',
    'Paragraph-Tracking',
    'Meaning',
    'PCA'
]

# Define the category order for plotting (bottom to top in the stackplot)
CATEGORY_ORDER = [
    'Context-Tracking',
    'Paragraph-Tracking',
    'Sentence-Tracking', 
    'PCA',
    'Alphabet',
    'Nullspace',
    'Meaning',
    'Context-Binding',
    'Unknown'
]

CATEGORY_TO_DICT = {
    'Sentence-Tracking': period_elephants,
    'Context-Binding': binding_elephants,
    'Paragraph-Tracking': newline_elephants,
    'Alphabet': alphabet_elephants,
    'Nullspace': nullspace_elephants,
    'Context-Tracking': context_elephants,
    'Meaning': meaning_elephants,
    'PCA': pca_elephants,
}

CATEGORY_COLORS = {
    'Unknown': '#ececec',
    'Nullspace': '#a6cee3',
    'Alphabet': '#1f78b4',
    'Meaning': '#b2df8a',
    'Context-Binding': '#33a02c',
    'Context-Tracking': '#fb9a99',
    'Paragraph-Tracking': '#e31a1c',
    'Sentence-Tracking': '#ff7f00',
    'PCA': '#8c564b',
}

#%%
layers = np.arange(0,26)
category_counts = {cat: [] for cat in CATEGORY_ORDER}
category_absolute = {cat: [] for cat in CATEGORY_ORDER}

def get_all_categories(elephant, layer):
    categories = []
    for cat in CATEGORY_PRIORITY:
        if elephant in CATEGORY_TO_DICT[cat].get(layer, []):
            categories.append(cat)
    return categories if categories else ['Unknown']

def assign_category(elephant, layer):
    categories = get_all_categories(elephant, layer)
    if len(categories) > 1:
        print(f"Layer {layer}, Elephant {elephant} belongs to multiple categories: {categories}")
    return categories[0]

for layer in layers:
    freqs, = get_mydata(layer, freqs=True)
    elephants = get_elephants_thres(freqs, 0.1).cpu().numpy()
    cat_for_elephant = {e: assign_category(e, layer) for e in elephants}
    counts = {cat: 0 for cat in CATEGORY_ORDER}
    for cat in cat_for_elephant.values():
        counts[cat] += 1
    
    for cat in CATEGORY_ORDER:
        category_absolute[cat].append(counts[cat])
    
    total_elephants = sum(counts.values())
    if total_elephants > 0:
        percentages = {cat: (count/total_elephants)*100 for cat, count in counts.items()}
    else:
        percentages = {cat: 0 for cat in CATEGORY_ORDER}
        
    for cat in CATEGORY_ORDER:
        category_counts[cat].append(percentages[cat])

#%%
def plot_elephants(mode):
    if mode not in ['PERCENTAGE', 'ABSOLUTE']:
        raise ValueError("mode must be either 'PERCENTAGE' or 'ABSOLUTE'")
        
    data = category_counts if mode == 'PERCENTAGE' else category_absolute
    
    fig = go.Figure()
    
    for i, cat in enumerate(CATEGORY_ORDER):
        fig.add_trace(go.Scatter(
            x=layers,
            y=data[cat],
            mode='none',
            stackgroup='one',
            name=cat,
            fillcolor=CATEGORY_COLORS.get(cat),
            opacity=0.9
        ))

    # Update layout
    fig.update_layout(
        xaxis_title="Layer",
        yaxis_title="Percentage of Dense Latents" if mode == 'PERCENTAGE' else "Number of Elephants",
        xaxis=dict(range=[0, 25]),
        yaxis=dict(range=[0, 82]),
        width=700,
        height=260,
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=1.27,
            font=dict(size=9)
        ),
        margin=dict(l=50, r=120, t=20, b=50)
    )
    fig.write_image('*PLOTS/master.pdf',scale=20)
    fig.show()

plot_elephants(mode='PERCENTAGE')  # or 'ABSOLUTE'

# %%
