import os
import pickle
import torch
from tqdm import tqdm
import argparse
from utils import seed_everything, create_mlp, \
    get_batch_sparse_parity, get_batch_staircase, \
    get_batch_gaussian, get_batch_gaussian_sanity, \
    generate_covariance, get_batch_balanced_logistic, create_balanced_logistic_net, \
    get_teacher_student_batch, Custom_BN
from torch.func import hessian, functional_call, vmap, grad, jacrev
from torch.utils._pytree import tree_flatten   # <- here
from functorch import make_functional_with_buffers
import time

import numpy as np
import wandb
import math
import copy

torch.cuda.empty_cache()

def f(model, params, inputs, targets):
    """Base loss function for Hessian computation"""
    y = functional_call(model, params, inputs).squeeze()
    if inputs.shape[0] == 1: # recover the batch dimension
        y = y.unsqueeze(0)
    return loss_fn(y, targets)

def f2(model, params, inputs, targets):
    """Loss function for per-sample gradient computation (used in AdaGrad variants)"""
    y = functional_call(model, params, inputs.unsqueeze(0)).squeeze()
    if inputs.shape[0] == 1: # recover the batch dimension
        y = y.unsqueeze(0)
    return loss_fn(y, targets)

def f3(model, params, inputs):
    """Forward function for Gauss-Newton computation - returns model output directly"""
    y = functional_call(model, params, inputs.unsqueeze(0)).squeeze()
    if inputs.shape[0] == 1: # recover the batch dimension
        y = y.unsqueeze(0)
    return y


def per_sample_grad(model, params, inputs, targets):
    """Computes per-sample (flattened) gradients for AdaGrad variants
    -  Return shape: (bs, mn)
    """
    def temp(model, params, inputs, targets):
        temp_grad = grad(f2, argnums=(1))(model, params, inputs, targets)
        for key in temp_grad.keys():
            temp_grad[key] = temp_grad[key].flatten()
        return temp_grad
    return vmap(temp, in_dims=(None, None, 0, 0))(model, params, inputs, targets)

def per_sample_grad_2(model, params, inputs, targets):
    """ Unflattened per-sample gradients
    -  Return shape: (bs, m, n)
    """
    return vmap(grad(f2, argnums=(1)), in_dims=(None, None, 0, 0))(model, params, inputs, targets)


def per_output_grad(model, params, inputs):
    """Computes per-output gradients for Gauss-Newton variants"""
    """
    - Shape for each weight: (bs, mn)
    """
    def temp(model, params, inputs):
        temp_grad = grad(f3, argnums=(1))(model, params, inputs)
        for key in temp_grad.keys():
            temp_grad[key] = temp_grad[key].flatten()
        return temp_grad

    val = vmap(temp, in_dims=(None, None, 0))(model, params, inputs)
    if args.loss == 'logistic': # multiply by y(1-y)
        # y: of shape (bs,)
        y = functional_call(model, params, inputs.unsqueeze(0)).squeeze(-1)
        # run through sigmoid
        p_y = torch.sigmoid(y)
        hessian_scaling = ((p_y * (1 - p_y))**0.5).squeeze() # of shape (bs,)
        for key in val.keys():
            if val[key].dim() == 3:
                val[key] = val[key] * hessian_scaling.unsqueeze(1).unsqueeze(1)
            elif val[key].dim() > hessian_scaling.dim():
                val[key] = val[key] * hessian_scaling.unsqueeze(1)
            else:
                val[key] = val[key] * hessian_scaling
    return val



def per_output_jacobian(model, params, inputs):
    """
       Computes per-output gradients for Gauss-Newton variants,
       for a function with multi-dimensional output.
    - Shape for each weight: (bs, d_out, mn)
    """
    def temp(model, params, inputs):
        temp_grad = jacrev(f3, argnums=(1))(model, params, inputs)
        for key in temp_grad.keys():
            temp_grad[key] = temp_grad[key].flatten(start_dim=1)
        return temp_grad

    val = vmap(temp, in_dims=(None, None, 0))(model, params, inputs)
    if args.loss == 'logistic': # multiply by y(1-y)
        # y: of shape (bs,)
        y = functional_call(model, params, inputs.unsqueeze(0)).squeeze(-1)
        p_y = torch.sigmoid(y)
        hessian_scaling = ((p_y * (1 - p_y))**0.5).squeeze()
        for key in val.keys():
            if hessian_scaling.dim() == 1:
                if val[key].dim() > hessian_scaling.dim():
                    val[key] = val[key] * hessian_scaling.unsqueeze(1)
                else:
                    val[key] = val[key] * hessian_scaling
            else:
                val[key] = torch.einsum('bnd,bn->bnd', val[key], hessian_scaling)
    return val


def per_output_grad_2(model, params, inputs):
    """
    - Shape for each weight: (bs, m, n)
    """
    val = vmap(grad(f3, argnums=(1)), in_dims=(None, None, 0))(model, params, inputs)
    if args.loss == 'logistic': # multiply by y(1-y)
        # y: of shape (bs,)
        y = functional_call(model, params, inputs.unsqueeze(0)).squeeze(-1)
        # run through sigmoid
        p_y = torch.sigmoid(y)
        hessian_scaling = ((p_y * (1 - p_y))**0.5).squeeze()
        for key in val.keys():
            if val[key].dim() == 3:
                val[key] = val[key] * hessian_scaling.unsqueeze(1).unsqueeze(1)
            else:
                val[key] = val[key] * hessian_scaling.unsqueeze(1)
    return val



def per_output_jacobian_2(model, params, inputs):
    """
       Computes per-output gradients for Gauss-Newton variants,
       for a function with multi-dimensional output.
    - Shape for each weight: (bs, d_out, m, n)
    """
    val = vmap(jacrev, in_dims=(None, None, 0))(model, params, inputs)
    if args.loss == 'logistic': # multiply by y(1-y)
        # y: of shape (bs,)
        y = functional_call(model, params, inputs.unsqueeze(0)).squeeze(-1)
        # run through sigmoid
        p_y = torch.sigmoid(y)
        hessian_scaling = ((p_y * (1 - p_y))**0.5).squeeze()
        for key in val.keys():
            if hessian_scaling.dim() == 1:
                if val[key].dim() > hessian_scaling.dim():
                    val[key] = val[key] * hessian_scaling.unsqueeze(1)
                else:
                    val[key] = val[key] * hessian_scaling
            else:
                val[key] = torch.einsum('bnd,bn->bnd', val[key], hessian_scaling)
    return val


