import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
from tqdm import tqdm
import matplotlib as mpl


SEQ_LEN = 32760*4
HEADS   = 12
BLOCKS  = range(30)
BASE_B  = 10000 
D       = 44
DMUL    = 4 # extrapolation factor
PAIRS   = D // 2
T = 1560


Δmax      = 21 * DMUL
FINE_STEP = 1.0 / 32.0
Δ_fine    = np.arange(0.0, Δmax + FINE_STEP, FINE_STEP, dtype=np.float64)

# ======================== Statistics ========================
dot_sum   = torch.zeros(PAIRS)
cross_sum = torch.zeros(PAIRS)
dot_sum_mean = torch.zeros(PAIRS)
cross_sum_mean = torch.zeros(PAIRS)
dot_sum_var = torch.zeros(PAIRS)
cross_sum_var = torch.zeros(PAIRS)
n_blk     = 0

with torch.no_grad():
    for idx in tqdm(BLOCKS):
        q = torch.load(f"qk_store/q_{idx}.pt", map_location="cpu").squeeze(0).to(torch.float32).reshape(SEQ_LEN,HEADS,128)  # [S,H,128], cast to float32
        k = torch.load(f"qk_store/k_{idx}.pt", map_location="cpu").squeeze(0).to(torch.float32).reshape(SEQ_LEN,HEADS,128)  # [S,H,128], cast to float32
        q44 = q[:, :, :44].reshape(HEADS, SEQ_LEN, PAIRS, 2)  # [H, S, P, 2]
        k44 = k[:, :, :44].reshape(HEADS, SEQ_LEN, PAIRS, 2)  # [H, S, P, 2]

        # Parallel compute: average dot/cross of q at all positions with k at offset positions
        # Precompute all possible offset multiples
        max_offset = SEQ_LEN // T
        offsets = torch.arange(-max_offset, max_offset + 1)  # all possible offset multiples
        
        # Build position index matrix [SEQ_LEN, n_offsets]
        positions = torch.arange(SEQ_LEN).unsqueeze(1) + offsets.unsqueeze(0) * T
        
        # Create mask of valid positions
        valid_mask = (positions >= 0) & (positions < SEQ_LEN)  # [SEQ_LEN, n_offsets]
        
        # Clamp invalid positions (later ignored by the mask)
        positions = torch.clamp(positions, 0, SEQ_LEN - 1)
        
        # Gather all corresponding k vectors
        # Use advanced indexing instead of gather to avoid dimension issues
        k_gathered = k44[:, positions, :, :]  # [H, S, n_offsets, P, 2]
        
        # Expand q vectors [H, SEQ_LEN, 1, P, 2]
        q_expanded = q44.unsqueeze(2)  # [H, S, 1, P, 2]
        
        # Compute cosine term (q0*k0 + q1*k1) [H, SEQ_LEN, n_offsets, P]
        dot_products = (q_expanded[:, :, :, :, 0] * k_gathered[:, :, :, :, 0] + 
                       q_expanded[:, :, :, :, 1] * k_gathered[:, :, :, :, 1])
        
        # Compute sine term (q1*k0 - q0*k1) [H, SEQ_LEN, n_offsets, P]
        cross_products = (q_expanded[:, :, :, :, 1] * k_gathered[:, :, :, :, 0] - 
                         q_expanded[:, :, :, :, 0] * k_gathered[:, :, :, :, 1])
        
        # Apply valid position mask and average across heads first
        valid_mask_expanded = valid_mask.unsqueeze(0).unsqueeze(-1)  # [1, S, n_offsets, 1]
        
        # Number of valid positions per query position
        valid_count_per_pos = valid_mask.sum(dim=1).unsqueeze(-1)  # [S, 1]
        
        # Avoid division by zero
        valid_count_per_pos = torch.clamp(valid_count_per_pos, min=1)
        
        # Average across heads first to get [S, n_offsets, P]
        dot_products_head_avg = dot_products.mean(dim=0)  # [S, n_offsets, P]
        cross_products_head_avg = cross_products.mean(dim=0)  # [S, n_offsets, P]
        
        # For each sequence position, compute the mean over valid offsets
        dot_mean_per_pos = (dot_products_head_avg * valid_mask.unsqueeze(-1)).sum(dim=1) / valid_count_per_pos  # [S, P]
        cross_mean_per_pos = (cross_products_head_avg * valid_mask.unsqueeze(-1)).sum(dim=1) / valid_count_per_pos  # [S, P]
        
        # Expand means to match dimensions for variance calculation
        dot_mean_expanded = dot_mean_per_pos.unsqueeze(1)  # [S, 1, P]
        cross_mean_expanded = cross_mean_per_pos.unsqueeze(1)  # [S, 1, P]
        
        # Compute per-position variance: Var = E[(X−μ)²]
        dot_variance_per_pos = (((dot_products_head_avg - dot_mean_expanded) ** 2) * valid_mask.unsqueeze(-1)).sum(dim=1) / valid_count_per_pos  # [S, P]
        cross_variance_per_pos = (((cross_products_head_avg - cross_mean_expanded) ** 2) * valid_mask.unsqueeze(-1)).sum(dim=1) / valid_count_per_pos  # [S, P]

        # Replace NaNs with 0
        dot_variance_per_pos = torch.where(torch.isnan(dot_variance_per_pos), torch.zeros_like(dot_variance_per_pos), dot_variance_per_pos)
        cross_variance_per_pos = torch.where(torch.isnan(cross_variance_per_pos), torch.zeros_like(cross_variance_per_pos), cross_variance_per_pos)

        # If only one valid position exists, the variance should be 0
        single_valid_mask = (valid_mask.sum(dim=1) <= 1).unsqueeze(-1)  # [S, 1]
        dot_variance_per_pos = torch.where(single_valid_mask, torch.zeros_like(dot_variance_per_pos), dot_variance_per_pos)
        cross_variance_per_pos = torch.where(single_valid_mask, torch.zeros_like(cross_variance_per_pos), cross_variance_per_pos)
        
        # Per-sequence-position statistics
        dots = dot_mean_per_pos    # [S, P]
        crosses = cross_mean_per_pos  # [S, P]

        # Per-sequence-position variance
        dots_var = dot_variance_per_pos    # [S, P]
        crosses_var = cross_variance_per_pos  # [S, P]

        # Raise error on NaNs
        if torch.isnan(dots).any() or torch.isnan(crosses).any():
            raise ValueError("There are nan values in dots or crosses")
        if torch.isnan(dots_var).any() or torch.isnan(crosses_var).any():
            raise ValueError("There are nan values in dots_var or crosses_var")

        # Mean across sequence length to obtain values over P
        dot_sum_mean += dots.mean(dim=0)   # [P]
        cross_sum_mean += crosses.mean(dim=0)  # [P]
        dot_sum_var += dots_var.mean(dim=0)      # [P]
        cross_sum_var += crosses_var.mean(dim=0)  # [P]

        n_blk += 1

