import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from scipy.special import softmax, logsumexp

method_name = ['No scale', 'LogN', 'Scale-invariant']
method_colors = {
    'No scale': '#f46d43',
    'LogN': '#313695',
    'Scale-invariant': '#000'
}

title_font_size = 7
axis_label_font_size = 6
ticks_font_size = 6

# set seeds
torch.manual_seed(42)
np.random.seed(42)

# ===== Utility Functions =====
def entropy(P):
    """Calculate entropy for a probability distribution"""
    if isinstance(P, torch.Tensor):
        return - (P * torch.log(P)).sum(-1)
    else:  # numpy array
        return - (P * np.log(P)).sum(-1)

def entropy_from_logits_no_softmax(logits, dim=None):
    """Compute entropy from logits without softmax call"""
    if isinstance(logits, torch.Tensor):
        # PyTorch version
        lse = torch.logsumexp(logits, dim=dim, keepdim=True)
        exp_logits = torch.exp(logits - lse)
        weighted_logits_sum = torch.sum(exp_logits * logits, dim=dim)
        entropy = lse.squeeze(dim=dim) - weighted_logits_sum
    else:
        # NumPy version
        lse = logsumexp(logits, axis=dim, keepdims=True)
        exp_logits = np.exp(logits - lse)
        weighted_logits_sum = np.sum(exp_logits * logits, axis=dim)
        entropy = lse.squeeze(axis=dim) - weighted_logits_sum
    return entropy

def get_logit_stat_torch(S, local_window_size=100):
    """Get entropy and local probability sum for PyTorch tensor"""
    probs = F.softmax(S, dim=-1)
    entropys = entropy_from_logits_no_softmax(S, -1)
    local_prob_sum = probs[..., :local_window_size].sum(-1)
    return entropys, local_prob_sum

def get_logit_stat_numpy(S, local_window_size=100):
    """Get entropy and local probability sum for NumPy array"""
    probs = softmax(S, -1)
    entropys = entropy_from_logits_no_softmax(S, -1)
    local_prob_sum = probs[..., :local_window_size].sum(-1)
    return entropys, local_prob_sum

