import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import seaborn as sns
from scipy.stats import gaussian_kde

sns.set_theme(style = "darkgrid", font_scale = 2.6)

base_model =  "csv_logs/openllama_model_ver/"
length_model = "csv_logs/longllama_model_ver/"

out_path = "generate/final_all_layers_longllama.png"

model_names = ['LongLLaMA', 'OpenLLaMA']

num_layers = 26
batch_index = 0
plotting_layers = [6, 12, 18]
eps = 1e-12
window_size = 50

rows = [8, 16, 32, 64, 128, 256, 512]

min_logp = -17.5
max_logp = 0.0

kde_points = 500
ridge_height = 0.9

def attn_prob_files(attn_dir):
    return sorted(Path(attn_dir).rglob("attn_probs"))

def layer_lines(attn_file, layer_index):
    with attn_file.open('r') as f:
        for i, line in enumerate(f):
            if (i % num_layers) != layer_index:
                continue
            line = line.strip()
            if line:
                yield json.loads(line)
                
def mv_average(x, w):
    x = np.asarray(x)
    n = x.size
    if n == 0 or w <= 1:
        return x.copy()
    k = np.ones(w)
    num = np.convolve(x, k, mode="full")[:n]
    den = np.minimum(np.arange(1, n + 1), w)
    return num / den

def kl_uniform(p):
    p = np.asarray(p)
    m = p.size
    if m <= 1:
        return 0.0
    p = np.clip(p, eps, None)
    return np.sum(p*(np.log(p) + np.log(m)))/np.log(m)

def aitchison_norm(p):
    p = np.asarray(p)
    p = np.clip(p, eps, None)
    lp = np.log(p)
    clr = lp - lp.mean()
    return np.linalg.norm(clr)/np.sqrt(p.size)

def padding(x, n):
    return x if x.size >= n else np.pad(x, (0, n- x.size))

def layer_measures(attn_dir, layer_index):
    s1_kl = np.zeros(0)
    s2_kl = np.zeros(0)
    s1_a = np.zeros(0)
    s2_a = np.zeros(0)
    c = np.zeros(0)

    for f in attn_prob_files(attn_dir):
        for array in layer_lines(f, layer_index):
            slice_b = np.asarray(array[batch_index])
            if slice_b.ndim != 3:
                continue
            H, T, _ = slice_b.shape

            s1_kl = padding(s1_kl, T)
            s2_kl = padding(s2_kl, T)
            s1_a = padding(s1_a, T)
            s2_a = padding(s2_a, T)
            c = padding(c, T)

            for h in range(H):
                mat = slice_b[h, :T, :T]
                for r in range(T):
                    p = mat[r, : r + 1]
                    vkl = kl_uniform(p)
                    va = aitchison_norm(p)
                    s1_kl[r] += vkl
                    s2_kl[r] += vkl * vkl
                    s1_a[r] += va
                    s2_a[r] += va * va
                    c[r] += 1.0

    valid = c > 0
    if not np.any(valid):
        z = np.array([])
        return z, z, z, z, 0, z

    Tlen = int(np.where(valid)[0].max() + 1)
    c = c[:Tlen]
    v = valid[:Tlen]

    kl = np.zeros(Tlen)
    a = np.zeros(Tlen)
    kl[v] = s1_kl[:Tlen][v] / c[v]
    a[v] = s1_a[:Tlen][v] / c[v]

    kl_var = np.zeros(Tlen)
    a_var = np.zeros(Tlen)
    kl_var[v] = s2_kl[:Tlen][v] / c[v] - kl[v] ** 2
    a_var[v] = s2_a[:Tlen][v] / c[v] - a[v] ** 2
    kl_var = np.maximum(kl_var, 0.0)
    a_var = np.maximum(a_var, 0.0)

    return kl, kl_var, a, a_var, Tlen, c

def layer_values(attn_dir, layer_index, name):
    logp = {L: [] for L in rows}
    w = {L: [] for L in rows}

    for f in tqdm(attn_prob_files(attn_dir)):
        for array in layer_lines(f, layer_index):
            slice_b = np.asarray(array[batch_index])
            if slice_b.ndim != 3:
                continue
            H, T, _ = slice_b.shape
            valid_rows = [L for L in rows if 1 <= L <= T]
            if not valid_rows:
                continue

            mats = slice_b[:, :T, :T]
            for h in range(H):
                mat = mats[h]
                for L in valid_rows:
                    p = np.clip(mat[L - 1, :L], eps, None)
                    lp = np.log(p)
                    ww = -p * lp
                    logp[L].append(lp)
                    w[L].append(ww)

    for L in rows:
        logp[L] = np.concatenate(logp[L]) if logp[L] else np.array([])
        w[L] = np.concatenate(w[L]) if w[L] else np.array([])

    return logp, w