cos_term = (dot_sum_mean / n_blk).float().numpy().astype(np.float64)
sin_term = (cross_sum_mean / n_blk).float().numpy().astype(np.float64)
cos_var_term   = (dot_sum_var   / n_blk).float().numpy().astype(np.float64)  # E[Var_seq(E_1)]
sin_var_term   = (cross_sum_var / n_blk).float().numpy().astype(np.float64)  # E[Var_seq(E_2)]
total_term = cos_var_term + sin_var_term
dims       = np.arange(PAIRS)

# ======================= Plotting ======================
mpl.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "Nimbus Roman", "STIXGeneral", "DejaVu Serif"],
    "mathtext.fontset": "custom",
    "mathtext.rm": "Times New Roman",
    "mathtext.it": "Times New Roman:italic",
    "mathtext.bf": "Times New Roman:bold",
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "svg.fonttype": "none",
})

TITLE_FONTSIZE = 48
LABEL_FONTSIZE = 48
TICK_FONTSIZE  = 36
LEGEND_FONTSIZE = 36

os.makedirs("wan", exist_ok=True)

all_values = np.concatenate([cos_term, sin_term, cos_var_term, sin_var_term])
y_min = all_values.min()
y_max = all_values.max()
y_range_padding = (y_max - y_min) * 0.05
y_lim = [y_min - y_range_padding, y_max + y_range_padding]