def torch_logm_orthogonal(Q):
    """
    Compute the principal matrix-log of a real orthogonal Q via its eigendecomposition.
    Returns a skew-Hermitian K such that exp(K)=Q.
    """
    # 1) eigendecompose Q (returns complex outputs on GPU if Q is on CUDA)
    eigvals, eigvecs = torch.linalg.eig(Q)            # eigvals: (n,), eigvecs: (n,n)

    # 2) principal log of each eigenvalue on the unit circle
    log_eigvals = torch.log(eigvals)                  # complex

    # 3) rebuild K = V diag(log λ) V^H
    inv_vecs   = eigvecs.conj().transpose(-2, -1)     # since V is unitary
    K_complex  = eigvecs @ torch.diag(log_eigvals) @ inv_vecs

    # 4) enforce exact skew-Hermitian (numerical cleanup)
    K = 0.5 * (K_complex - K_complex.conj().transpose(-2, -1))
    return K

def geodesic_interpolate(basis: torch.Tensor, alpha: float) -> torch.Tensor:
    """
    Geodesic interp on O(n): γ(α)=exp(α log(basis)).
    Fully on GPU if `basis` is on CUDA.
    """
    dim_basis, device = basis.shape[0], basis.device
    if alpha == 0:
        return torch.eye(dim_basis, device=device)
    if alpha == 1:
        return basis

    # compute logm on GPU
    K = torch_logm_orthogonal(basis)                  # complex skew‐Hermitian

    # exponentiate, then take real part (should be exactly real)
    M_complex = torch.linalg.matrix_exp(alpha * K)
    return M_complex.real


parser = argparse.ArgumentParser(description='Hessian Calculation')
# Task to run
parser.add_argument('--task', '-task', default='parity', type=str, help='Task to run')
# Input dimension for parity task
parser.add_argument('--n', '-n', default=50, type=int, help='Input dimension for parity task')
# Number of relevant input bits for parity task
parser.add_argument('--k', '-k', default=5, type=int, help='Number of relevant input bits for parity task')
parser.add_argument('--multidim_output', '-multidim_output', default=False, action='store_true',
                    help='Whether the function has a multi-dimensional output (e.g. balanced logistic w/ n>1.)')
parser.add_argument('--noise_rank', '-noise_rank', default=0, type=int, help='Number of noise dimensions for Gaussian data')
parser.add_argument('--noise_sigma', '-noise_sigma', default=0, type=float, help='Noise sigma for Gaussian data')
parser.add_argument('--act_fn', '-act_fn', default='relu', type=str,
                    help='Activation function to use in Gaussian data')
parser.add_argument('--poly_degree', '-poly_degree', default=2, type=int,
                      help='Degree of polynomial for Gaussian data, only used if act_fn is poly_degree')
parser.add_argument('--alphas', '-alphas', default="none", type=str,
                      help='Alpha parameter for balanced_logistic task')
parser.add_argument('--init_scale', '-init_scale', default=1.0, type=float,
                      help='Scale factor for weight initialization in balanced_logistic task')
### Model specific parameters
parser.add_argument('--model', '-model', default='mlp', type=str, help='Model to use')
parser.add_argument('--depth', '-depth', default=1, type=int, help='Depth of MLP')
parser.add_argument('--width', '-width', default=200, type=int,
                    help='Width of hidden layers in MLP')
parser.add_argument('--input_weight_scale', '-input_weight_scale', default=0, type=float,
                    help='Weight scale for the 1st layer')
parser.add_argument('--weight_scale', '-weight_scale', default=0, type=float,
                    help='Weight scale for all layers')
parser.add_argument('--logistic_net_vector_version', default=0, type=int, help='Version of vectorized balanced logistic task')
parser.add_argument('--logistic_proj_diagonal', default=0, type=int, help='Whether to project weight matrices to be diagonal')

# Random seed for reproducibility
parser.add_argument('--seed', '-seed', default=17, type=int, help='Random seed for reproducibility')
# Batch size for training
parser.add_argument('--bs', '-bs', default=128, type=int, help='Batch size for training')
parser.add_argument('--val_batch_size', '-val_batch_size', default=5000, type=int, help='Batch size for validation')
# Minimum batch size for Hessian computation
parser.add_argument('--min_hess_bs', '-min_hess_bs', default=1024, type=int, help='Minimum batch size for Hessian computation')
# Activation function to use in MLP
parser.add_argument('--act', '-act', default='relu', type=str, help='Activation function to use in MLP')
# Loss function to use
parser.add_argument('--loss', '-loss', default='mse', type=str, help='Loss function to use')
# Learning rate
parser.add_argument('--lr', '-lr', default=0.01, type=float, help='Learning rate')
parser.add_argument("--lr_schedule", default="constant", type=str, help="Learning rate schedule to use")
parser.add_argument("--step_constant_until", default=0, type=int, help="Step constant until this step")
parser.add_argument("--lr_min", default=0.0001, type=float, help="Minimum learning rate")
parser.add_argument("--lr_max", default=1, type=float, help="Maximum learning rate")
parser.add_argument("--step_size", default=10, type=int, help="Step size for step schedule")
parser.add_argument("--step_size_factor", default=0.5, type=float, help="Step size for step schedule")
parser.add_argument("--step_size_increase_factor", default=1, type=float, help="Step size increase factor for step schedule")
parser.add_argument("--decay_inverse_power", default=1, type=float, help="Decay inverse power for inverse schedule")
# Number of training iterations
parser.add_argument('--iters', '-iters', default=10000, type=int, help='Number of training iterations')
# Regularization coefficient for preconditioner
parser.add_argument('--reg_coeff', '-reg_coeff', default=1e-4, type=float, help='Regularization coefficient for preconditioner')
parser.add_argument('--adaptive_reg_quantile', default=0, type=float, help='Adaptive regularization coefficient based on quantile of eigenvalues')
parser.add_argument('--adaptive_reg_condNum', default=0, type=float, help='Adaptive regularization coefficient based on condition number of Hessian')
# Power scaling for eigenvalues in preconditioner
parser.add_argument('--power', '-power', default=-0.5, type=float, help='Power scaling for eigenvalues in preconditioner')
parser.add_argument('--batch_norm_power', type=float, default=1)
# Optimization method to use
parser.add_argument('--opt', '-opt', default='sgd', type=str, 
                   choices=['sgd','adam', 'GN-diag-v1'
                           'GN-layer', 'GN-adam-layer',
                           'GN-layer-kron', 'GN-adam-layer-kron'],
                   help='Optimization method to use')
