import math
import random
import torch

# Constants
SEED = 42
TRAIN_SPLIT_RATIO = 0.7  # Fraction of data used for training
P = 37  # Prime number used for modular arithmetic
HIDDEN_DIM = 512  # Hidden layer dimension, must be >= input size for overparametrization
LOG_INTERVAL = 10  # Logging frequency
SAVE_INTERVAL = 10  # Model saving frequency

# Set random seed for reproducibility
random.seed(SEED)


def encode_pairs(pairs):
    """
    Encode pairs (a, b) into frequency vectors X and one-hot vectors Y.

    X: Frequency encoding of a and b.
    Y: One-hot encoding of (a + b) % P.
    """
    X = torch.zeros((len(pairs), P))
    Y = torch.zeros((len(pairs), P))
    for i, (a, b) in enumerate(pairs):
        X[i, a] += 1
        X[i, b] += 1
        Y[i, (a + b) % P] = 1
    return X, Y


def generate_train_test_pairs():
    """
    Generate pairs (a, b) where 0 <= a <= b < P and split into training and testing sets.
    """
    pairs = [(i, j) for i in range(P) for j in range(i, P)]
    random.shuffle(pairs)
    split_idx = int(TRAIN_SPLIT_RATIO * len(pairs))
    return pairs[:split_idx], pairs[split_idx:]


# Generate train and test sets
train_pairs, test_pairs = generate_train_test_pairs()
X_train, Y_train = encode_pairs(train_pairs)
X_test, Y_test = encode_pairs(test_pairs)

# Define input and output dimensions
N, INPUT_DIM = X_train.shape
OUTPUT_DIM = Y_train.shape[1]

# Convert to PyTorch tensors
X_train = X_train.to(torch.float32)
Y_train = Y_train.to(torch.float32)
X_test = X_test.to(torch.float32)
Y_test = Y_test.to(torch.float32)

# Initialize weights for first layer (INPUT_DIM x HIDDEN_DIM)
W1 = torch.randn(INPUT_DIM, HIDDEN_DIM, requires_grad=True) / math.sqrt(INPUT_DIM)

# Save initial weights
torch.save(W1, "W1_init.pth")

log_file = open("log.txt", "w")

# W1 = torch.load("W1.pth")

# Training hyperparameters
step_size = 1e-3
num_iterations = 5_000

for iteration in range(num_iterations):
    try:
        # Forward pass: compute hidden layer activations
        A = X_train @ W1  # Linear transformation
        H = torch.relu(A)  # Apply ReLU activation

        # Compute Moore-Penrose pseudo-inverse of H
        H_pinv = torch.linalg.pinv(H)

        # Compute optimal output weights W2 (least-norm solution)
        W2 = H_pinv @ Y_train

        if iteration % LOG_INTERVAL == 0 or iteration < 20:
            # Compute predictions
            Y_pred_train = H @ W2
            Y_pred_test = torch.relu(X_test @ W1) @ W2

            # Compute loss and accuracy
            train_loss = torch.nn.functional.mse_loss(Y_pred_train, Y_train)
            train_accuracy = (Y_pred_train.argmax(dim=1) == Y_train.argmax(dim=1)).float().mean()
            test_loss = torch.nn.functional.mse_loss(Y_pred_test, Y_test)
            test_accuracy = (Y_pred_test.argmax(dim=1) == Y_test.argmax(dim=1)).float().mean()

            # Compute weight norms
            W1_norm = torch.norm(W1)
            W2_norm = torch.norm(W2)

            # Print progress
            print(
                f"Iter {iteration+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, "
                f"Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.4f}, "
                f"W1 norm: {W1_norm:.4f}, W2 norm: {W2_norm:.4f}"
            )

            # Log progress
            log_file.write(f"{iteration+1},{train_loss:.4f},{train_accuracy:.4f},{test_loss:.4f},{test_accuracy:.4f},{W1_norm:.4f},{W2_norm:.4f}\n")

        if iteration % SAVE_INTERVAL == 0:
            # Save model parameters
            torch.save(W1, "W1.pth")
            torch.save(W2, "W2.pth")

        # Compute weight regularization loss: ||W1||² + ||W2||²
        loss = torch.norm(W1, p="fro") ** 2 + torch.norm(W2, p="fro") ** 2

        # Compute W1 gradient using backpropagation-like approach
        G = H @ H.T
        G_inv = torch.linalg.pinv(G)
        W1_grad = W1 - X_train.T @ ((H > 0).to(torch.float32) * (G_inv @ Y_train @ Y_train.T @ G_inv @ H))

        # Gradient descent step
        W1 = W1 - step_size * W1_grad

    except RuntimeError:
        print(f"Failed to compute pseudo-inverse at iteration {iteration}, skipping update.")
        W1 += torch.randn_like(W1) * 1e-4  # Add noise to recover numerical stability
        continue
