from copy import copy

import numpy as np
import torch

from theory.ideal_weights_utils import make_E_S_biweighted
from theory.theory import EmbeddingType, create_embedding_matrix
from utils.common import print_array_shapes


def _orthonormal_columns(m, n, rng):
    """Q in R^{m x n} with orthonormal columns (m >= n)."""
    A = rng.standard_normal((m, n))
    Q, _ = np.linalg.qr(A, mode="reduced")
    return Q  # m x n, Q^T Q = I_n

def _normalize_columns(A, eps=1e-12):
    """Scale each column of A to unit ℓ2 norm (no-op for near-zero)."""
    nrm = np.linalg.norm(A, axis=0)
    nrm = np.where(nrm < eps, 1.0, nrm)
    return A / nrm

def make_E_S(V, D, N, seed=None, as_tensors=True):
    """
    Construct E in R^{D x V} and S in R^{N x D} such that:
      diag(E^T E) = 1 and diag((S E)^T (S E)) = 1.
    Requires D >= N. V is arbitrary (V >= 1).
    """
    if D < N:
        raise ValueError(f"Need D >= N; got D={D}, N={N}.")
    rng = np.random.default_rng(seed)

    # S with orthonormal rows: build Q (D x N) with orthonormal columns, then S = Q^T.
    Q = _orthonormal_columns(D, N, rng)   # D x N, Q^T Q = I_N
    S = Q.T                                # N x D, S S^T = I_N (rows orthonormal)

    # Choose Z (N x V) with unit-norm columns, then set E = S^T Z.
    Z = _normalize_columns(rng.standard_normal((N, V)))  # N x V, unit column norms
    E = S.T @ Z  # D x V, columns lie in range(S^T) so norms are preserved

    if as_tensors:
        E = torch.as_tensor(E, dtype=torch.float)
        S = torch.as_tensor(S, dtype=torch.float)

    return E, S


def create_scaled_QR_embedding_matrix(
        V: int, M: int,
        norm: float = 1,
        normalize_by_mean: bool = False,
):
    assert (V >= M) and (M >= 2)

    # 1. Random tall matrix and QR decomposition
    G = torch.randn(V, M)
    Q, _ = torch.linalg.qr(G, mode='reduced')  # (V, D) with orthonormal columns

    # 2. Scale each row to match desired norms
    row_norms = torch.linalg.norm(Q, dim=1)

    # 3. Normalize
    if normalize_by_mean:
        E = Q * (np.sqrt(norm) / row_norms.mean())
    else:
        E = Q * (np.sqrt(norm) / row_norms).unsqueeze(1)

    # 4. Transpose
    E = E.T

    return E


def create_scaled_by_parts_QR_embedding_matrix(
        V: int, M: int,
        V_a: int,
        a_norm: float,
        b_norm: float,
        normalize_by_mean: bool = True,
):
    assert (V >= M) and (M >= 2)

    # 1. Random tall matrix and QR decomposition
    G = torch.randn(V, M)
    Q, _ = torch.linalg.qr(G, mode='reduced')  # (V, D) with orthonormal columns

    # 2. Scale each row to match desired norms
    row_norms = torch.linalg.norm(Q, dim=1)


    if normalize_by_mean:
        E_a = Q[:V_a] * (np.sqrt(a_norm) / row_norms[:V_a].mean())
        E_b = Q[V_a:] * (np.sqrt(b_norm) / row_norms[V_a:].mean())
    else:
        E_b = Q[V_a:] * (np.sqrt(b_norm) / row_norms[V_a:]).unsqueeze(1)
        E_a = Q[:V_a] * (np.sqrt(a_norm) / row_norms[:V_a]).unsqueeze(1)


    # 3. concat and transpose
    E = torch.vstack([E_a, E_b])

    return E.T, E_a.T, E_b.T