# Beta2 parameter for Adam-style methods
parser.add_argument('--beta2', '-beta2', default=0.95, type=float, help='Beta2 parameter for Adam-style methods')
# Whether to rotate input data
parser.add_argument('--rotate', '-rotate', default=False, action='store_true', help='Whether to rotate input data')
# Whether to use 1D preconditioner
parser.add_argument('--precond_1d', '-precond_1d', default=False, action='store_true', help='Whether to use 1D preconditioner')
# Whether to use per-sample gradient computation
parser.add_argument('--per_sample_grad', '-per_sample_grad', default=False, action='store_true', help='Whether to use per-sample gradient computation')
# Whether to use per-sample basis computation
parser.add_argument('--per_sample_basis', '-per_sample_basis', default=False, action='store_true', help='Whether to use per-sample basis computation')
parser.add_argument('--loss_scaled', type=int, default=0, help='Whether to scale the preconditioner by the loss**power')
# Whether to use batch normalization
parser.add_argument('--use_bn', '-use_bn', default=False, action='store_true', help='Whether to use batch normalization')
# Whether to use hidden bias
parser.add_argument('--use_hidden_bias', type=int, default=1, help='Whether to use hidden bias')
# Whether to skip output layer
parser.add_argument('--skip_output_layer', type=int, default=0, help='Whether to skip output layer')
# Whether to reuse batch
parser.add_argument('--reuse_batch', type=int, default=0, help='Whether to reuse the same batch for hessian & gradient estimation.')
# Number of samples for gradient/basis estimation (-1 means use batch size)
parser.add_argument('--num_samples', '-num_samples', default=-1, type=int, help='Number of samples for gradient/basis estimation (-1 means use batch size)')
parser.add_argument('--use_samples', '-use_samples', default=0, type=int, help='Whether to use samples (vs full population) for balanced logistic task')
parser.add_argument("--label_pm1", default=1, type=int, help="Whether to use the binary label of -1/1 (label_pm1=1) or 1/0 (label_pm1=0)")
# Whether to add spike to input data
parser.add_argument('--add_spike', '-add_spike', default=0, type=float, help='Whether to add spike to the Gaussian input data')
# Whether to freeze initial preconditioner
parser.add_argument('--freeze_init', '-freeze_init', default=False, action='store_true', help='Whether to freeze initial preconditioner')
# Whether to estimate Kronecker factorization gap
parser.add_argument('--est_kron_gap', '-est_kron_gap', default=False, action='store_true', help='Whether to estimate Kronecker factorization gap')
parser.add_argument('--log_weights', '-log_weights', default=False, action='store_true', help='Whether to log weights')
parser.add_argument('--log_spectra', '-log_spectra', default=False, action='store_true', help='Whether to log spectra')
parser.add_argument("--interpolate_basis", default=0, type=float, help="Interpolating between the current basis and the identity.")
parser.add_argument("--geodesic_interpolation", default=1, type=float, help="Geodesic interpolation parameter between identity and GN basis (0=identity, 1=full GN).")
parser.add_argument("--random_basis", default=0, type=int, help="Random basis.")
parser.add_argument("--balance_p", default=0.5, type=float, help="Balance parameter for balanced logistic task.")
parser.add_argument("--fix_x", default=0, type=int, help="Whether to fix the input for eigenbasis_attention_like task.")
parser.add_argument("--H_sqrt_filename", default='', type=str, help="Filename for H_sqrt for gaussian task.")
parser.add_argument('--hidden_dim', default=50, type=int, help='Hidden dimension for eigenbasis network')
parser.add_argument('--output_dim', default=1, type=int, help='Output dimension for eigenbasis network')
parser.add_argument('--spectrum_type', default='log_decay', type=str, 
                    choices=['random', 'log_decay', 'poly_decay', 'exponential_decay'], help='Type of eigenvalue decay for eigenbasis task')
parser.add_argument('--poly_decay_degree', default=2, type=int, help='Degree for polynomial decay for eigenbasis task')
parser.add_argument('--exponential_decay_base', default=2, type=int, help='Base for exponential decay for eigenbasis task')
parser.add_argument('--target_scales', default=None, type=str, help='Target scales for eigenbasis network')
parser.add_argument('--wandb_tag', '-wandb_tag', default='', type=str, help='Wandb tag')
parser.add_argument('--wandb_log_interval', default=1, type=int, help='Wandb log interval')

args = parser.parse_args()
opt_name = args.opt

task_token = args.task
if args.task == 'gaussian': 
    task_token = f"{args.act_fn}"
    if args.act_fn == 'id':
        task_token = f"{task_token}_id"
    if args.add_spike > 0:
        task_token = f"{task_token}_spike{args.add_spike}"

if args.loss == 'mse':
    loss_fn = torch.nn.functional.mse_loss
elif args.loss == 'logistic':
    loss_fn = torch.nn.functional.binary_cross_entropy_with_logits

reg_token = ''
if args.adaptive_reg_quantile > 0:
    reg_token = f"regQuan{args.adaptive_reg_quantile}"
elif args.adaptive_reg_condNum > 0:
    reg_token = f"regCond{args.adaptive_reg_condNum}"
else:
    reg_token = f"reg_coeff{args.reg_coeff}"

