# This script computes the cosine similarity between the norm minimization dynamics and real training dynamics.


import random
import torch
from torch.func import functional_call, jacrev

# Modular Addition
LOG_EVERY = 100
SEED = 45
FRAC_TRAIN = 0.9
P = 11
D_in = P
D_out = P
D_hidden = 128
weight_decay = 1e-4
classification = False
LEARNING_RATE = 1
NUM_STEPS = 210_000
QUEUE_SIZE = 10


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED)

def generate_data_modular_addition():
    def encode_pairs(pairs):
        X = torch.zeros((len(pairs), P))
        Y = torch.zeros((len(pairs), P))
        for i, (a, b) in enumerate(pairs):
            X[i, a] += 1
            X[i, b] += 1
            Y[i, (a + b) % P] = 1
        return X.to(device), Y.to(device)

    def gen_pairs_train_test():
        """Generate train and test split"""
        num_to_generate = P
        pairs = [(i, j) for i in range(num_to_generate) for j in range(i, num_to_generate)]
        random.seed(SEED)
        random.shuffle(pairs)
        div = int(FRAC_TRAIN * len(pairs))
        return pairs[:div], pairs[div:]


    train_pairs, test_pairs = gen_pairs_train_test()
    X, Y = encode_pairs(train_pairs)
    X_test, Y_test = encode_pairs(test_pairs)
    
    X = X.to(device)
    Y = Y.to(device)
    X_test = X_test.to(device)
    Y_test = Y_test.to(device)

    return X, Y, X_test, Y_test


def compute_jacobian(model, X):
    # Turn module into a pure function of its parameters
    params = dict(model.named_parameters())
    buffers = dict(model.named_buffers())

    def f_of_params(params_dict):
        # functional_call runs model with *these* params (no mutation)
        y = functional_call(model, (params_dict, buffers), (X,))   # (N, d_out)
        return y.reshape(-1)                                       # (N*d_out,)

    # Jacobian is a pytree matching params structure
    J_tree = jacrev(f_of_params)(params)

    # Flatten pytree into a 2D (N*d_out, P) matrix
    Nm = X.shape[0] * D_out
    flat_cols = []
    for name, p in params.items():
        Jp = J_tree[name]                # shape (N*d_out, *p.shape)
        flat_cols.append(Jp.reshape(Nm, -1))
    J_theta = torch.cat(flat_cols, dim=1) # (N*d_out, P)

    return J_theta

X, Y, X_test, Y_test = generate_data_modular_addition()

print(f"Train data shape: {X.shape}, Train labels shape: {Y.shape}, Test data shape: {X_test.shape}, Test labels shape: {Y_test.shape}")

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, D_hidden, bias=False),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(D_hidden, D_out, bias=False),
)

model = model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, weight_decay=weight_decay)
criterion = torch.nn.MSELoss()

model.train()

print(f"Training on {device}")
print(f"X shape: {X.shape}, Y shape: {Y.shape}, W1 shape: {model[0].weight.shape}")


def gradient_estimate(model):
    # Project to the zero-loss set
    model_overfit = torch.nn.Sequential(
        torch.nn.Linear(D_in, D_hidden, bias=False),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(D_hidden, D_out, bias=False),
    ).to(device)
    model_overfit[0].weight.data = model[0].weight.data.clone()
    model_overfit[2].weight.data = model[2].weight.data.clone()
    optimizer_overfit = torch.optim.SGD(model_overfit.parameters(), lr=LEARNING_RATE, weight_decay=0, momentum=0.9)

    for i in range(100):
        optimizer_overfit.zero_grad()
        output_overfit = model_overfit(X)
        loss_overfit = criterion(output_overfit, Y)
        loss_overfit.backward()
        optimizer_overfit.step()
    
    print(f"Overfit loss after 100 steps: {loss_overfit.item():.06f}")

    # Compute Jacobian on the zero-loss set
    J_theta = compute_jacobian(model_overfit, X)
    
    # Project the norm to the zero-loss set (nullspace of J_theta)    
    theta = torch.cat([model[0].weight.data.view(-1), model[2].weight.data.view(-1)], dim=0)
    grad_est = project_to_nullspace(J_theta, -theta)

    return grad_est

def project_to_nullspace(M, v):
    v_null = v - torch.linalg.pinv(M) @ M @ v
    return v_null

old_W_flat = torch.cat([model[0].weight.data.view(-1), model[2].weight.data.view(-1)], dim=0)


print("Num parameters:", old_W_flat.shape[0])
print("Dataset size:", Y.shape[0] * Y.shape[1])

fout = open(f"seed_{SEED}.txt", "w")
fout.write("Step\tLoss\tTrainAcc\tTestLoss\tTestAcc\tWnorm\tCosineSim\n")

old_W_queue = []

for step in range(1, NUM_STEPS):
    old_W_flat = torch.cat([model[0].weight.data.view(-1), model[2].weight.data.view(-1)], dim=0)
    old_W_queue.append(old_W_flat.clone())
    
    if len(old_W_queue) > QUEUE_SIZE:
        old_W_queue.pop(0)

    optimizer.zero_grad()
    output = model(X)
    loss = criterion(output, Y)
    loss.backward()
    optimizer.step()

    if step < 100:
        LOG_EVERY = 1
    elif step < 1000:
        LOG_EVERY = 10
    elif step < 10_000:
        LOG_EVERY = 100
    elif step < 100_000:
        LOG_EVERY = 1000
    else:
        LOG_EVERY = 10_000

    if step % LOG_EVERY == 0:
        model.eval()
        Y_pred = model(X)
        Y_pred_test = model(X_test)
        model.train()

        test_loss = criterion(Y_pred_test, Y_test)
        train_acc = (Y_pred.argmax(dim=1) == Y.argmax(dim=1)).float().mean().item()
        test_acc = Y_pred_test.argmax(dim=1).eq(Y_test.argmax(dim=1)).float().mean().item()

        wnorm = model[0].weight.data.norm().item() + model[2].weight.data.norm().item()

        print(f"Step {step}, Loss: {loss.item():.06f}, Train: {train_acc:.04f}, Test: {test_acc:.04f}, ||W||: {wnorm:.04f}, queue size: {len(old_W_queue)}")

        grad_est = gradient_estimate(model)
        
        W_flat = torch.cat([model[0].weight.data.view(-1), model[2].weight.data.view(-1)], dim=0)
        update = W_flat - old_W_queue[0]
        cosine_sim = torch.nn.functional.cosine_similarity(update, grad_est, dim=0)
        print(f"Cosine similarity between grad and projected grad: {cosine_sim.item():.4f}")
        fout.write(f"{step}\t{loss.item():.6f}\t{train_acc:.4f}\t{test_loss:.4f}\t{test_acc:.4f}\t{wnorm:.4f}\t{cosine_sim.item():.4f}\n")
        fout.flush()