def create_scaled_QR_for_SE(E: torch.Tensor, N: int, norm: float = 1.0, eps: float = 1e-12):
    """
    Given E ∈ R^{D×V}, build S ∈ R^{N×D} so that (S @ E) is a QR-style
    embedding in R^{N×V} with unit (sqrt(norm)) column norms.
    We construct the target inside row(E) so that S@E matches exactly.
    """
    device = E.device
    D, V = E.shape
    assert N <= D, "Need N ≤ rank(E) ≤ D for an exact construction."

    # 1) Row-space(E) via SVD: E = U Σ Vh, where Vh ∈ R^{D×V} spans row-space(E)
    U, Svals, Vh = torch.linalg.svd(E, full_matrices=False)  # U: D×D, Svals: D, Vh: D×V

    # 2) Sample a random V×N inside row-space: Qsub = (Vh^T) @ R, then QR for orthonormal columns
    R = torch.randn(D, N, device=device)                     # coefficients in the row-space basis
    Qsub = Vh.T @ R                                          # V×N, lives in row-space(E)
    Q, _ = torch.linalg.qr(Qsub, mode='reduced')             # V×N with orthonormal columns

    # 3) Equalize row norms (like your create_scaled_QR_embedding_matrix) and transpose
    row_norms = torch.linalg.norm(Q, dim=1)                  # length V
    scale = (np.sqrt(norm) / (row_norms + eps)).unsqueeze(1) # V×1
    P = (Q * scale).T                                        # N×V target so that all columns have √norm

    # 4) Solve S E = P  ->  S = P E^+  (exact since P ∈ row-space(E); pinv is fine)
    S = P @ torch.pinverse(E)                                # N×D

    # Sanity: S@E should be ≈ P
    # print(torch.norm(S @ E - P) / torch.norm(P))

    return S, P


def optimize_E_and_S_embeddings(
        V, D, N,
        seed=0,
        s_k=0.8, s_v=1.2,
        c_s=3.2,
        noise_alpha=0.3,
        return_as_tensors=True,
        device='cpu',
        dtype=torch.float,
):
    """
    Construct embedding (E) and projection (S) matrices.

    Parameters
    ----------
    V : int
        Vocabulary size (must be even).
    D : int
        Embedding dimension.
    N : int
        Projected dimension (< D).
    seed : int, optional
        Random seed.
    s_k, s_v, c_s, noise_alpha : floats, optional
        Scaling and noise factors used to adjust the diagonal/off-diagonal ratios.

    Returns
    -------
    E : ndarray, shape (D, V)
        Embedding matrix whose first V//2 columns are keys and
        the rest are values.
    S : ndarray, shape (N, D)
        Projection matrix.
    G_E : ndarray, shape (V, V)
        Gram matrix EᵀE.
    G_SE : ndarray, shape (V, V)
        Gram matrix (S E)ᵀ (S E).
    """

    if not (2 <= N <= D <= V):
        raise ValueError("Dimensions must satisfy 2 <= N <= D <= V.")

    rnd = np.random.RandomState(seed)

    V_k = V // 2
    V_v = V - V_k

    # 1. Random orthonormal basis: split into key and value subspaces
    W = rnd.randn(D, D)
    Q, _ = np.linalg.qr(W)           # Q is orthonormal
    W_k = Q[:, :N]                   # key subspace basis (D×N)
    W_v = Q[:, N:]                   # value subspace basis (D×(D−N))

    # 2. Random unit vectors in each subspace
    Z_k = rnd.randn(N, V_k)
    Z_k /= np.linalg.norm(Z_k, axis=0, keepdims=True)
    Z_v = rnd.randn(D - N, V_v)
    Z_v /= np.linalg.norm(Z_v, axis=0, keepdims=True)

    # 3. Build E: keys and values live in orthogonal subspaces, scaled by s_k and s_v
    E_k = (W_k @ Z_k) * s_k          # D×(V/2)
    E_v = (W_v @ Z_v) * s_v          # D×(V/2)
    E = np.concatenate([E_k, E_v], axis=1)

    # 4. Build S: aligned with key subspace, with small noise
    base_S = (c_s / s_k) * W_k.T     # N×D
    if noise_alpha > 0:
        noise = rnd.randn(N, D)
        S = base_S + noise_alpha * noise
    else:
        S = base_S

    # 5. Gram matrices
    G_E = E.T @ E
    G_SE = (S @ E).T @ (S @ E)

    # 6. Tensors
    if return_as_tensors:
        E = torch.tensor(E, device=device, dtype=dtype)
        S = torch.tensor(S, device=device, dtype=dtype)
        G_E = torch.tensor(G_E, device=device, dtype=dtype)
        G_SE = torch.tensor(G_SE, device=device, dtype=dtype)

    return E, S, G_E, G_SE