def create_merged_figure():
    # Create data directory if it doesn't exist
    data_dir = Path('./data')
    data_dir.mkdir(parents=True, exist_ok=True)

    # ===== Figure Setup =====
    fig, axs = plt.subplots(2, 3, figsize=(5.5, 2.8), constrained_layout=True,
                           gridspec_kw={'wspace': 0.03, 'left': 0.08})

    # Define shared style parameters
    plot_config = {
        'marker': 'o',
        'markersize': 2,
        'alpha':1.0
    }

    # ===== Row 1: IID Gaussian Data (First Code) =====
    tau = 100
    m_func = lambda t: -2 * torch.log1p(t / tau)
    a_func = lambda t: torch.sqrt(2 * torch.log1p(t / tau) + 1)

    num_trials = 100
    Ns = [100, 1000, 10000, 100000, 1000000]
    s_ssmax = 0.4

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Check if we have saved data
    gaussian_data_path = data_dir / 'gaussian_attention_stats.npz'
    if gaussian_data_path.exists():
        print("Loading saved Gaussian data...")
        data = np.load(gaussian_data_path)
        attention_score_stats = data['attention_score_stats']
        entropys_standard = data['entropys_standard']
        entropys_logn = data['entropys_logn']
        entropys_scale_free = data['entropys_scale_free']
    else:
        print("Computing Gaussian data...")
        attention_score_stats = []
        for N in Ns:
            S = torch.randn(num_trials, N, device=device)  # Raw <q, k>
            t = torch.arange(N, dtype=torch.float32, device=device)
            m_s, a_s = m_func(t), a_func(t)  # parameter for the affine transformation
            L = S * a_s + m_s
            attention_score_stats.append((
                get_logit_stat_torch(S),
                get_logit_stat_torch(S * s_ssmax * torch.log(torch.tensor(N, dtype=torch.float32, device=device))),
                get_logit_stat_torch(L),
            ))

        # Move data to CPU for plotting
        attention_score_stats_cpu = []
        for stats_tuple in attention_score_stats:
            cpu_tuple = []
            for method_stats in stats_tuple:
                cpu_method = []
                for stat in method_stats:
                    cpu_method.append(stat.cpu().numpy())
                cpu_tuple.append(cpu_method)
            attention_score_stats_cpu.append(cpu_tuple)

        attention_score_stats = np.array(attention_score_stats_cpu)

        # Prepare data for the third plot (entropy bins)
        T = 1000000
        S = torch.randn(num_trials, T, device=device)
        t = torch.arange(T, dtype=torch.float32, device=device)
        m_s, a_s = m_func(t), a_func(t)
        L = S * a_s + m_s

        S_logn = S * s_ssmax * torch.log(torch.tensor(T, dtype=torch.float32, device=device))

        Ts = [1, 10, 100, 1000, 10000, 100000, 1000000]
        entropys_scale_free = []
        entropys_logn = []
        entropys_standard = []

        for i in range(len(Ts) - 1):
            t_start, t_end = Ts[i], Ts[i + 1]
            local_prob_scale_free = F.softmax(L[..., t_start:t_end], dim=-1)
            local_prob_logn = F.softmax(S_logn[..., t_start:t_end], dim=-1)
            local_prob_standard = F.softmax(S[..., t_start:t_end], dim=-1)
            H_standard = entropy(local_prob_standard)
            H_logn = entropy(local_prob_logn)
            H_scale_free = entropy(local_prob_scale_free)
            entropys_standard.append(H_standard.mean(0).cpu().numpy())
            entropys_logn.append(H_logn.mean(0).cpu().numpy())
            entropys_scale_free.append(H_scale_free.mean(0).cpu().numpy())

        # Save the computed data
        np.savez(gaussian_data_path,
                 attention_score_stats=attention_score_stats,
                 entropys_standard=entropys_standard,
                 entropys_logn=entropys_logn,
                 entropys_scale_free=entropys_scale_free)

    # Plot the metrics for Row 1
    for i in range(3):
        # 1. Global entropy (first plot)
        axs[0, 0].plot(Ns, attention_score_stats[:, i, 0].mean(-1), label=method_name[i],
                     color=method_colors[method_name[i]], **plot_config)

        # 2. Attention in first 100 tokens (second plot)
        axs[0, 2].plot(Ns, attention_score_stats[:, i, 1].mean(-1),
                     color=method_colors[method_name[i]], **plot_config)

    # Plot entropy bins in the third subplot
    axs[0, 1].plot(
        entropys_standard,
        marker='o', markersize=2, color=method_colors['No scale']
    )
    axs[0, 1].plot(
        entropys_logn,
        marker='o', markersize=2, color=method_colors['LogN']
    )
    axs[0, 1].plot(
        entropys_scale_free,
        marker='o', markersize=2, color=method_colors['Scale-invariant']
    )

    # ===== Row 2: Real Data (Second Code) =====
    # Load data for the second code
    run_id_sfa = 'sfa_and_p_rope_4k_s5__60fd'
    run_id_logn_p_rope = 'logn_trick_and_p_rope_learnS_4k_s5__35b8'
    run_id_ntk = 'p_rope_4k_s5__989b'

    # Check if we have saved real data
    real_data_path = data_dir / 'real_attention_stats.npz'

    max_power = 6
    Ts_real = [10**i for i in range(0, max_power+1)]
    sizes = [100, 1000, 10_000, 100_000]
    size_names = ['100', '1000', '10000', '100000']

    if real_data_path.exists():
        print("Loading saved real data...")
        attention_score_stats_real, entropys_standard_real, entropys_scale_free_real, entropys_logn_real = load_real_data_stats(data_dir)
    else:
        print("Computing real data...")
        first_n_tokens = 100
        head_ix = 3
        T = 100_000

        attention_score_stats_real, entropys_standard_real, entropys_scale_free_real, entropys_logn_real = compute_real_data_stats(
            run_id_sfa, run_id_logn_p_rope, run_id_ntk, sizes, size_names,
            first_n_tokens, head_ix, Ts_real, T
        )

        # Save the computed data
        save_real_data_stats(data_dir, attention_score_stats_real,
                            entropys_standard_real, entropys_scale_free_real,
                            entropys_logn_real)
        print(f"Saved processed statistics to {real_data_path}")

    # Plot the metrics for Row 2
    for i in range(3):  # Three methods in second plot
        color = method_colors[method_name[i]]

        # 1. Global entropy (first plot)
        axs[1, 0].plot(sizes, attention_score_stats_real[:, i, 0].mean(-1),
                     label=method_name[i], color=color, **plot_config)

        # 2. Attention in first 100 tokens (second plot)
        axs[1, 2].plot(sizes, attention_score_stats_real[:, i, 1].mean(-1),
                     label=method_name[i], color=color, **plot_config)

    # Plot entropy bins in the third subplot with updated colors
    axs[1, 1].plot(
        entropys_standard_real,
        marker='o', markersize=2, color=method_colors['No scale']
    )
    axs[1, 1].plot(
        entropys_scale_free_real,
        marker='o', markersize=2, color=method_colors['Scale-invariant']
    )
    axs[1, 1].plot(
        entropys_logn_real,
        marker='o', markersize=2, color=method_colors['LogN']
    )

    # ===== Format all plots =====
    # Row titles - placed on the right side and rotated 90 degrees
    fig.text(0.988, 0.74, "IID Gaussian", fontsize=axis_label_font_size, rotation=270, ha='center', va='center')
    fig.text(0.988, 0.32, "Real Logits", fontsize=axis_label_font_size, rotation=270, ha='center', va='center')

    # Column titles and formatting for all plots
    titles = ['Global Entropy', 'Entropy of bins', 'Attention to\nlatest 100 tokens']

    # Define consistent axis limits for each column
    # Column 0: Global Entropy
    entropy_xlim = (100, 1000000)  # x-axis for both rows, column 0
    entropy_ylim = (0, 15)         # y-axis for both rows, column 0

    # Column 1: Entropy of bins
    bin_ylim = (0, 15)              # y-axis for both rows, column 1

    # Column 2: Attention to first 100 tokens
    attn_xlim = (100, 1000000)     # x-axis for both rows, column 2
    attn_ylim = (1e-4, 1.0)        # y-axis for both rows, column 2

    for row in range(2):
        for col in range(3):
            ax = axs[row, col]

            # Set title only for top row
            if row == 0:
                ax.set_title(titles[col], fontsize=title_font_size, pad=2)
                # Remove x-axis tick labels on first row
                ax.set_xticklabels([])

            # Set log scale for x-axis on first and last plots
            if col == 0 or col == 2:
                # Both rows use base 10 for log scale
                ax.set_xscale('log', base=10)

                # Only add x-axis label on bottom row
                if row == 1:
                    ax.set_xlabel('Context size', fontsize=axis_label_font_size)
                else:
                    # Remove x-axis tick labels for top row
                    ax.set_xticklabels([])

                # Set consistent x-axis limits for columns 0 and 2
                if col == 0:
                    ax.set_xlim(entropy_xlim)
                    ax.set_ylim(entropy_ylim)
                elif col == 2:  # col == 2
                    ax.set_yscale('log', base=10)
                    ax.set_xlim(attn_xlim)
                    ax.set_ylim(attn_ylim)

            # Set consistent y-axis limits for column 1
            if col == 1:
                ax.set_ylim(bin_ylim)

            # Format tick labels
            ax.tick_params(axis='both', labelsize=ticks_font_size)

            if col == 1 or col == 0: # set entropy yticks
                ax.set_yticks([0, 5, 10, 15])
            # Format x-ticks for the middle column with 45 degree rotation
            if col == 1:
                if row == 0:
                    # First row
                    labels = [r'$[10^{%d}, 10^{%d})$' % (i, i+1) for i in range(6)]
                    ax.set_xticks(range(6))
                    # Set empty labels for first row
                    ax.set_xticklabels([])
                else:
                    # Second row
                    labels = [r'$[10^{%d}, 10^{%d})$' % (i, i+1) for i in range(max_power)]
                    ax.set_xticks(range(len(labels)))
                    ax.set_xticklabels(labels, fontsize=ticks_font_size, rotation=25, ha='right')

    # Add legend to the bottom left plot
    axs[0, 0].legend(fontsize=5, markerscale=0.8, loc='upper left')

    plt.savefig('./entropy.pdf')

    return fig