wandb_name = f"{opt_name}_{task_token}_n{args.n}_k{args.k}_width{args.width}_depth{args.depth}"
wandb_name += f"_lr{args.lr}"
if args.lr_schedule != "":
    wandb_name += f"_{args.lr_schedule}"
wandb_name += f"_power{args.power}_seed{args.seed}_bs{args.bs}_iters{args.iters}_{reg_token}_beta2{args.beta2}"

wandb.init(
    project=f"second-order-basis-{args.task}{args.wandb_tag}",
    name=wandb_name,
    allow_val_change=True)

if args.wandb_tag != '':
    wandb.config.update({'wandb_tag': args.wandb_tag}, allow_val_change=True)

args_wandb = copy.deepcopy(args.__dict__)
pid = os.getpid()
args_wandb['pid'] = pid
wandb.config.update(args_wandb)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
n = args.n
k = args.k
depth = args.depth
width = args.width
mean_red_factor=args.mean_red_factor
act=args.act 

seed_everything(args.seed)
if args.num_samples == -1:
    args.num_samples = args.bs

use_hidden_bias = args.use_hidden_bias == 1
skip_output_layer = args.skip_output_layer == 1

if args.task == 'teacher_student':
    # Generate the covariance matrix and eigenvectors for the teacher-student task
    seed_everything(72)
    H, eigvals, eigvecs = generate_covariance(args.n, args.spectrum_type)
    H = H.to(device)
    eigvals = eigvals.to(device)
    eigvecs = eigvecs.to(device)
    
    # Create true model for the teacher-student task
    teacher_model = create_mlp(
        n=n,
        width=width,
        depth=depth,
        output_dim=args.output_dim,
        mean_red_factor=args.mean_red_factor,
        act=args.act,
        use_bn=args.use_bn,
        use_hidden_bias=use_hidden_bias,
        skip_output_layer=skip_output_layer
    )
    seed_everything(args.seed)

    # Create student model for the teacher-student task
    model = create_mlp(
        n=n,
        width=4*width,
        depth=depth,
        output_dim=args.output_dim,
        mean_red_factor=args.mean_red_factor,
        act=args.act,
        use_bn=args.use_bn,
        use_hidden_bias=use_hidden_bias,
        skip_output_layer=skip_output_layer
    )

    # Set up problem parameters for the data generation function
    X_DATATYPE = torch.float32
    Y_DATATYPE = torch.float32
    problem_params = {
        'covariance': H,
        'teacher_model': teacher_model,
        'device': device,
        'n': args.n,
        'hidden_dim': args.hidden_dim,
        'output_dim': args.output_dim,
        'X_DATATYPE': X_DATATYPE
    }
elif args.task == 'balanced_logistic':
    assert args.alphas != "none", "alphas must be specified for balanced_logistic task"
    args.alphas = [float(x) for x in args.alphas.split(',')]
    model = create_balanced_logistic_net(n=n, alphas=args.alphas, initialization_scale=args.init_scale,
                                         activation=args.act, vector_version=args.logistic_net_vector_version)
    X_DATATYPE = torch.float32
    Y_DATATYPE = torch.float32
elif args.model == 'mlp':
    model = create_mlp(
        n=n,
        width=width,
        depth=depth,
        output_dim=args.output_dim,
        mean_red_factor=args.mean_red_factor,
        act=args.act,
        use_bn=args.use_bn,
        use_hidden_bias=use_hidden_bias,
        skip_output_layer=skip_output_layer
    )
    X_DATATYPE = torch.float32
    Y_DATATYPE = torch.float32

model = model.to(device)

if args.weight_scale > 0:
    for module in model.modules():
        if isinstance(module, torch.nn.Linear):
            module.weight.data = module.weight.data * args.weight_scale
elif args.input_weight_scale > 0:
    print('input weight scale:', args.input_weight_scale, flush=True)
    model[0].weight.data = model[0].weight.data * args.input_weight_scale
model= model.to(device)

global_acts = {}
global_samples = args.num_samples

def hook_fn_fwd(module, input, output):
#    if len(module.weight.squeeze().shape) == 1:
    global_acts[module] = input[0]

def hook_fn_bwd(module, grad_in, grad_out):
    if isinstance(module, Custom_BN):
        print('global acts shape', global_acts[module].shape, flush=True)
        mean_acts = global_acts[module].mean(dim=0, keepdims=True)
        var_acts = global_acts[module].var(dim=0, keepdims=True, unbiased=False)
        std_acts = torch.sqrt(var_acts + 1e-5)
        print('mean', mean_acts, flush=True)
        print('var', var_acts, flush=True)
        temp_acts = (global_acts[module] - mean_acts)/std_acts
        print('temp acts', temp_acts.shape, flush=True)
        print('grad out', grad_out[0].shape, flush=True)
        print('post mean', temp_acts.mean(dim=0, keepdims=True), flush=True)
        print('post var', temp_acts.var(dim=0, keepdims=True), flush=True)
        module.weight.sample_grad = (grad_out[0]*temp_acts)*grad_out[0].shape[0]
        module.weight.hess = torch.matmul((grad_out[0]*temp_acts).T, (grad_out[0]*temp_acts))*grad_out[0].shape[0]
        module.bias.sample_grad = (grad_out[0])*grad_out[0].shape[0]
        module.bias.hess = torch.matmul((grad_out[0]).T, (grad_out[0]))*grad_out[0].shape[0]
    if len(module.weight.squeeze().shape) == 2:
        module.bias.hess = torch.matmul(grad_out[0].T, grad_out[0])*grad_out[0].shape[0]
        if module.bias.sample_grad is None:
            module.bias.sample_grad = grad_out[0]*grad_out[0].shape[0]
        if module.weight.sample_grad is None:
            module.weight.sample_grad = (torch.matmul(grad_out[0].unsqueeze(-1), global_acts[module].unsqueeze(-2))*grad_out[0].shape[0])
        #module.weight.sample_grad = module.weight.sample_grad[:global_samples]
        #print(module.weight.sample_grad.shape, flush=True)
        ind1 = grad_out[0].shape[0]//3
        if module.weight.grad1 is None:
            module.weight.grad1 = torch.mean(module.weight.sample_grad[:ind1], dim=0, keepdim=True)
            module.weight.grad2 = torch.mean(module.weight.sample_grad[ind1:], dim=0, keepdim=True)
        else:
            module.weight.grad1 = torch.cat([module.weight.grad1, torch.mean(module.weight.sample_grad[:ind1], dim=0, keepdim=True)], dim=0)
            module.weight.grad2 = torch.cat([module.weight.grad2, torch.mean(module.weight.sample_grad[ind1:], dim=0, keepdim=True)], dim=0)
        module.weight.len1 = ind1
        module.weight.len2 = grad_out[0].shape[0] - ind1
        
    if len(module.weight.squeeze().shape) == 1:
        temp_grad = global_acts[module]*grad_out[0]
        if module.weight.sample_grad is None:
            module.weight.sample_grad = temp_grad*grad_out[0].shape[0]
        module.weight.hess = torch.matmul(temp_grad.T, temp_grad)*temp_grad.shape[0]


