from sorcerun.git_utils import get_repo
import numpy as np
import torchvision.transforms as transforms, torchvision, torch
import sys
import torch
from transformers import AutoModelForCausalLM

sys.path.append(f"{get_repo().working_dir}")
from utils import export

MATRIX_DISTRIBUTIONS = {}


@export(MATRIX_DISTRIBUTIONS)
def unif(c, d, eps):

    diag = torch.diag(torch.rand(d) * 2 - 1)
    diag = torch.where(torch.abs(diag) < eps, torch.sign(diag) * eps, diag)
    RM = torch.randn(d, d)
    # Perform QR decomposition
    Q, R = torch.linalg.qr(RM)
    # If det(Q) < 0, flip the sign of one column
    if torch.det(Q) < 0:
        Q[:, 0] = -Q[:, 0]

    A = Q @ diag @ Q.T
    return A


@export(MATRIX_DISTRIBUTIONS)
def wishart(c, d, eps):

    X = torch.randn(int(c * d), d)
    A = X.T @ X
    A = A / 3/d + eps * torch.eye(d)

    return A


@export(MATRIX_DISTRIBUTIONS)
def wishart_unif(c, d, eps):

    X = torch.rand(int(c * d), d) * 2 - 1
    A = 3*X.T @ X
    A = A / 3/d + eps * torch.eye(d)

    return A


@export(MATRIX_DISTRIBUTIONS)
def CIFAR(c, d, eps):

    trainset = torchvision.datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        ),
    )
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=d, shuffle=True)
    images, _ = next(iter(trainloader))

    images = torch.reshape(images, (d, -1))
    col_means = images.mean(dim=0, keepdim=True)   # [1, image_size]
    images = images - col_means
    
    I = torch.eye(3072)
    A = (images.T @ images) / d
    A = A / torch.linalg.norm(A, "fro")
    A = A + eps * I

    return A

@export(MATRIX_DISTRIBUTIONS)
def Erdos_Renyi(c, d, eps):

    # Normalized Graph Laplacian of Erdos-Renyi graph, max eigenvalue set to 1 scaled.

    upper = torch.triu(torch.rand(d, d), diagonal=1)
    A = (upper < c).float()  # Random Erdos-Renyi graph

    # Make adjacency symmetric (undirected graph)
    A = A + A.T

    # Degree matrix
    degrees = torch.sum(A, dim=1)
    degrees = torch.where(degrees > 0, 1.0 / torch.sqrt(degrees), torch.zeros_like(degrees))
    sqrt_inv_D = torch.diag(degrees)

    # Laplacian: L = D - A
    I = torch.eye(A.size(0), device=A.device)
    L = I - sqrt_inv_D @ A @ sqrt_inv_D

    return (L + eps * torch.eye(d))/(2+eps)  # Shift to avoid numerical issues


def rectangular(c, d, eps):

    A = torch.randn(int(c * d), d)
    A = A / np.sqrt(4 * d)

    return A


@export(MATRIX_DISTRIBUTIONS)
def lm_head(c, d, eps):
    ## d not used.

    model = AutoModelForCausalLM.from_pretrained("gpt2")
    lm_head_weight = model.lm_head.weight  # shape: [vocab_size, hidden_dim]

    vocab_size, hidden_dim = lm_head_weight.shape
    print(hidden_dim)
    batch = int(c * hidden_dim)

    # Sample d rows
    indices = torch.randperm(vocab_size)[:batch]
    X = lm_head_weight[indices]
    mean_vector = torch.mean(X, axis=0)
    X = X - mean_vector
    I = torch.eye(hidden_dim)
    return X.T @ X / torch.linalg.norm(X.T @ X) + eps * I ## Normalized covariance, adding eps for stability.


# %% more interesting for computing sign
@export(MATRIX_DISTRIBUTIONS)
def spiked_wishart(c, d, eps, rank=5, spike_var=10.0):
    # Noise part
    X = torch.randn(int(c * d), d)
    A = X.T @ X / X.shape[0]
    # Low‑rank spike
    U = torch.randn(d, rank)
    spike = spike_var / rank * (U @ U.T)
    A = A + spike
    # Shift away from 0
    return A / torch.linalg.norm(A, "fro") + eps * torch.eye(d)


@export(MATRIX_DISTRIBUTIONS)
def logreg_hessian(c, d, eps):
    # Synthetic binary classification with Gaussian features
    n = int(c * d)
    X = torch.randn(n, d)
    w = torch.randn(d)  # off‑optimum weights
    logits = X @ w
    p = torch.sigmoid(logits)
    S = torch.diag((p * (1 - p)))  # diag of variances
    H = X.T @ S @ X / n
    return H + eps * torch.eye(d)


@export(MATRIX_DISTRIBUTIONS)
def block_pm_identity(c, d, eps):
    sizes = torch.randint(2, 8, (d // 5,))  # random block sizes
    blocks = []
    sign = 1
    for s in sizes:
        block = sign * torch.eye(int(s))
        blocks.append(block)
        sign *= -1
    A = torch.block_diag(*blocks)
    return A + eps * torch.eye(A.shape[0])

@export(MATRIX_DISTRIBUTIONS)
def quartic_saddle(c, d, eps, a=1.0):
    """
    Symmetric indefinite matrix that is exactly the Hessian of
        F(z) = Σ_i (z_i^4 / 4 - a z_i^2 / 2)
    evaluated at a random point z.

    Parameters
    ----------
    c   : unused (kept for interface compatibility)
    d   : int, matrix dimension
    eps : float, positive shift to keep numerics safe
    a   : float, quadratic coefficient; bigger a → more negatives
    """
    # 1. sample a random point z  ~ N(0,1)
    z = torch.randn(d)

    # 2. Diagonal Hessian in z-basis : H_diag = diag(3 z_i^2 - a)
    h_diagonal = 3.0 * z.pow(2) - a
    H_diag = torch.diag(h_diagonal)
    H_diag = torch.where(torch.abs(H_diag) < eps, torch.sign(H_diag) * eps, H_diag)

    # 3. Random orthogonal similarity transform so eigenvectors are mixed
    Q, _ = torch.linalg.qr(torch.randn(d, d))
    if torch.det(Q) < 0:
        Q[:, 0] = -Q[:, 0]

    H = Q @ H_diag @ Q.T

    # 4. Scaled to be less than 1
    return H/torch.linalg.norm(H)