# -*- coding: utf-8 -*-
"""Copy of ellipsoid_notebook_1.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1xpJ5HkM-_EhQ5frVV9PQshGjmphPuWCU

# Setup
"""

# Install dependencies
# !pip install einops

# @title Imports
import transformers
import einops
import torch
from tqdm.auto import trange

device = 'cuda' if torch.cuda.is_available() else 'cpu'

"""# Model and dataset loading"""

model_name = "EleutherAI/pythia-14m"
model_name = "roneneldan/TinyStories-1M"

model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float64).to(device)
print(model)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

USE_LAYER_NORM_EPS = True  # Set to False for debugging and idealized transformer.

if not USE_LAYER_NORM_EPS:
    model.gpt_neox.final_layer_norm.eps *= 0

hidden_size = model.config.hidden_size
vocab_size = model.config.vocab_size

# (hidden_size+1) choose 2 for matrix elements, hidden_size for learning the
# center term, and 32 for extra buffer for linear independece

# We assume that hidden_size was stolen correctly by the method presented in the paper
num_equations = ((hidden_size + 1) * hidden_size // 2) + hidden_size + 32
num_equations = num_equations + 8000

Q = torch.zeros((vocab_size, num_equations), dtype=torch.float64)  # Same as Q matrix from paper;
H = torch.zeros((hidden_size, num_equations), dtype=torch.float64)  # Same as Q matrix from paper;
# $ Q \in \mathbb{R}^{ l \times n } $

model.config.hidden_size

# Collect logprobs/logits for random sequences

tokenizer.pad_token = tokenizer.eos_token

batch_size = 4  # Can increase on big GPU
logprob_samples_per_sequence = 8

assert num_equations % logprob_samples_per_sequence == 0
num_sequences = num_equations // logprob_samples_per_sequence
assert num_sequences % batch_size == 0
sequence_length = 16

inputs = torch.randint(
    low=10,  # Tokens 0-9 could be weird
    high=vocab_size-10,  # Tokens -1 to -9 could be weird, too
    size=(num_sequences, sequence_length),
)

for minibatch_idx in trange(0, num_sequences, batch_size):
  tokens = inputs[minibatch_idx:minibatch_idx+batch_size]
  outs = model(tokens.to(device), output_hidden_states=True)
  logits = einops.rearrange(
      outs.logits[:,-logprob_samples_per_sequence:],
      "Batch Seq Vocab -> (Batch Seq) Vocab"
  )
  hiddens = einops.rearrange(
      outs.hidden_states[-1][:,-logprob_samples_per_sequence:],
      "Batch Seq Hidden -> (Batch Seq) Hidden"
  )
  # TODO(conmy): make this logprobs version
  Q[:, minibatch_idx*logprob_samples_per_sequence:(minibatch_idx+batch_size)*logprob_samples_per_sequence] = logits.T.detach()
  H[:, minibatch_idx*logprob_samples_per_sequence:(minibatch_idx+batch_size)*logprob_samples_per_sequence] = hiddens.T.detach()

import numpy as np
assert Q.shape == (vocab_size, num_equations)
assert H.shape == (hidden_size, num_equations)
np.savez("data/carlini/logits.npz", hidden=H.T.numpy())
print("Saved")

Q = (Q - Q[0,:])[1:]
# Q = Q[:, :-8000] # Remove extra logprobs

# Important to be double for SVD precision
# We shift the ellipsoid to pass through the origin
Q_origin = (Q - Q[:, 0:1])[:, 1:].to(device)

if device == 'cuda':
  raw_U, raw_S, raw_Vh = torch.linalg.svd(Q_origin[:, :hidden_size + 32], full_matrices=False, driver='gesvd')  #  32 for extra buffer, driver='gesvd' for high precision
else:
  raw_U, raw_S, raw_Vh = torch.linalg.svd(Q_origin[:, :hidden_size + 32], full_matrices=False)  # keyword argument `driver=` is only supported on CUDA inputs with cuSOLVER backend.

# See the paper, main text Algorithm 1
pred_dim = (torch.argmax(torch.abs(torch.diff(torch.log(raw_S)))) + 1).item()

assert pred_dim in [hidden_size, hidden_size - 1]
if pred_dim == hidden_size:
  print("You should only see this if the model is using RMSNorm rather than LayerNorm")

U, S, Vh = raw_U[:,:pred_dim], torch.diag(raw_S[:pred_dim]), raw_Vh[:pred_dim]

assert U.shape == (vocab_size-1, pred_dim)
assert S.shape == (pred_dim, pred_dim)
assert Vh.shape == (pred_dim, hidden_size + 32)

assert torch.max(torch.abs(U @ S @ Vh - Q_origin[:, :hidden_size + 32])).item() < 1e-7, "Truncated SVD should reconstruct the logits matrix!"

X = U.T @ Q_origin
assert X.shape == (pred_dim, num_equations - 1)

# only use if GPU is out of VRAM!
# X = X.cpu()

# Construct equations

# General Ellipsoid equation (x-c)^T * A * (x-c) = 1; A is symmetric Positive Semidefinite matrix.
# x^TAx -2xA^Tc + c^TAc = 1
# Trick1: Linearization, we would fit A and d=Ac as if they are independent.
# Trick2: By substracting Q[0], our ellipsoid passes through the origin, hence c^TAc=1.
# x^TAx -2x^T*d = 0 impose a set of linear equations over A & d, with (h+1 C 2) + (h+1) variables.
# we will find the nullspace of this system of equations using SVD.
# later we would enforce c^TAc = 1 by scaling the fitted/predicted A, c.

eqs = []
upper_diag_indices = tuple(torch.triu_indices(len(X), len(X), 1))

for i in trange(num_equations - 1):

    eq = torch.outer(X.T[i], X.T[i])
    coeffs_for_A_upper = eq[upper_diag_indices]
    coeffs_for_A_diagonal = eq.diag()
    coeffs_for_d = -2 * X.T[i]

    coeffs = torch.cat([coeffs_for_A_upper, coeffs_for_A_diagonal, coeffs_for_d], dim=0)
    eqs.append(coeffs)

eqs = torch.stack(eqs)
print(f"Constraints x Variables: {tuple(eqs.shape)}")

# Commented out IPython magic to ensure Python compatibility.
# %%time
# 
sol_space = torch.linalg.svd(eqs) # This is the method's bottleneck ~ O(hidden_dim ^ 6)
actual_sol = sol_space.Vh[-1] # expected to be the nullspace basis vector

delta_log_singular_values = torch.abs(torch.diff(torch.log(sol_space.S[-2:]))).item()
# assert delta_log_singular_values > 3, "Expecting a nullspace of dim=1 -> smallest eigenvalue is much much smaller than the next one"

actual_sol = actual_sol.cpu()

# organize solution per variable, taking into account the symmetry of A.
sol_A_upper = actual_sol[:len(coeffs_for_A_upper)]
sol_A_diag = actual_sol[len(coeffs_for_A_upper): len(coeffs_for_A_upper) + len(coeffs_for_A_diagonal)]
sol_d = actual_sol[-len(coeffs_for_d):]

sol_A_reshaped = torch.zeros(len(X), len(X)).double()
sol_A_reshaped[tuple(torch.triu_indices(X.shape[0], X.shape[0], 1))] = sol_A_upper
sol_A_reshaped[torch.arange(len(sol_A_reshaped)), torch.arange(sol_A_reshaped.shape[-1])] = sol_A_diag
sol_A_reshaped = (sol_A_reshaped+sol_A_reshaped.T) / 2

sol_A_reshaped_inv = torch.linalg.inv(sol_A_reshaped)
sol_centered = sol_A_reshaped_inv @ sol_d
factor = sol_d @ sol_A_reshaped_inv @ sol_d # d^T*A^-1*d = c^T*A*c = factor

factored_d = sol_d / factor
factored_A = sol_A_reshaped / factor

"""## Verify the solution at hand"""

if "pythia" in model_name:
    true_W = model.embed_out.weight.double().detach()
    gamma = model.gpt_neox.final_layer_norm.weight.double().detach()
    beta = model.gpt_neox.final_layer_norm.bias.double().detach()
elif "Tiny" in model_name:
    true_W = model.lm_head.weight.double().detach()
    gamma = model.transformer.ln_f.weight.double().detach()
    beta = model.transformer.ln_f.bias.double().detach()

# true_dim = model.config.hidden_size
# true_W_gamma = true_W @ torch.diag(gamma * true_dim**0.5)
# inv_true_A = U.T @ true_W_gamma @ true_W_gamma.T @ U # follows from the equations of Appendix G.

# true_A = torch.linalg.inv(inv_true_A).cpu()
# true_ellipsoid_center = (true_W @ beta).cpu()

Ut = U.T.cpu()
Xt = X.T.cpu()

M = torch.linalg.cholesky(factored_A)
# O = M.T @ Ut @ true_W_gamma.cpu()


pred_ellipsoid_center = (sol_centered @ Ut) + Q[:,0].cpu()
x_min_c_T = Xt - torch.Tensor(sol_centered)

# class Ellipse:
#     up_proj: Num[Array, "emb-1 vocab-1"]
#     bias: Num[Array, "vocab-1"]
#     rot1: None | Num[Array, "emb-1 emb-1"]
#     stretch: Num[Array, "emb-1"]
#     rot2: Num[Array, "emb-1 emb-1"]
bias = pred_ellipsoid_center.numpy()
rot2, stretch, rot1 = np.linalg.svd(M.numpy())
up_proj = Q_origin @ np.linalg.pinv(X) 

np.savez("data/carlini/carlini_pred.npz", bias=bias, rot1=rot1, rot2=rot2, stretch=stretch, up_proj=up_proj.T)

np.testing.assert_allclose(bias, true_ellipsoid_center)
true_U, true_S, true_Vh = np.linalg.svd(true_W)
np.testing.assert_allclose(up_proj @ rot2 @ np.diag(stretch) , true_U @ np.diag(true_S))




# @title Success criteria
# We think the numeric instabiltiy marginally affects the precision of the results.
# For example the matrix O approximates an orthonormal matrix, since O^TO is roughly Identity, but some entries differ by 0.01.

assert torch.allclose((x_min_c_T @ factored_A @ x_min_c_T.T).diag(), torch.ones(1).double())
assert torch.mean(torch.abs(pred_ellipsoid_center - true_ellipsoid_center)) < 1
assert torch.allclose(factored_A, torch.DoubleTensor(true_A))
assert torch.dist(O @ O.T - torch.eye(pred_dim)).item() < 1e-2