# When using the parity task
if args.task == 'parity':
    use_pm1 = int(args.model != 'transformer')
    problem_params = {'k': k, 'use_pm1': use_pm1, 'label_pm1': args.label_pm1}
    get_batch = get_batch_sparse_parity
# When using the staircase task
elif args.task == 'staircase':
    num_groups = n//k   
    problem_params = {f'k_{i}': (i*k, (i+1)*k) for i in range(num_groups)}
    get_batch = get_batch_staircase
elif 'gaussian' in args.task:
    problem_params = {'n': n, 'k': k}
    if args.act_fn == 'relu':
        act_fn = torch.nn.ReLU()
    elif args.act_fn == 'sigmoid':
        act_fn = torch.nn.Sigmoid()
    elif args.act_fn == 'sine':
        act_fn = torch.sin
    elif args.act_fn == 'cosine':
        act_fn = torch.cos
    elif args.act_fn == 'poly_degree':
        act_fn = lambda x: x**args.poly_degree
    elif args.act_fn == 'id':
        act_fn = lambda x: x
    
    if args.task == 'gaussian':
        ws = torch.randn(n, k, requires_grad=False) / np.sqrt(n)
        problem_params = {'ws': ws, 'act_fn': act_fn, 'add_spike': args.add_spike}
        if args.H_sqrt_filename != '' and os.path.exists(args.H_sqrt_filename):
            H_sqrt = torch.load(args.H_sqrt_filename)
            problem_params['H_sqrt'] = H_sqrt.float()

        get_batch = get_batch_gaussian
    elif args.task == 'get_batch_gaussian_sanity':
        w = torch.randn(k, requires_grad=False) / np.sqrt(k)
        problem_params = {'w': w, 'noise_rank': args.noise_rank, 'noise_sigma': args.noise_sigma, 'act_fn': act_fn, 'noise': 0}
        get_batch = get_batch_gaussian_sanity

elif args.task == 'balanced_logistic':
    problem_params = {'alphas': args.alphas, 'p': args.balance_p, 'use_samples': args.use_samples}
    get_batch = get_batch_balanced_logistic

else:
    raise ValueError(f"Task {args.task} not supported")

@torch.no_grad()
def class_acc(yhat, y):
    if not args.label_pm1:
        yhat = yhat.sign()
        yhat[yhat<0] = 0 
        return (yhat == y).float().mean().item()
    else:
        return (yhat.sign() == y).float().mean().item()

@torch.no_grad()
def continuous_acc(yhat, y):
    pred = yhat.sign().squeeze()
    target = y.sign().squeeze()
    return (pred == target).float().mean().item()

@torch.no_grad()
def balanced_logistic_acc(yhat, y):
    pred = yhat.round()
    target = y
    return (pred == target).float().mean().item()

@torch.no_grad()
def staircase_acc_general(yhat, y, num_groups=None):
    """
    Computes accuracy for staircase task by checking if prediction is closest to the correct value
    among all possible outcomes based on the number of groups.
    
    Args:
        yhat: Model predictions
        y: True targets
        num_groups: Number of groups in the staircase task. If None, will be inferred from data.
    
    Returns:
        Accuracy as a float between 0 and 1
    """
    # If num_groups is not provided, infer it from the maximum absolute value in y
    if num_groups is None:
        max_abs_value = y.abs().max().item()
        num_groups = int(max_abs_value)
        
        # Verify that max_abs_value is an integer (or very close to one)
        if abs(max_abs_value - num_groups) > 1e-5:
            print(f"Warning: Maximum absolute value in targets ({max_abs_value}) is not an integer. "
                  f"Using {num_groups} as the number of groups.")
    
    # Generate all possible outcomes
    # For n groups, possible values are -n, -n+2, -n+4, ..., n-2, n
    possible_values = torch.arange(-num_groups, num_groups+1, 2, dtype=y.dtype, device=y.device)
    
    # Find the closest possible value for each prediction
    diff = (yhat.unsqueeze(1) - possible_values.unsqueeze(0)).abs()
    closest_value_idx = diff.argmin(dim=1)
    closest_value = possible_values[closest_value_idx]
    
    # Check if the closest value matches the true value
    return (closest_value == y).float().mean().item()

if args.task == 'staircase':
    get_acc = staircase_acc_general

elif 'gaussian' in args.task:
    get_acc = continuous_acc
elif args.task == 'balanced_logistic':
    get_acc = balanced_logistic_acc
else:
    get_acc = class_acc

val_batch_size = args.val_batch_size
val_set = get_batch(n, problem_params, batch_size=val_batch_size, noise=0.0)
val_set = (val_set[0].to(device), val_set[1].to(device))
x_val, y_val = val_set

shapes = []
lengths = []
keys = []

for ind, key in enumerate(model.state_dict()):
    param = model.state_dict()[key]
    shapes.append(param.shape)
    lengths.append(len(param.flatten()))
    keys.append(key)

# Initialize optimizer with initial learning rate
if args.opt == 'adam':
    optimizer = torch.optim.Adam(lr=args.lr, params=model.parameters(), betas=(0.0, args.beta2), eps=args.reg_coeff)
else:
    optimizer = torch.optim.SGD(lr=args.lr, params=model.parameters())