def save_real_data_stats(data_dir, attention_score_stats_real, entropys_standard_real, entropys_scale_free_real, entropys_logn_real):
    """Save real data statistics to disk"""
    np.savez(data_dir / 'real_attention_stats.npz',
             attention_score_stats_real=attention_score_stats_real,
             entropys_standard_real=entropys_standard_real,
             entropys_scale_free_real=entropys_scale_free_real,
             entropys_logn_real=entropys_logn_real)

def load_real_data_stats(data_dir):
    """Load real data statistics from disk"""
    data = np.load(data_dir / 'real_attention_stats.npz')
    return (data['attention_score_stats_real'],
            data['entropys_standard_real'],
            data['entropys_scale_free_real'],
            data['entropys_logn_real'])

def compute_real_data_stats(run_id_sfa, run_id_logn_p_rope, run_id_ntk, sizes, size_names, first_n_tokens, head_ix, Ts_real, T):
    """Compute real data statistics from attention samples"""
    attn_samples_sfa = {}
    attn_samples_logn_p_rope = {}
    attn_samples_ntk = {}

    for size, size_name in zip(sizes, size_names):
        attn_samples_sfa[size] = torch.load(f'logs/{run_id_sfa}/attn_samples_{size_name}.pt', weights_only=False)['layer_1']
        attn_samples_logn_p_rope[size] = torch.load(f'logs/{run_id_logn_p_rope}/attn_samples_{size_name}.pt', weights_only=False)['layer_1']
        attn_samples_ntk[size] = torch.load(f'logs/{run_id_ntk}/attn_samples_{size_name}.pt', weights_only=False)['layer_1']

    attn_samples_rope = attn_samples_ntk

    attention_score_stats_real = []
    for N in sizes:
        S = attn_samples_rope[N][:first_n_tokens, head_ix, :].float().numpy()
        L = attn_samples_sfa[N][:first_n_tokens, head_ix, :].float().numpy()
        S_logn = attn_samples_logn_p_rope[N][:first_n_tokens, head_ix, :].float().numpy()
        attention_score_stats_real.append((
            get_logit_stat_numpy(S, local_window_size=first_n_tokens),
            get_logit_stat_numpy(S_logn, local_window_size=first_n_tokens),
            get_logit_stat_numpy(L, local_window_size=first_n_tokens),
        ))
    attention_score_stats_real = np.array(attention_score_stats_real)

    # Compute entropy bins
    S = attn_samples_rope[T][:100, head_ix, :].float().numpy()
    L = attn_samples_sfa[T][:100, head_ix, :].float().numpy()
    S_logn = attn_samples_logn_p_rope[T][:100, head_ix, :].float().numpy()

    entropys_scale_free_real = []
    entropys_standard_real = []
    entropys_logn_real = []

    for i in range(len(Ts_real) - 1):
        t_start, t_end = Ts_real[i], Ts_real[i + 1]
        if t_end > T:
            break

        H_standard = entropy_from_logits_no_softmax(S[..., t_start:t_end], -1)
        H_scale_free = entropy_from_logits_no_softmax(L[..., t_start:t_end], -1)
        H_logn = entropy_from_logits_no_softmax(S_logn[..., t_start:t_end], -1)
        entropys_standard_real.append(H_standard.mean(0))
        entropys_scale_free_real.append(H_scale_free.mean(0))
        entropys_logn_real.append(H_logn.mean(0))

    return (attention_score_stats_real, entropys_standard_real,
            entropys_scale_free_real, entropys_logn_real)

# Create the merged figure
fig = create_merged_figure()
# plt.show()