#%%
import torch
import pickle
from globals import *
from plotting_master import context_elephants
from tqdm.auto import tqdm
layers = np.arange(0,26)

import numpy as np

def get_average_run_length(arr, context_length=1024):
    N = arr.size
    num_full_docs = N // context_length

    docs = arr.reshape(num_full_docs, context_length)
    total_of_completed_run_lengths = 0
    count_of_completed_runs = 0

    for i in range(num_full_docs):
        doc = docs[i, :]
        padded_doc = np.concatenate(([0], doc, [0]))
        diffs = np.diff(padded_doc)
        run_starts_in_doc = np.where(diffs == 1)[0]
        run_ends_in_doc_exclusive = np.where(diffs == -1)[0]
        
        for start_idx, end_idx_exclusive in zip(run_starts_in_doc, run_ends_in_doc_exclusive):
            if end_idx_exclusive < context_length:
                run_length = end_idx_exclusive - start_idx
                total_of_completed_run_lengths += run_length
                count_of_completed_runs += 1
    if count_of_completed_runs == 0:
        return 0.0
    return total_of_completed_run_lengths / count_of_completed_runs
#%%
run_lengths = {}
for layer in tqdm(layers, desc="Layers"):
    run_lengths[layer] = {}
    freqs, saes, elephant_acts = get_mydata(layer, freqs=True, saes=True, elephant_acts=True)
    elephants = print_elephants_with_pairs(freqs,saes,thres=0.1,flat=False)
    elephants = [p for p in elephants if len(p) == 2]

    for e1, e2 in elephants:
        run_lengths[layer][e1] = {}
        run_lengths[layer][e2] = {}
        e1_acts = (elephant_acts[e1].to_dense() > 0).int().numpy()
        e2_acts = (elephant_acts[e2].to_dense() > 0).int().numpy()

        e1_acts[::1024] = 0
        e2_acts[::1024] = 0

        run_lengths1 = get_average_run_length(e1_acts)
        run_lengths2 = get_average_run_length(e2_acts)

        run_lengths[layer][e1] = run_lengths1
        run_lengths[layer][e2] = run_lengths2
 
with open('final/chunky/run_lengths.pkl', 'wb') as f:
    pickle.dump(run_lengths, f)

#%%
run_lengths = load_pickle('final/chunky/run_lengths.pkl')
means_per_layer = []
scatter_x = []
scatter_y = []
N = 1024*5000
for layer in tqdm(layers,desc='Layers'):
    elephant_acts_dict, = get_mydata(layer,elephant_acts=True)
    per_layer = []
    for e in run_lengths[layer]:
        if e in context_elephants.get(layer,[]):
            continue
        run_length = run_lengths[layer][e]
        freq = (elephant_acts_dict[e].to_dense() > 0).int().sum() / N
        expected_run_length = 1/(1-freq)
        normalized_run_length = run_length / expected_run_length

        scatter_x.append(layer)
        scatter_y.append(normalized_run_length)
        per_layer.append(normalized_run_length)
    means_per_layer.append(np.mean(per_layer))

#%%
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=scatter_x,
    y=scatter_y,
    mode='markers',
    marker=dict(color='#636EFA',opacity=0.8),
    name='Individual elephants'
))

fig.add_trace(go.Scatter(
    x=layers,
    y=means_per_layer,
    mode='lines',
    line=dict(color='grey'),
    name='Mean'
))

fig.update_layout(
    xaxis_title='Layer',
    yaxis_title='Normalized Run Length',
    # yaxis=dict(range=[0.7, 4.2]),
    showlegend=False,
    height=400,
    width=600,
    margin=dict(l=10,r=10,t=10,b=10)
)

# fig.write_image('*PLOTS/run_lengths.pdf',scale=20)
fig.show()


#%%

#%%


