import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import os
import matplotlib as mpl


SEQ_LEN = 67320*4
HEADS   = 24
PAIRS   = 8
BLOCKS  = range(60)
BASE_B  = 256
D       = 16
DMUL    = 4 # extrapolation factor
Δmax    = 33 * DMUL
T       = 2040


FINE_STEP = 1.0 / 32.0 
Δ_fine    = np.arange(0.0, Δmax + FINE_STEP, FINE_STEP, dtype=np.float64)

# ======================== Statistics ========================
dot_sum_var   = torch.zeros(PAIRS)
cross_sum_var = torch.zeros(PAIRS)
dot_sum_mean   = torch.zeros(PAIRS)
cross_sum_mean = torch.zeros(PAIRS)
n_blk     = 0

with torch.no_grad():
    for idx in tqdm(BLOCKS):
        q = torch.load(f"qk_store/q_{idx}.pt", map_location="cpu").squeeze(0)  # [S,H,128]
        k = torch.load(f"qk_store/k_{idx}.pt", map_location="cpu").squeeze(0)

        q16 = q[:, :, :16].reshape(HEADS, SEQ_LEN, PAIRS, 2)  # [H, S, P, 2]
        k16 = k[:, :, :16].reshape(HEADS, SEQ_LEN, PAIRS, 2)  # [H, S, P, 2]

        # Parallel compute: average dot/cross between q at each position and k at offset positions
        # Precompute all possible offset multiples
        max_offset = SEQ_LEN // T
        offsets = torch.arange(-max_offset, max_offset + 1)  # all possible offset multiples
        
        # Build position index matrix [SEQ_LEN, n_offsets]
        positions = torch.arange(SEQ_LEN).unsqueeze(1) + offsets.unsqueeze(0) * T
        
        # Create valid position mask
        valid_mask = (positions >= 0) & (positions < SEQ_LEN)  # [SEQ_LEN, n_offsets]
        
        # Clamp invalid positions (later ignored by the mask)
        positions = torch.clamp(positions, 0, SEQ_LEN - 1)
        
        # Gather all corresponding k vectors
        # Use advanced indexing instead of gather to avoid dimension issues
        k_gathered = k16[:, positions, :, :]  # [H, S, n_offsets, P, 2]
        
        # Expand q vectors [H, SEQ_LEN, 1, P, 2]
        q_expanded = q16.unsqueeze(2)  # [H, S, 1, P, 2]
        
        # Compute cosine term (q0*k0 + q1*k1) [H, SEQ_LEN, n_offsets, P]
        dot_products = (q_expanded[:, :, :, :, 0] * k_gathered[:, :, :, :, 0] + 
                       q_expanded[:, :, :, :, 1] * k_gathered[:, :, :, :, 1])
        
        # Compute sine term (q1*k0 - q0*k1) [H, SEQ_LEN, n_offsets, P]
        cross_products = (q_expanded[:, :, :, :, 1] * k_gathered[:, :, :, :, 0] - 
                         q_expanded[:, :, :, :, 0] * k_gathered[:, :, :, :, 1])
        
        # Apply valid position mask and average across heads first
        valid_mask_expanded = valid_mask.unsqueeze(0).unsqueeze(-1)  # [1, S, n_offsets, 1]
        
        # Number of valid positions per query position
        valid_count_per_pos = valid_mask.sum(dim=1).unsqueeze(-1)  # [S, 1]
        
        # Avoid division by zero
        valid_count_per_pos = torch.clamp(valid_count_per_pos, min=1)
        
        # Average across heads first to get [S, n_offsets, P]
        dot_products_head_avg = dot_products.mean(dim=0)  # [S, n_offsets, P]
        cross_products_head_avg = cross_products.mean(dim=0)  # [S, n_offsets, P]
        
        # For each sequence position, compute the mean over valid offsets
        dot_mean_per_pos = (dot_products_head_avg * valid_mask.unsqueeze(-1)).sum(dim=1) / valid_count_per_pos  # [S, P]
        cross_mean_per_pos = (cross_products_head_avg * valid_mask.unsqueeze(-1)).sum(dim=1) / valid_count_per_pos  # [S, P]
        
        # Expand means to match dimensions for variance calculation
        dot_mean_expanded = dot_mean_per_pos.unsqueeze(1)  # [S, 1, P]
        cross_mean_expanded = cross_mean_per_pos.unsqueeze(1)  # [S, 1, P]
        
        # Compute per-position variance: Var = E[(X−μ)²]
        dot_variance_per_pos = ((dot_products_head_avg - dot_mean_expanded) ** 2 * valid_mask.unsqueeze(-1)).sum(dim=1) / valid_count_per_pos  # [S, P]
        cross_variance_per_pos = ((cross_products_head_avg - cross_mean_expanded) ** 2 * valid_mask.unsqueeze(-1)).sum(dim=1) / valid_count_per_pos  # [S, P]
        
        # For positions with only one valid offset, the variance should be 0
        single_valid_mask = (valid_mask.sum(dim=1) <= 1).unsqueeze(-1)  # [S, 1]
        dot_variance_per_pos = torch.where(single_valid_mask, torch.zeros_like(dot_variance_per_pos), dot_variance_per_pos)
        cross_variance_per_pos = torch.where(single_valid_mask, torch.zeros_like(cross_variance_per_pos), cross_variance_per_pos)
        
        # Per-sequence-position statistics
        dots = dot_mean_per_pos    # [S, P]
        crosses = cross_mean_per_pos  # [S, P]

        # Per-sequence-position variance
        dots_var = dot_variance_per_pos    # [S, P]
        crosses_var = cross_variance_per_pos  # [S, P]

        dot_sum_mean += dots.mean(dim=0)   # [P]
        cross_sum_mean += crosses.mean(dim=0)  # [P]
        dot_sum_var += dots_var.mean(dim=0)      # [P]
        cross_sum_var += crosses_var.mean(dim=0)  # [P]
        n_blk += 1