plt.figure(figsize=(12,8))
plt.plot(dims, cos_var_term, label=r"Variance of $E_1^{(i)}$ with respect to $\Delta t$", linewidth=1.8, marker='o')
plt.plot(dims, sin_var_term, label=r"Variance of $E_2^{(i)}$ with respect to $\Delta t$", linewidth=1.8, marker='s')
plt.xlabel(r"Frequency index $i$", fontsize=LABEL_FONTSIZE)
plt.ylabel(r"Variance with respect to $\Delta t$", fontsize=LABEL_FONTSIZE)
plt.ylim(0, y_max + y_range_padding)
plt.title(r"Variance of $E_1^{(i)}, E_2^{(i)}$ with respect to $\Delta t$", fontsize=LABEL_FONTSIZE)
plt.legend(fontsize=LEGEND_FONTSIZE, loc="best", framealpha=0.6)
plt.grid(True, alpha=0.3)
plt.tick_params(axis='both', labelsize=TICK_FONTSIZE)
plt.tight_layout()
plt.savefig("wan/qk_dot_cross_var_curve.png", dpi=150)
plt.close()

fig, ax = plt.subplots(figsize=(10,8), constrained_layout=True)
ax.plot(dims, cos_term, label=r"$\hat{E}_1^{(i)}$", linewidth=1.8, marker='o')
ax.plot(dims, sin_term, label=r"$\hat{E}_2^{(i)}$", linewidth=1.8, marker='s')
ax.set_xlabel(r"Frequency index $i$", fontsize=LABEL_FONTSIZE)
ax.set_ylabel(r"$\hat{E}_1^{(i)},\hat{E}_2^{(i)}$", fontsize=LABEL_FONTSIZE)
ax.set_title(r"Value of $\hat{E}_1^{(i)},\hat{E}_2^{(i)}$", fontsize=LABEL_FONTSIZE)
ax.set_ylim(y_lim)
ax.legend(fontsize=LEGEND_FONTSIZE, loc="best", framealpha=0.6)
ax.grid(True, alpha=0.3)
ax.tick_params(axis='both', labelsize=TICK_FONTSIZE)
fig.savefig("wan/qk_dot_cross_mean_curve.png", dpi=150)

def sparse_xticks(dims, step=3):
    ticks = list(dims[::step])
    if dims[-1] not in ticks:
        ticks.append(dims[-1])
    return ticks

magnitude = np.sqrt(cos_term ** 2 + sin_term ** 2)       # [PAIRS]

plt.figure(figsize=(10,8))
plt.plot(dims, magnitude, marker="o", linewidth=1.8)
plt.xticks(sparse_xticks(dims, step=3)) 
plt.xlabel(r"Frequency index $i$", fontsize=LABEL_FONTSIZE)
plt.ylabel(r"$a_i$", fontsize=LABEL_FONTSIZE)
plt.title(r"Amplitudes of RoPE frequencies", fontsize=LABEL_FONTSIZE)
plt.grid(True, alpha=0.3, linewidth=0.8)
plt.tick_params(axis='both', labelsize=TICK_FONTSIZE)
plt.tight_layout()
plt.savefig("wan/rope_pair_percentage.png", dpi=150)
plt.close()

angles_rad = np.arctan2(-sin_term, cos_term)
print(angles_rad)
angles_rad = np.arctan2(-sin_term, cos_term)
plt.figure(figsize=(10,8))
plt.ylim(-0.65, np.pi)
plt.plot(dims, angles_rad, marker="o", linewidth=1.8)
plt.xticks(sparse_xticks(dims, step=3))
plt.xlabel(r"Frequency index $i$", fontsize=LABEL_FONTSIZE)
plt.ylabel(r"$b_i$", fontsize=LABEL_FONTSIZE)
plt.title(r"Bias angles", fontsize=LABEL_FONTSIZE)
plt.grid(True, alpha=0.3, linewidth=0.8)
plt.tick_params(axis='both', labelsize=TICK_FONTSIZE)
plt.tight_layout()
plt.savefig("wan/rope_pair_angles_rad.png", dpi=150)
plt.close()