def create_ideal_model_weights(
        V: int, D: int, N: int,
        print_output_shapes: bool = False,
        device='cpu',
        kv_ratio=4.0,
):

    assert (V >= D) and (D >= N) and (N >= 2)

    # dims
    V_k = V // 2
    V_v = V - V_k

    # baseline E, S embeddings
    E = create_scaled_QR_embedding_matrix(V=V, M=D, norm=1)  # (D, V)
    S = create_scaled_QR_embedding_matrix(V=D, M=N, norm=1)  # (N, D)

    # # E, S = make_E_S(V, D, N, seed=0)
    # E, S = make_E_S_biweighted(V, D, N, r=kv_ratio, seed=0)


    # # E embedding
    # # k_diag = 0.5  # desired squared norm for keys
    # # v_diag = 1.0  # desired squared norm for values
    # k_diag = 1
    # v_diag = 2
    # E, E_k, E_v = create_scaled_by_parts_QR_embedding_matrix(
    #     V=V, M=D,
    #     V_a=V_k,
    #     a_norm=k_diag, b_norm=v_diag,
    # )
    #
    # # S embedding
    # # S = create_scaled_QR_embedding_matrix(V=D, M=N, norm=1)  # (N, D)
    # s_diag = 1
    # S = s_diag * create_scaled_QR_embedding_matrix(V=D, M=N, norm=1)  # (N, D)
    # # S, P_target = create_scaled_QR_for_SE(E=E, N=N, norm=1.0)
    # # S, P_target = create_scaled_QR_for_SE(E=E_k, N=N, norm=1.0)
    # # S_full, S_high_half, S_low_half = create_scaled_by_parts_QR_embedding_matrix(
    # #     V=2*D, M=N,
    # #     V_a=D,
    # #     a_norm=1.0, b_norm=0.1,
    # # )

    # # E, S embeddings
    # E, S, G_E, G_SE = optimize_E_and_S_embeddings(
    #     V, D, N,
    #     seed=0,
    #     s_k=0.8, s_v=1.2,
    #     c_s=3.2,
    #     # noise_alpha=0.3,
    #     noise_alpha=0.,
    #     return_as_tensors=True,
    #     device=device,
    # )

    # in/out embeddings
    E_in = copy(E)  # (D, V)
    E_out = E.T  # (V, D)

    # helpers
    I_D = torch.eye(D)
    Z_D = torch.zeros((D, D))
    Z_ND = torch.zeros((N, D))
    I_D_col = torch.ones(D)
    Z_D_col = torch.zeros(D)

    # in/out projections
    P_in = torch.hstack([I_D, I_D]).T  # (2D, D)
    P_out = torch.hstack([Z_D, I_D])  # (D, 2D)

    # conv1d weights
    W_p = torch.hstack([I_D_col, Z_D_col])
    W_c = torch.hstack([Z_D_col, I_D_col])
    W = torch.vstack([W_p, W_c]).T

    # write/read/output (k/q/v) projectors
    S_B = torch.hstack([S, Z_ND])  # (N, 2D)
    S_C = torch.hstack([Z_ND, S])  # (N, 2D)

    # print
    if print_output_shapes:
        print_array_shapes(['E_in', 'P_in', 'W', 'S_B', 'S_C', 'P_out', 'E_out'])

    return E_in, P_in, W, S_B, S_C, P_out, E_out