def kde_maps(logp_base, w_base, logp_length, w_length, x_grid):
    dens_logp, dens_ent = {}, {}
    max_logp_val = 0
    max_ent_val = 0
    
    for model, lp_src, w_src in [("Base", logp_base, w_base), ("Length", logp_length, w_length)]:
        for L in rows:
            vals = lp_src[L]
            ww = w_src[L]
            
            if vals.size == 0:
                dens_logp[(model, L)] = None
                dens_ent[(model, L)] = None
                continue
            
            d1 = gaussian_kde(vals)(x_grid)
            dens_logp[(model, L)] = d1
            max_logp_val = max(max_logp_val, d1.max())

            wsum = ww.sum()
            if wsum <= 0.0:
                dens_ent[(model, L)] = None
            else:
                d2 = gaussian_kde(vals, weights=ww)(x_grid) * wsum
                dens_ent[(model, L)] = d2
                max_ent_val = max(max_ent_val, d2.max())

    return dens_logp, dens_ent, max_logp_val, max_ent_val

def ridgeplot(ax, dens, scale_max, x_grid, xlabel, title = None, show_ylabel = False, show_legend = False):
    contexts = [L for L in rows if any(dens.get((m, L)) is not None for m in ("Base", "Length"))]
    
    for j, L in enumerate(contexts):
        y0 = float(j)
        
        for i, model in enumerate(("Base", "Length")):
            dens2 = dens.get((model, L))
            if dens2 is None:
                continue
            
            d = dens2/scale_max if scale_max >0 else dens
            
            ax.fill_between(x_grid, y0, y0+d*ridge_height, alpha = 0.35, color = f'C{i}')
            ax.plot(x_grid, y0+d*ridge_height, lw = 1, color = f'C{i}')
        
    ax.set_yticks(np.arange(len(contexts)))
    ax.set_yticklabels([str(c) for c in contexts])
    ax.set_xlim(x_grid[0], x_grid[-1])
    ax.set_xlabel(xlabel)
    ax.set_ylabel(r'Context Length $(L)$' if show_ylabel else '')
    if title is not None:
        ax.set_title(title)
        
    if show_legend:
        handles = [
            plt.Line2D([0], [0], color="C0", lw=2, label=model_names[1]),
            plt.Line2D([0], [0], color="C1", lw=2, label=model_names[0]),
        ]
        ax.legend(handles=handles, loc="upper right")
        
x_grid = np.linspace(min_logp, max_logp, kde_points)
fig, axs = plt.subplots(3, len(plotting_layers), figsize=(30, 20), constrained_layout=True, sharey="row")

for col, layer_index in enumerate(plotting_layers):
    logp_base, w_base = layer_values(base_model, layer_index, "Base")
    logp_length, w_length = layer_values(length_model, layer_index, "Length")
    
    dens_logp, dens_ent, max_logp_val, max_ent_val = kde_maps(logp_base, w_base, logp_length, w_length, x_grid)
    
    show_ylabel = col == 0
    
    ridgeplot(axs[0, col], dens_logp, max_logp_val, x_grid, xlabel = r'$\log(p_i)$', title = f'Layer {layer_index}', show_ylabel = show_ylabel, show_legend = (col == len(plotting_layers) -1))
    
    ridgeplot(axs[1, col], dens_ent, max_ent_val, x_grid, xlabel=r'Entropy Weighted $\log(p_i)$', title = None, show_ylabel = show_ylabel, show_legend = False)
    
    kl_base, kl_var_base, a_base, a_var_base, T_base, c_base = layer_measures(base_model, layer_index)
    kl_length, kl_var_length, a_length, a_var_length, T_length, c_length = layer_measures(length_model, layer_index)
    T = min(T_base, T_length)
    
    d_kl = kl_length[:T] - kl_base[:T]
    d_a = a_length[:T] - a_base[:T]
    
    x = np.arange(1, T+1)
    ax = axs[2, col]
    
    se_kl = np.sqrt(np.maximum(kl_var_length[:T] / np.maximum(c_length[:T], 1.0), 0.0) + np.maximum(kl_var_base[:T] / np.maximum(c_base[:T], 1.0), 0.0))
    se_a = np.sqrt(np.maximum(a_var_length[:T] / np.maximum(c_length[:T], 1.0), 0.0) + np.maximum(a_var_base[:T] / np.maximum(c_base[:T], 1.0), 0.0))

    m_a = mv_average(d_a, window_size)
    m_kl = mv_average(d_kl, window_size)
    s_a = mv_average(se_a, window_size)
    s_kl = mv_average(se_kl, window_size)

    ax.plot(x, m_a, label=r"$\Delta \bar d_A(\mathbf{p},\mathbf{u})$")
    ax.fill_between(x, m_a - 1.96 * s_a, m_a + 1.96 * s_a, alpha=0.15)

    ax.plot(x, m_kl, label=r"$\Delta \overline{\mathrm{KL}}(\mathbf{p}\|\mathbf{u})$")
    ax.fill_between(x, m_kl - 1.96 * s_kl, m_kl + 1.96 * s_kl, alpha=0.15)

    ax.set_xlabel(r"Context Length $(L)$")
    if show_ylabel:
        ax.set_ylabel(rf"$\Delta$({model_names[0]} - {model_names[1]})")
    ax.grid(True)
    ax.legend()

plt.savefig(out_path, dpi=300)
plt.close(fig)