cos_term = (dot_sum_mean / n_blk).float().numpy().astype(np.float64)
sin_term = (cross_sum_mean / n_blk).float().numpy().astype(np.float64)
cos_var_term   = (dot_sum_var   / n_blk).float().numpy().astype(np.float64)  # E[Var_seq(A_i)]
sin_var_term   = (cross_sum_var / n_blk).float().numpy().astype(np.float64)  # E[Var_seq(B_i)]

total_term = cos_var_term + sin_var_term
import csv

csv_path = "result_hunyuan/rope_terms_and_vars.csv"
os.makedirs("result_hunyuan", exist_ok=True)
with open(csv_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["pair_index", "cos_term", "sin_term", "cos_var_term", "sin_var_term", "total_var_term"])
    for i in range(PAIRS):
        writer.writerow([
            i,
            float(cos_term[i]),
            float(sin_term[i]),
            float(cos_var_term[i]),
            float(sin_var_term[i]),
            float(cos_var_term[i] + sin_var_term[i])
        ])
print(f"Saved cos_term, sin_term, cos_var_term, sin_var_term to {csv_path}")
dims       = np.arange(PAIRS)

# ======================= Plotting ======================
mpl.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "Nimbus Roman", "STIXGeneral", "DejaVu Serif"],
    "mathtext.fontset": "custom",
    "mathtext.rm": "Times New Roman",
    "mathtext.it": "Times New Roman:italic",
    "mathtext.bf": "Times New Roman:bold",
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "svg.fonttype": "none",
})

TITLE_FONTSIZE = 48
LABEL_FONTSIZE = 48
TICK_FONTSIZE  = 36
LEGEND_FONTSIZE = 36

total_term = cos_var_term + sin_var_term

dims = np.arange(PAIRS)

os.makedirs("hy", exist_ok=True)


all_values = np.concatenate([cos_term, sin_term, cos_var_term, sin_var_term])
y_min = all_values.min()
y_max = all_values.max()
y_range_padding = (y_max - y_min) * 0.05 
y_lim = [y_min - y_range_padding, y_max + y_range_padding]