def _create_ideal_model_weights(
        V: int, D: int, N: int,
        print_output_shapes: bool = False,
):

    assert (V >= D) and (D >= N) and (N >= 2)

    # dims
    Vk = V // 2
    Vv = V - Vk

    # generate embeddings
    embedding_type = EmbeddingType.SCALED_GAUSSIAN_QR
    E = create_embedding_matrix(V, D, embedding_type).T  # (D, V)
    S = create_embedding_matrix(D, N, embedding_type).T  # (N, D)

    # in/out embeddings
    E_in = copy(E)
    E_out = E.T

    # helpers
    I_V = torch.eye(V)
    Z_V = torch.zeros((V, V))
    I_D = torch.eye(D)
    Z_D = torch.zeros((D, D))
    I_D_col = torch.ones(D)
    Z_D_col = torch.zeros(D)
    I_Vk_col = torch.ones(Vk)
    Z_Vk_col = torch.zeros(Vk)
    I_Vv_col = torch.ones(Vv)
    Z_Vv_col = torch.zeros(Vv)

    # extractors: k/v and prev/curr
    M_p = torch.hstack([I_V, Z_V])  # (V, 2V) – picks the curr half of xi_t
    M_c = torch.hstack([Z_V, I_V])  # (V, 2V) – picks the prev half of xi_t
    M_k = torch.diag(torch.hstack([I_Vk_col, Z_Vv_col]))  # (V, V) – zeros a non-key x_t
    M_v = torch.diag(torch.hstack([Z_Vk_col, I_Vv_col]))  # (V, V) – zeros a non-value x_t

    # compressed extractors
    M_p_ = torch.hstack([I_D, Z_D])  # (D, 2D)
    M_c_ = torch.hstack([Z_D, I_D])  # (D, 2D)

    M_k_ = E @ M_k @ E.T  # (D, D)
    M_v_ = E @ M_v @ E.T  # (D, D)

    # write/read/output (k/q/v) projectors
    S_B = S @ M_k_ @ M_p_  # (N, 2D)
    S_C = S @ M_k_ @ M_c_  # (N, 2D)
    P_out = M_v_ @ M_c_  # (V, 2D)

    # rest of the projections are more straight-forward
    P_in = torch.hstack([I_D, I_D]).T
    W_p = torch.hstack([I_D_col, Z_D_col])
    W_c = torch.hstack([Z_D_col, I_D_col])
    W = torch.vstack([W_p, W_c]).T

    return E_in, P_in, W, S_B, S_C, P_out, E_out


def __create_ideal_model_weights(V: int, D: int, N: int):

    assert (V >= D) and (D >= N) and (N >= 2), "Require V >= D >= N >= 2"

    # k/v dims
    Vk = V // 2
    Vv = V - Vk

    # helpers
    I_D = torch.eye(D)
    Z_D = torch.zeros((D, D))
    Z_ND = torch.zeros((N, D))
    I_Vk_col = torch.ones(Vk)
    Z_Vk_col = torch.zeros(Vk)
    I_Vv_col = torch.ones(Vv)
    Z_Vv_col = torch.zeros(Vv)

    # Masks in vocabulary space (not directly used, but here for clarity)
    M_k = torch.diag(torch.hstack([I_Vk_col, Z_Vv_col]))  # key mask
    M_v = torch.diag(torch.hstack([Z_Vk_col, I_Vv_col]))  # value mask

    # random orthonormal embeddings/compressions, E: V -> D, S: D -> N
    embedding_type = EmbeddingType.SCALED_GAUSSIAN_QR
    E = create_embedding_matrix(V, D, embedding_type).T  # (D, V)
    S = create_embedding_matrix(D, N, embedding_type).T  # (N, D)

    # in/out embeddings
    E_in = copy(E)  # (D, V)
    E_out = E.T  # (V, D)

    # in/out projections
    P_in = torch.vstack([I_D, I_D])  # (2D, D)  duplicate the D-dimensional embedding for curr/prev in a 2D vector
    P_out = torch.hstack([Z_D, I_D])  # (D, 2D)  extract the previous token’s value embedding (second D rows)

    # Indicators for which rows of the 2D state are curr or prev
    W = torch.zeros((2 * D, 2))
    W[:D, 0] = 1.0  # current
    W[D:, 1] = 1.0  # previous

    # S_B: compress the current token’s query (first D rows) to N dimensions
    S_B = torch.hstack([S, Z_ND])   # (N, 2D) – picks current token, compress to N
    S_C = torch.hstack([Z_ND, S])   # (N, 2D) – picks previous token, compress to N

    return E_in, P_in, W, S_B, S_C, P_out, E_out