iters = args.iters
bar = tqdm(range(iters))
max_bs_per_hess_iter = min(1024, args.min_hess_bs)
min_hess_bs = args.min_hess_bs
exp_avg_final_hess = {}

if args.rotate:
    Q = torch.nn.init.orthogonal_(torch.randn(n,n)).to(device)
    x_val = torch.matmul(x_val, Q)
losses = []
accs = []
exp_avg_sq = {}

spectrums = {}
spectrums['loss'] = []
spectrums['acc'] = []
spectrums['positive_prop'] = []
if args.log_weights:
    os.makedirs('spectrums', exist_ok=True)
    os.makedirs(f'spectrums/{wandb_name}', exist_ok=True)


HAS_NAN = False
for i in bar:
    try: ## Allow keyboard interrupt     
        opt_lst = ['GN-layer', 'GN-adam-layer',
                   'GN-adam-layer-kron', 'GN-layer-kron', 
                   'GN-diag-v1',
                  ]
        if args.opt in opt_lst and (not args.freeze_init or i == 0):
            total_iters = 0
            curr_bs = 0
            final_hess = None
            while(True):
                x, y = get_batch(n, problem_params, batch_size=max_bs_per_hess_iter)
                x, y = x.type(X_DATATYPE).to(device), y.type(Y_DATATYPE).to(device)
                if args.rotate:
                    x = torch.matmul(x, Q)
                optimizer.zero_grad()
                
                # Gauss-Newton preconditioner computation
                curr_final_hess = {}
                if args.opt in ['GN-adam-layer-kron', 'GN-layer-kron']:
                    # Kronecker-factored Gauss-Newton
                    if args.multidim_output:
                        hess = per_output_jacobian_2(model, model.state_dict(), x)
                    else:
                        hess = per_output_grad_2(model, model.state_dict(), x)

                    for key in hess.keys():
                        # Similar Kronecker factorization as above
                        hess[key] = hess[key].squeeze(-1)
                        if not args.multidim_output:
                            if len(hess[key].shape) > 2: # Kronecker-factored Hessian
                                curr_final_hess[key] = []
                                curr_final_hess[key].append(torch.mean(torch.matmul(hess[key], hess[key].transpose(1,2)), dim=0))
                                curr_final_hess[key].append(torch.mean(torch.matmul(hess[key].transpose(1,2), hess[key]), dim=0))
                                curr_final_hess[key][0] = curr_final_hess[key][0]/torch.trace(curr_final_hess[key][0])
                            else:
                                if args.precond_1d:
                                    curr_final_hess[key] = torch.matmul(hess[key].T, hess[key])/hess[key].shape[0]
                                else:
                                    curr_final_hess[key] = torch.eye(hess[key].shape[1]).to(device)
                        else:
                            if len(hess[key].shape) > 3:
                                raise NotImplementedError("Multidimensional output not supported for Kronecker-factored Hessian")
                            else:
                                curr_final_hess[key] = torch.einsum('bki,bkj->kij', hess[key], hess[key]) / hess[key].shape[0]
                                curr_final_hess[key] = torch.mean(curr_final_hess[key], dim=0)
                else:
                    # Full Gauss-Newton
                    if args.multidim_output:
                        hess = per_output_jacobian(model, model.state_dict(), x)
                    else:
                        hess = per_output_grad(model, model.state_dict(), x)
                    
                    if args.opt == 'GN-diag-v1':
                        for key in hess.keys():
                            if args.multidim_output: 
                                curr_final_hess[key] = torch.diag(torch.mean(hess[key]**2, dim=0).mean(dim=0))
                            else:
                                curr_final_hess[key] = torch.diag(torch.mean(hess[key]**2, dim=0))
                    else:
                        curr_final_hess = {}
                        for key in hess.keys():
                            if args.multidim_output: 
                                curr_final_hess[key] = torch.einsum('bki,bkj->kij', hess[key], hess[key]) / hess[key].shape[0]
                                curr_final_hess[key] = torch.mean(curr_final_hess[key], dim=0)
                            else:
                                curr_final_hess[key] = torch.matmul(hess[key].T, hess[key])/(hess[key].shape[0]**args.batch_norm_power)

                if final_hess is None:
                    final_hess = curr_final_hess
                else:
                    for key in final_hess.keys():
                        if args.opt in ['GN-layer', 'GN-adam-layer']:
                            final_hess[key] += curr_final_hess[key]
                        elif len(hess[key].shape) > 2:
                            final_hess[key][0] += curr_final_hess[key][0]
                            final_hess[key][1] += curr_final_hess[key][1]
                        else:
                            final_hess[key] += curr_final_hess[key]
                        
                total_iters += 1
                curr_bs += max_bs_per_hess_iter
                if curr_bs >= min_hess_bs:
                    for key in final_hess.keys():
                        if len(hess[key].shape) > 2:
                            final_hess[key][0] = final_hess[key][0]/total_iters
                            final_hess[key][1] = final_hess[key][1]/total_iters
                        else:
                            final_hess[key] = final_hess[key]/total_iters
                    break        
            
            reg_coeff = args.reg_coeff
            if 'diag' in args.opt and args.beta2 > 0.0:
                for key in final_hess.keys():
                    if key not in exp_avg_final_hess.keys():
                        exp_avg_final_hess[key] = torch.zeros_like(final_hess[key])
                    exp_avg_final_hess[key] = args.beta2*exp_avg_final_hess[key] + (1-args.beta2)*final_hess[key]
                    final_hess[key] = exp_avg_final_hess[key]/(1-args.beta2**(i+1))
                
            precond = {}
            eig_vecs_layer = {}
            eig_vals = {}
            eig_vals_unreg = {}

            for key in final_hess.keys():
                if not args.multidim_output and len(hess[key].shape) > 2:
                    # For matrix parameters, compute eigendecomposition of both Kronecker factors
                    eig_vecs_layer[key] = []
                    # Left factor eigendecomposition
                    eigs, eig_vecs = torch.linalg.eigh(final_hess[key][0] + reg_coeff*torch.eye(final_hess[key][0].shape[0]).to(device))
                    
                    # Apply geodesic interpolation if requested
                    eig_vecs = geodesic_interpolate(eig_vecs, args.geodesic_interpolation)
                    eigs = torch.diag(torch.matmul(torch.matmul(eig_vecs.T, final_hess[key][0]), eig_vecs)) + reg_coeff

                    eigs = eigs.clamp(min=reg_coeff)  # Ensure numerical stability
                    eigs = eigs**args.power  # Apply power scaling to eigenvalues

                    precond[key] = [torch.matmul(eig_vecs, torch.matmul(torch.diag(eigs), eig_vecs.T))]
                    eig_vecs_layer[key].append(eig_vecs)
                    
                    # Right factor eigendecomposition
                    eigs, eig_vecs = torch.linalg.eigh(final_hess[key][1] + reg_coeff*torch.eye(final_hess[key][1].shape[0]).to(device))
                    
                    # Apply geodesic interpolation if requested
                    eig_vecs = geodesic_interpolate(eig_vecs, args.geodesic_interpolation)
                    eigs = torch.diag(torch.matmul(torch.matmul(eig_vecs.T, final_hess[key][1]), eig_vecs)) + reg_coeff                
                    
                    eigs = eigs.clamp(min=reg_coeff)
                    eigs = eigs**args.power

                    precond[key].append(torch.matmul(eig_vecs, torch.matmul(torch.diag(eigs), eig_vecs.T)))
                    eig_vecs_layer[key].append(eig_vecs)

                    if args.log_spectra:
                        eig_vals[key] = eigs.detach().cpu().numpy()
                        eig_vals_unreg[key] = torch.linalg.eigh(final_hess[key][1])[0].detach().cpu().numpy()
                else:
                    # For vector parameters, compute single eigendecomposition
                    if args.adaptive_reg_quantile > 0:
                      base_reg_coeff = reg_coeff / 2
                      eigs_unreg, _ = torch.linalg.eigh(final_hess[key] + base_reg_coeff*torch.eye(final_hess[key].shape[0]).to(device))
                      reg_coeff = torch.quantile(eigs_unreg, args.adaptive_reg_quantile)
                    elif args.adaptive_reg_condNum > 0:
                      top_eig_val = torch.lobpcg(final_hess[key], k=1, tol=1e-5)[0]
                      reg_coeff = top_eig_val / args.adaptive_reg_condNum
                    
                    if len(final_hess[key].shape) == 0:
                        eigs = final_hess[key].detach()
                        eigs = eigs.clamp(min=reg_coeff)
                        eig_vecs = torch.tensor([1.0]).to(device)

                        precond[key] = eig_vecs * eigs

                    else:
                        eigs, eig_vecs = torch.linalg.eigh(final_hess[key] + reg_coeff*torch.eye(final_hess[key].shape[0]).to(device))
                      
                        # Apply geodesic interpolation if requested
                        eig_vecs = geodesic_interpolate(eig_vecs, args.geodesic_interpolation)
                        eigs = torch.diag(torch.matmul(torch.matmul(eig_vecs.T, final_hess[key]), eig_vecs)) + reg_coeff
                        eigs = eigs.clamp(min=reg_coeff)
                        precond[key] = torch.matmul(eig_vecs, torch.matmul(torch.diag(eigs**args.power), eig_vecs.T))
                    
                    eig_vecs_layer[key] = eig_vecs
                    if args.log_spectra:
                        eig_vals[key] = eigs.detach().cpu().numpy()
                        if len(final_hess[key].shape) == 0:
                          eig_vals_unreg[key] = eigs.detach().cpu().numpy()
                        else:
                          eig_vals_unreg[key] = torch.linalg.eigh(final_hess[key])[0].detach().cpu().numpy()
        
        y_hat_val = model(x_val).squeeze(-1).detach()
        if args.model == 'transformer':
          y_hat_val = y_hat_val[:,-1, :]
          y_hat_val = y_hat_val.argmax(dim=1)

        acc = get_acc(y_hat_val, y_val)
        accs.append(acc)
        positive_prop = (y_val > 0).float().mean().item()

        for _ in range(1, args.num_samples+1):
            optimizer.zero_grad()
            if not args.reuse_batch:
                x, y = get_batch(n, problem_params, batch_size=args.bs)
                x, y = x.type(X_DATATYPE).to(device), y.type(Y_DATATYPE).to(device)
                y = y.squeeze(-1)
                if args.rotate:
                    x = torch.matmul(x, Q)
            y_hat = model(x).squeeze(-1)
            loss = loss_fn(y_hat, y)
            loss.backward()

        losses.append(loss.item())
        
        if args.opt in opt_lst:
            grad_curr = {}
            param_numels = {}
            for name, param in model.named_parameters():
                if not args.multidim_output and len(hess[name].shape) > 2:
                    grad_curr[name] = param.grad
                else:
                    grad_curr[name] = param.grad.flatten()
              
                if args.loss_scaled and 'adam' in args.opt: # Adjusting Adam into GN
                  loss_reg = max(loss.item(), args.reg_coeff)
                  grad_curr[name] *= loss_reg**(-1.0*args.power)
                elif args.loss_scaled and 'GN' in args.opt: # Adjusting GN into Adam
                  loss_reg = max(loss.item(), args.reg_coeff)
                  grad_curr[name] *= loss_reg**(args.power)


            if args.opt in ['GN-adam-layer', 'GN-adam-layer-kron']:
                # Adam-style variants: Apply eigendecomposition and adaptive scaling
                if args.opt == 'adam-random-basis': 
                    eig_vecs_layer = {}
                for key in grad_curr.keys():
                    if args.opt == 'adam-random-basis':
                        dim = param_numels[key]
                        # sample a random orthogonal matrix
                        rand_mtrx = torch.randn(dim, dim).to(device)
                        R_rand, _ = torch.linalg.qr(rand_mtrx)
                        eig_vecs_layer[key] = R_rand
                        grad_curr[key] = torch.matmul(R_rand.T, grad_curr[key])
                    else:
                        if args.multidim_output: 
                            raise NotImplementedError("Multidimensional output not supported for Adam-style variants") 
                        else:
                            if len(hess[key].shape) > 2:
                                # Transform gradients using eigenvectors
                                grad_curr[key] = torch.matmul(eig_vecs_layer[key][0].T, torch.matmul(grad_curr[key], eig_vecs_layer[key][1]))
                            else:
                                grad_curr[key] = torch.matmul(eig_vecs_layer[key].T, grad_curr[key])
                    
                    # Compute Adam-style moving averages
                    if key not in exp_avg_sq.keys():
                        exp_avg_sq[key] = torch.zeros_like(grad_curr[key])
                    exp_avg_sq[key] = args.beta2*exp_avg_sq[key] + (1-args.beta2)*(grad_curr[key]**2)
                    bias_correction = 1 - (args.beta2**(i+1))
                    denom = (exp_avg_sq[key]/bias_correction).sqrt().clamp(min=args.reg_coeff)
                    grad_curr[key] = grad_curr[key]/denom
                    
                    # Transform back to original space
                    if len(hess[key].shape) > 2:
                        grad_curr[key] = torch.matmul(eig_vecs_layer[key][0], torch.matmul(grad_curr[key], eig_vecs_layer[key][1].T))
                    elif len(grad_curr[key].shape) == 0:
                        grad_curr[key] = eig_vecs_layer[key] * grad_curr[key]
                    else:
                        grad_curr[key] = torch.matmul(eig_vecs_layer[key], grad_curr[key])
            else:
                for key in grad_curr.keys():
                    if not args.multidim_output:
                        if len(hess[key].shape) > 2: # Kron-variants
                            grad_curr[key] = torch.matmul(precond[key][0], torch.matmul(grad_curr[key], precond[key][1]))
                        else:
                            grad_curr[key] = torch.matmul(precond[key], grad_curr[key]).reshape(-1,1)
                    else:
                        if len(hess[key].shape) > 3:
                            raise NotImplementedError("Multidimensional output not supported for Kronecker-factored Hessian")
                        else:
                            grad_curr[key] = torch.matmul(precond[key], grad_curr[key]).reshape(-1,1)

            # Update model parameters with preconditioned gradients
            for name, param in model.named_parameters():
                param.grad = grad_curr[name].reshape(param.shape)  

        if args.lr_schedule == "constant":
            optimizer.step()
        elif args.lr_schedule == "cosine":
            lr = args.lr * 0.5 * (1 + math.cos(math.pi * i / args.iters))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            optimizer.step()
        elif args.lr_schedule == "step":
            if i % args.step_size == 0 and i >= args.step_constant_until:
                # NOTE: clip at 20 for numerical stability.
                lr = args.lr * (args.step_size_factor)**(min(20, i//args.step_size))
                if args.step_size_increase_factor > 1:
                    args.step_size = int(args.step_size * args.step_size_increase_factor)
                lr = max(lr, args.lr_min)
                lr = min(lr, args.lr_max)
                wandb_dict['lr'] = lr
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
            optimizer.step()

        wandb_dict = {'loss': loss.item(), 'acc': acc, 'positive_prop': positive_prop}

        if args.logistic_proj_diagonal:
          for name, param in model.named_parameters():
            if param.dim() == 2 and param.shape[0] == param.shape[1]:
              diagnal_entries = param.data.diag()
              param.data = torch.diag(diagnal_entries)

        if args.multidim_output:
            # Log the loss per-dimension
            with torch.no_grad():
                n_dims = y_hat.shape[1]
                for ndim_idx in range(n_dims):
                    loss = loss_fn(y_hat[:, ndim_idx], y[:, ndim_idx])
                    wandb_dict['loss_ndim_' + str(ndim_idx)] = loss.item()

        spectrums['loss'].append(loss.item())
        spectrums['acc'].append(acc)
        spectrums['positive_prop'].append(positive_prop)
        
        if args.log_weights and i % args.wandb_log_interval == 0:
            print(f"Logging weights at step {i}")
            norms = {}
            weights = {}
            named_params = list(model.named_parameters())
            for name, param in named_params:
                l1_norm = torch.norm(param, p=1).item()
                l2_norm = torch.norm(param, p=2).item() 
                lfro_norm = torch.norm(param, p='fro').item()
                linf_norm = torch.norm(param, p=float('inf')).item()
                norms[name + '_l1'] = l1_norm
                norms[name + '_l2'] = l2_norm
                norms[name + '_lfro'] = lfro_norm
                norms[name + '_linf'] = linf_norm
                weights[name] = param.detach().cpu().numpy()
                if torch.isnan(param).any():
                  HAS_NAN = True
            print(f"Logging norms: {norms}")
            wandb_dict.update(norms)

            if args.task == 'balanced_logistic':
                matrices = []
                for name, param in model.named_parameters():
                  if 'weight' in name:
                    param = param.detach().cpu().numpy()
                    weights[name] = param.reshape(-1)
                    for wi in range(len(weights[name])):
                      wandb_dict[f'{name}_entry{wi}'] = weights[name][wi] 
                    matrices.append(param.T)
                if args.logistic_net_vector_version == 0:
                    matrix_prod = np.linalg.multi_dot(matrices)
                    matrix_prod = matrix_prod.reshape(-1)
                else:
                    matrix_prod = np.prod(matrices, axis=0)
                for wi in range(len(matrix_prod)):
                  wandb_dict[f'matrix_prod_entry{wi}'] = matrix_prod[wi]

            spectrums['weights'] = weights
            if args.log_spectra and args.opt != 'adam' and args.opt != 'sgd':
                  spectrums['eig_vals'] = eig_vals
                  spectrums['eig_vals_unreg'] = eig_vals_unreg
            with open(f'spectrums/{wandb_name}/step_{i}.pkl', 'wb') as f:
                pickle.dump(spectrums, f)

        if i % args.wandb_log_interval == 0:
          wandb.log(wandb_dict, step=i)
        
        # Update progress bar description based on task
        bar.set_description(f"loss: {loss.item():.3f}, acc: {acc:.3f}")

        if HAS_NAN:
            print(f"Step {i} has NaN") 
            break

    except KeyboardInterrupt:
        print(f"Interrupted at step {i}")
        break