plt.figure(figsize=(12,8))
plt.plot(dims, cos_var_term, label=r"Variance of $E_1^{(i)}$ with respect to $\Delta t$", linewidth=1.8, marker='o')
plt.plot(dims, sin_var_term, label=r"Variance of $E_2^{(i)}$ with respect to $\Delta t$", linewidth=1.8, marker='s')
plt.xlabel(r"Frequency index $i$", fontsize=LABEL_FONTSIZE)
plt.ylabel(r"Variance with respect to $\Delta t$", fontsize=LABEL_FONTSIZE)
plt.ylim(0, y_max + y_range_padding)
plt.title(r"Variance of $E_1^{(i)}, E_2^{(i)}$ with respect to $\Delta t$", fontsize=LABEL_FONTSIZE)
plt.legend(fontsize=LEGEND_FONTSIZE, loc="best", framealpha=0.6)
plt.grid(True, alpha=0.3)
plt.tick_params(axis='both', labelsize=TICK_FONTSIZE)
plt.tight_layout()
plt.savefig("hy/qk_dot_cross_var_curve.png", dpi=150)
plt.close()

fig, ax = plt.subplots(figsize=(10,8), constrained_layout=True)
ax.plot(dims, cos_term, label=r"$\hat{E}_1^{(i)}$", linewidth=1.8, marker='o')
ax.plot(dims, sin_term, label=r"$\hat{E}_2^{(i)}$", linewidth=1.8, marker='s')
ax.set_xlabel(r"Frequency index $i$", fontsize=LABEL_FONTSIZE)
ax.set_ylabel(r"$\hat{E}_1^{(i)},\hat{E}_2^{(i)}$", fontsize=LABEL_FONTSIZE)
ax.set_title(r"Value of $\hat{E}_1^{(i)},\hat{E}_2^{(i)}$", fontsize=LABEL_FONTSIZE)
plt.ylim(y_lim)
plt.legend(fontsize=LEGEND_FONTSIZE, loc="best", framealpha=0.6)
plt.grid(True, alpha=0.3)
plt.tick_params(axis='both', labelsize=TICK_FONTSIZE)
plt.tight_layout()
plt.savefig("hy/qk_dot_cross_mean_curve.png", dpi=150)
plt.close()

magnitude = np.sqrt(cos_term ** 2 + sin_term ** 2)       # [PAIRS]

plt.figure(figsize=(10,8))
plt.plot(dims, magnitude, marker="o", linewidth=1.8)
plt.xticks(dims)
plt.xlabel(r"frequency index $i$", fontsize=LABEL_FONTSIZE)
plt.ylabel(r"$a_i$", fontsize=LABEL_FONTSIZE)
plt.title(r"Amplitudes of RoPE frequencies", fontsize=LABEL_FONTSIZE)
plt.grid(True, alpha=0.3, linewidth=0.8)
plt.tick_params(axis='both', labelsize=TICK_FONTSIZE)
plt.tight_layout()
plt.savefig("hy/rope_pair_percentage.png", dpi=150)
plt.close()

angles_rad = np.arctan2(-sin_term, cos_term) 
print(angles_rad)
plt.figure(figsize=(10,8))
plt.ylim(-0.05, np.pi)
plt.plot(dims, angles_rad, marker="o", linewidth=1.8)
plt.xticks(dims)
plt.xlabel(r"Frequency index $i$", fontsize=LABEL_FONTSIZE)
plt.ylabel(r"$b_i$", fontsize=LABEL_FONTSIZE)
plt.title(r"Bias angles", fontsize=LABEL_FONTSIZE)
plt.grid(True, alpha=0.3, linewidth=0.8)
plt.tick_params(axis='both', labelsize=TICK_FONTSIZE)
plt.tight_layout()
plt.savefig("hy/rope_pair_angles_rad.png", dpi=150)
plt.close()