def ___create_ideal_model_weights(
        V: int, D: int, N: int,

        lambda_vv: float = 8, lambda_kv: float = 4,
        lambda_kk: float = 2, lambda_vk: float = 1,

        # lambda_vv: float = 16, lambda_kv: float = 8,
        # lambda_kk: float = 4, lambda_vk: float = 2,

        # lambda_vv: float = 2, lambda_kv: float = 1,
        # lambda_kk: float = 1, lambda_vk: float = 0,
):

    assert (V >= D) and (D >= N) and (N >= 2)

    # split D into key/value dims
    k_dim = D // 2
    v_dim = D - k_dim

    # generate embeddings (separate bases for key and value dimensions)
    embedding_type = EmbeddingType.SCALED_GAUSSIAN_QR
    # E_k_base = create_embedding_matrix(V, k_dim, embedding_type).T  # (k_dim,V)
    # E_v_base = create_embedding_matrix(V, v_dim, embedding_type).T  # (v_dim,V)
    E_k_base = create_scaled_QR_embedding_matrix(V, D)  # (k_dim,V)
    E_v_base = create_scaled_QR_embedding_matrix(V, D)  # (v_dim,V)
    S = create_scaled_QR_embedding_matrix(D, N)  # (k_dim,V)

    assert lambda_kv >= lambda_kk
    assert lambda_vv >= lambda_vk

    # scale parameters
    s_kv = np.sqrt(lambda_kk)
    s_kk = np.sqrt(lambda_vk)
    s_vv = np.sqrt(lambda_kv - lambda_kk)
    s_vk = np.sqrt(lambda_vv - lambda_vk)
    half_V = V // 2

    # per-token scaling
    scale_k_in = torch.ones(V)
    scale_v_in = torch.ones(V)
    scale_k_in[:half_V] = s_kv
    scale_k_in[half_V:] = s_kk
    scale_v_in[:half_V] = s_vv
    scale_v_in[half_V:] = s_vk

    # scale_k_in = torch.ones(V)
    # scale_v_in = 2 * torch.ones(V)

    # apply scaling and stack to form E_in (D,V)
    E_k = E_k_base * scale_k_in
    E_v = E_v_base * scale_v_in
    # E_in = torch.vstack([E_k, E_v])  # (50,120)
    E_in = E_k + E_v
    E_out = E_in.T

    # build the rest of the projections exactly as before
    I_D = torch.eye(D)
    P_in = torch.hstack([I_D, I_D]).T
    W_p = torch.hstack([torch.ones(D), torch.zeros(D)])
    W_c = torch.hstack([torch.zeros(D), torch.ones(D)])
    W = torch.vstack([W_p, W_c]).T

    k_mask = torch.tensor([1] * k_dim + [0] * v_dim)
    k_proj = torch.diag(k_mask).float()
    c_proj = torch.hstack([I_D, torch.zeros_like(I_D)])
    p_proj = torch.hstack([torch.zeros_like(I_D), I_D])
    S_B = S @ k_proj @ c_proj
    S_C = S @ k_proj @ p_proj
    P_out = torch.hstack([torch.zeros((D, D)), torch.eye(D)])

    return E_in, P_in, W, S_B, S_C, P_out, E_out
