# Sep 26, 2023
# Mert Pilanci
# Python code to produce Figure 7 of the paper 
#"From Complexity to Clarify: Analytical Expressions of Deep Neural Network
# Weights via Clifford's Geometric Algebra and Convexity"
# Based on the open-sourced implementation of CIFAR10 hyperlightspeedbench
# https://github.com/tysam-code/hlb-CIFAR10
# 
# Note: The one change we need to make if we're in Colab is to uncomment this below block.
# If we are in an ipython session or a notebook, clear the state to avoid bugs
"""
try:
  _ = get_ipython().__class__.__name__
  ## we set -f below to avoid prompting the user before clearing the notebook state
  %reset -f
except NameError:
  pass ## we're still good
"""
import functools
from functools import partial
import os
import copy
import csv
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import nn
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from model import * #mp modify
import numpy as np #mp modify
from datetime import datetime
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1" #mp modify
torch.manual_seed(1)
## <-- teaching comments
# <-- functional comments
# You can run 'sed -i.bak '/\#\#/d' ./main.py' to remove the teaching comments if they are in the way of your work. <3

# This can go either way in terms of actually being helpful when it comes to execution speed.
# torch.backends.cudnn.benchmark = True

# This code was built from the ground up to be directly hackable and to support rapid experimentation, which is something you might see
# reflected in what would otherwise seem to be odd design decisions. It also means that maybe some cleaning up is required before moving
# to production if you're going to use this code as such (such as breaking different section into unique files, etc). That said, if there's
# ways this code could be improved and cleaned up, please do open a PR on the GitHub repo. Your support and help is much appreciated for this
# project! :)


# This is for testing that certain changes don't exceed X% portion of the reference GPU (here an A100)
# so we can help reduce a possibility that future releases don't take away the accessibility of this codebase.
#torch.cuda.set_per_process_memory_fraction(fraction=8./40., device=0) ## 40. GB is the maximum memory of the base A100 GPU

# set global defaults (in this particular file) for convolutions
default_conv_kwargs = {'kernel_size': 1, 'padding': 'same', 'bias': True} #mp modify jan 30 2023

batchsize = 32
bias_scaler = 1#32
# To replicate the ~95.77% accuracy in 188 seconds runs, simply change the base_depth from 64->128 and the num_epochs from 10->80
hyp = {
    'opt': {
        'bias_lr':      0.05,# learning rates updated below #1*  1.15 * 1.35 * 1. * bias_scaler/batchsize,#mp modify added 0.01* # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :'))))
        'non_bias_lr':  0.05,#1*  1.15 * 1.35 * 1. / batchsize, #mp modify added 0.1*
        'bias_decay':     0*.85 * 4.8e-4 * batchsize/bias_scaler,
        'non_bias_decay': 1*.85 * 4.8e-4 * batchsize,
        'scaling_factor': 1./10,
        'percent_start': .2,
    },
    'net': {
        'whitening': {
            'kernel_size': 5, #kernel size 2 changed to 1 mp modify jan 30 2023
            'num_examples': 50000,
        },
        'batch_norm_momentum': .8,
        'cutout_size': 0,
        'pad_amount': 0,#3, #mp modify from 3 to 0 sep 7, 2023
        'base_depth': 128 ## mp modify 64 to 128 to 512 This should be a factor of 8 in some way to stay tensor core friendly
    },
    'misc': {
        'ema': {
            'epochs': 0,
            'decay_base': .986,
            'every_n_steps': 2,
        },
        'train_epochs': 5, #10 worked best mp modify 10 to 80
        'device': 'cuda', #mp modify 'cuda' to 'cpu'
        'data_location': 'data.pt',
    }
}
#############################################
#                Dataloader                 #
#############################################

if True: #not os.path.exists(hyp['misc']['data_location']):
        cifar10_mean, cifar10_std = [
            torch.tensor([0.4913997551666284, 0.48215855929893703, 0.4465309133731618], device=hyp['misc']['device']),
            torch.tensor([0.24703225141799082, 0.24348516474564, 0.26158783926049628],  device=hyp['misc']['device'])
        ]

        transform = transforms.Compose([
            #transforms.CenterCrop(8), #mp modify sep 14 2023
            transforms.ToTensor()])

        cifar10      = torchvision.datasets.CIFAR10('cifar10/', download=True,  train=True,  transform=transform)
        cifar10_eval = torchvision.datasets.CIFAR10('cifar10/', download=True, train=False, transform=transform)
        # convert to binary classification mp modify
         # convert to binary classification mp modify
        convert_binary_one_vs_all = False
        classes_to_pick = [1, 8] # mp modify 1 vs 8 
        assert not convert_binary_one_vs_all and (classes_to_pick is not None), f"Please set only one of convert_binary_one_vs_all and classes_to_pick"
        if convert_binary_one_vs_all:
            cifar10.targets = [int(x != 0) for x in cifar10.targets]
            cifar10_eval.targets = [int(x != 0) for x in cifar10_eval.targets]
        elif classes_to_pick is not None and len(classes_to_pick) == 2:
            data_mask = [x in classes_to_pick for x in cifar10.targets]
            # given list cifar10.data of length N and a binary mask data_mask of length N, we can use the following to filter the data
            cifar10.data = [x for x, y in zip(cifar10.data, data_mask) if y]
            cifar10.targets = [x for x, y in zip(cifar10.targets, data_mask) if y]
            data_mask = [x in classes_to_pick for x in cifar10_eval.targets]
            cifar10_eval.data = [x for x, y in zip(cifar10_eval.data, data_mask) if y]
            cifar10_eval.targets = [x for x, y in zip(cifar10_eval.targets, data_mask) if y]
            print(f"Filtered to {len(cifar10.targets)} training and {len(cifar10_eval.targets)} testing examples.")
            # map targets to 0 or 1
            cifar10.targets = [int(x == classes_to_pick[0]) for x in cifar10.targets]
            cifar10_eval.targets = [int(x == classes_to_pick[0]) for x in cifar10_eval.targets]
            #break and debug
        # #crop the images to 8x8 #mp modif ysep 14 2023
        # cifar10.data = [x[:, 8:16, 8:16] for x in cifar10.data]
        # cifar10_eval.data = [x[:, 8:8, 8:8] for x in cifar10_eval.data]
        # use the dataloader to get a single batch of all of the dataset items at once.
        train_dataset_gpu_loader = torch.utils.data.DataLoader(cifar10, batch_size=len(cifar10), drop_last=True,
                                                  shuffle=True, num_workers=2, persistent_workers=False)
        eval_dataset_gpu_loader = torch.utils.data.DataLoader(cifar10_eval, batch_size=len(cifar10_eval), drop_last=True,
                                                  shuffle=False, num_workers=1, persistent_workers=False)

        train_dataset_gpu = {}
        eval_dataset_gpu = {}

        train_dataset_gpu['images'], train_dataset_gpu['targets'] = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(train_dataset_gpu_loader))]
        eval_dataset_gpu['images'],  eval_dataset_gpu['targets']  = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(eval_dataset_gpu_loader)) ]

        def batch_normalize_images(input_images, mean, std):
            return (input_images - mean.view(1, -1, 1, 1)) / std.view(1, -1, 1, 1)

        # preload with our mean and std
        batch_normalize_images = partial(batch_normalize_images, mean=cifar10_mean, std=cifar10_std)

        ## Batch normalize datasets, now. Wowie. We did it! We should take a break and make some tea now.
        train_dataset_gpu['images'] = batch_normalize_images(train_dataset_gpu['images'])
        eval_dataset_gpu['images']  = batch_normalize_images(eval_dataset_gpu['images'])

        data = {
            'train': train_dataset_gpu,
            'eval': eval_dataset_gpu
        }

        ## Convert dataset to FP16 now for the rest of the process....
        data['train']['images'] = data['train']['images'].half()
        data['eval']['images']  = data['eval']['images'].half()

        torch.save(data, hyp['misc']['data_location'])

else:
    ## This is effectively instantaneous, and takes us practically straight to where the dataloader-loaded dataset would be. :)
    ## So as long as you run the above loading process once, and keep the file on the disc it's specified by default in the above
    ## hyp dictionary, then we should be good. :)
    data = torch.load(hyp['misc']['data_location'])


## As you'll note above and below, one difference is that we don't count loading the raw data to GPU since it's such a variable operation, and can sort of get in the way
## of measuring other things. That said, measuring the preprocessing (outside of the padding) is still important to us.

# Pad the GPU training dataset
if hyp['net']['pad_amount'] > 0:
    ## Uncomfortable shorthand, but basically we pad evenly on all _4_ sides with the pad_amount specified in the original dictionary
    data['train']['images'] = F.pad(data['train']['images'], (hyp['net']['pad_amount'],)*4, 'reflect')

#############################################
#            Network Components             #
#############################################

# We might be able to fuse this weight and save some memory/runtime/etc, since the fast version of the network might be able to do without somehow....
class BatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'], weight=False, bias=True):
        super().__init__(num_features, eps=eps, momentum=momentum)
        self.weight.data.fill_(1.0)
        self.bias.data.fill_(0.0)
        self.weight.requires_grad = weight
        self.bias.requires_grad = bias

# Allows us to set default arguments for the whole convolution itself.
class Conv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        kwargs = {**default_conv_kwargs, **kwargs}
        super().__init__(*args, **kwargs)
        self.kwargs = kwargs

# can hack any changes to each residual group that you want directly in here
class ConvGroup(nn.Module):
    def __init__(self, channels_in, channels_out, residual, short, pool, se):
        super().__init__()
        self.short = short
        self.pool = pool # todo: we can condense this later
        self.se = se

        self.residual = residual
        self.channels_in = channels_in
        self.channels_out = channels_out

        self.conv1 = Conv(channels_in, channels_out) #mp modify 
        #self.conv1 = Conv(channels_in, channels_out, kernel_size=8, stride=1, padding=1, bias=False) #mp modify
        self.pool1 = nn.MaxPool2d(2)
        self.norm1 = BatchNorm(channels_out)
        #self.activ = nn.GELU()          
        self.activ = nn.ReLU() #mp modify   
        # note: this has to be flat if we're jitting things.... we just might burn a bit of extra GPU mem if so
        if not short:
            self.conv2 = Conv(channels_out, channels_out)
            self.conv3 = Conv(channels_out, channels_out)
            self.norm2 = BatchNorm(channels_out)
            self.norm3 = BatchNorm(channels_out)

            self.se1 = nn.Linear(channels_out, channels_out//16)
            self.se2 = nn.Linear(channels_out//16, channels_out)

    def forward(self, x):
        x = self.conv1(x)
        if self.pool:
            x = self.pool1(x)
        x = self.norm1(x)
        x = self.activ(x) #mp modify
        if self.short: # layer 2 doesn't necessarily need the residual, so we just return it.
            return x
        residual = x
        if self.se:
            mult = torch.sigmoid(self.se2(self.activ(self.se1(torch.mean(residual, dim=(2,3)))))).unsqueeze(-1).unsqueeze(-1)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.activ(x) #mp modify
        x = self.conv3(x)
        if self.se:
            x = x * mult

        x = self.norm3(x)
        x = self.activ(x) #mp modify
        if self.residual: #mp added
            x = x + residual
        ##x = x + residual # haiku #mp modify
        return x

# Set to 1 for now just to debug a few things....
class TemperatureScaler(nn.Module):
    def __init__(self, init_val):
        super().__init__()
        self.scaler = torch.tensor(init_val)

    def forward(self, x):
        x.float() ## save precision for the gradients in the backwards pass
                  ## I personally believe from experience that this is important
                  ## for a few reasons. I believe this is the main functional difference between
                  ## my implementation, and David's implementation...
        return x.mul(self.scaler)

class FastGlobalMaxPooling(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        # Previously was chained torch.max calls.
        # requires less time than AdaptiveMax2dPooling -- about ~.3s for the entire run, in fact (which is pretty significant! :O :D :O :O <3 <3 <3 <3)
        return torch.amax(x, dim=(2,3)) # Global maximum pooling

#############################################
#          Init Helper Functions            #
#############################################

def get_patches(x, patch_shape=(3, 3), dtype=torch.float32):
    # TODO: Annotate
    c, (h, w) = x.shape[1], patch_shape
    return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).to(dtype) # TODO: Annotate?

def get_whitening_parameters(patches):
    # TODO: Let's annotate this, please! :'D / D':
    n,c,h,w = patches.shape
    est_covariance = torch.cov(patches.view(n, c*h*w).t())
    eigenvalues, eigenvectors = torch.linalg.eigh(est_covariance, UPLO='U') # this is the same as saying we want our eigenvectors, with the specification that the matrix be an upper triangular matrix (instead of a lower-triangular matrix)
    return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.t().reshape(c*h*w,c,h,w).flip(0)

# Run this over the training set to calculate the patch statistics, then set the initial convolution as a non-learnable 'whitening' layer
def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block_data=None, pad_amount=None, freeze=True, whiten_splits=None): #mp modify freeze false
    if train_set is not None and previous_block_data is None:
        if pad_amount > 0:
            previous_block_data = train_set[:num_examples,:,pad_amount:-pad_amount,pad_amount:-pad_amount] # if it's none, we're at the beginning of our network.
        else:
            previous_block_data = train_set[:num_examples,:,:,:]
    if whiten_splits is None:
         previous_block_data_split = [previous_block_data] # list of length 1 so we can reuse the splitting code down below
    else:
         previous_block_data_split = previous_block_data.split(whiten_splits, dim=0)

    eigenvalue_list, eigenvector_list = [], []
    for data_split in previous_block_data_split:
        eigenvalues, eigenvectors = get_whitening_parameters(get_patches(data_split, patch_shape=layer.weight.data.shape[2:])) # center crop to remove padding
        eigenvalue_list.append(eigenvalues)
        eigenvector_list.append(eigenvectors)

    eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0)
    eigenvectors = torch.stack(eigenvector_list, dim=0).mean(0)
    # for some reason, the eigenvalues and eigenvectors seem to come out all in float32 for this? ! ?! ?!?!?!? :'(((( </3
    set_whitening_conv(layer, eigenvalues.to(dtype=layer.weight.dtype), eigenvectors.to(dtype=layer.weight.dtype), freeze=freeze)
    data = layer(previous_block_data.to(dtype=layer.weight.dtype))
    return data

def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=True): #mp modify freeze false
    shape = conv_layer.weight.data.shape
    conv_layer.weight.data[-eigenvectors.shape[0]:, :, :, :] = (eigenvectors/torch.sqrt(eigenvalues+eps))[-shape[0]:, :, :, :]
    ## We don't want to train this, since this is implicitly whitening over the whole dataset
    ## For more info, see David Page's original blogposts (link in the README.md as of this commit.)
    if freeze: 
        conv_layer.weight.requires_grad = False


#############################################
#            Network Definition             #
#############################################

scaler = 2. ## You can play with this on your own if you want, for the first beta I wanted to keep things simple (for now) and leave it out of the hyperparams dict
depths = {
    'init':   round(scaler**-1*hyp['net']['base_depth']), # 64  w/ scaler at base value
    'block1': round(scaler**1*hyp['net']['base_depth']), # 128 w/ scaler at base value
    'block2': round(scaler**2*hyp['net']['base_depth']), # 256 w/ scaler at base value
    'block3': round(scaler**3*hyp['net']['base_depth']), # 512 w/ scaler at base value
    'num_classes': 2 #mp modify 10 to 2
}

class SpeedyResNet(nn.Module):
    def __init__(self, network_dict):
        super().__init__()
        self.net_dict = network_dict # flexible, defined in the make_net function

    # This allows you to customize/change the execution order of the network as needed.
    def forward(self, x):
        if not self.training:
            x = torch.cat((x, torch.flip(x, (-1,))))
        x = self.net_dict['initial_block']['whiten'](x) #mp modify removed the first whitening block
        x = self.net_dict['initial_block']['project'](x)
        x = self.net_dict['initial_block']['norm'](x)
        x = self.net_dict['initial_block']['activation'](x)
        x = self.net_dict['residual1'](x)
        x = self.net_dict['residual2'](x) #mp modify
        x = self.net_dict['residual3'](x) #mp modify
        x = self.net_dict['pooling'](x)
        x = self.net_dict['linear'](x)
        x = self.net_dict['temperature'](x)
        if not self.training:
            # Average the predictions from the lr-flipped inputs during eval
            orig, flipped = x.split(x.shape[0]//2, dim=0)
            x = .5 * orig + .5 * flipped
        return x
class NonConvexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 512
        self.kernel_size = 5
        self.conv1 = nn.Conv2d(3, self.filters, self.kernel_size, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(self.filters, affine=False)
        self.conv2 = nn.Conv2d(self.filters, self.filters, 1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(self.filters, affine=False)
        self.conv3 = nn.Conv2d(self.filters, self.filters, 1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(self.filters, affine=False)
        self.conv4 = nn.Conv2d(self.filters, self.filters, 1, stride=1, padding=0, bias=False)
        self.bn4 = nn.BatchNorm2d(self.filters, affine=False)
        self.conv5 = nn.Conv2d(self.filters, self.filters, 1, stride=1, padding=0, bias=False)
        self.bn5 = nn.BatchNorm2d(self.filters, affine=False)
        self.conv6 = nn.Conv2d(self.filters, self.filters, 1, stride=1, padding=0, bias=False)
        self.bn6 = nn.BatchNorm2d(self.filters, affine=False)
        #self.pool = nn.AdaptiveMaxPool2d((1, 1))
        self.fc1 = nn.Linear(self.filters, 2)
        self.deep = False
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        if self.deep:
            x0 = x
            x = self.conv2(x)
            x = self.bn2(x)
            x = F.relu(x)
            x = x + x0
            x1 = x
            x = self.conv3(x)
            x = self.bn3(x)
            x = F.relu(x)
            x = x + x1
            x = self.conv4(x)
            x = self.bn4(x)
            x = F.relu(x)
            # x = self.conv5(x)
            # x = self.bn5(x)
            # x = F.relu(x)
            # x = self.conv6(x)
            # x = self.bn6(x)
            # x = F.relu(x)
        #x = F.relu(self.bn1(self.conv1(x)))
        #x = F.relu(self.bn2(self.conv2(x)))
        #x = F.relu(self.bn3(self.conv3(x)))
        #x = F.relu(self.bn4(self.conv4(x)))
        #x = self.pool(x)
        #global average ppoling
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
        #x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.fc1(x)
        return x       
class MyConvexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 256
        self.kernel_size = 22
        self.deep_patterns = False
        self.depth = 4
        self.input_size = 3
        self.conv1 = nn.Conv2d(self.input_size, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.conv2 = nn.Conv2d(self.input_size, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.bn0 = nn.BatchNorm2d(self.input_size, affine=True)
        self.bn1 = nn.BatchNorm2d(self.filters, affine=True)
        if self.deep_patterns == True:
            self.padding_size = int((self.kernel_size-1)/2)
            self.bn1 = nn.BatchNorm2d(self.filters, affine=True)
            self.bn2 = nn.BatchNorm2d(self.filters, affine=True)
            self.bn3 = nn.BatchNorm2d(self.filters, affine=True)
            self.bn4 = nn.BatchNorm2d(self.filters, affine=True)
            #self.bn5 = nn.BatchNorm2d(self.filters, affine=True)
            #self.bn6 = nn.BatchNorm2d(self.filters, affine=True)
            #self.bn7 = nn.BatchNorm2d(self.filters, affine=True)
            #self.bn8 = nn.BatchNorm2d(self.filters, affine=True)
            self.conv3 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            self.conv4 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            #self.conv5 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            #self.conv6 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            ##self.conv7 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            #self.conv8 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            self.conv3.weight.requires_grad = False
            self.conv4.weight.requires_grad = False
            #self.conv5.weight.requires_grad = False
            #self.conv6.weight.requires_grad = False
        self.conv1.weight.requires_grad = True
        self.conv2.weight.requires_grad = False
    def forward(self, x):
        if self.deep_patterns == False:
            #x = self.conv1(self.bn0(x))*(self.conv2(self.bn0(x))>=0)
            #x = self.conv1(x)*torch.sign(self.conv2(x))#(self.conv2(x)>=0)#
            #x = self.bn0(x)
            x = self.conv1(x)*(self.conv2(x)>=0)# 
            #x = self.bn1(self.conv1(x))*(self.conv2(x)>=0)#
            #x = x1*(x2>=torch.median(x2))
        else:
            if self.depth<=3:
                #x = self.conv1(x)*(self.bn2(self.conv2(x)+self.conv3(F.relu(self.bn1(self.conv2(x)))))>=torch.median(self.bn2(self.conv2(x)+self.conv3(F.relu(self.bn1(self.conv2(x)))))))
                x = self.conv1(x)*torch.sign(self.bn2(self.conv2(x)+self.conv3(F.relu(self.bn1(self.conv2(x))))))
            elif self.depth==4: 
                x2 = self.conv2(self.bn0(x))
                x3=(x2+self.bn2(self.conv3(F.relu(self.bn1(x2)))))
                x = x3*(x3+self.bn4(self.conv4(F.relu(self.bn3(x3))))>=0)
            elif self.depth==5:
                x3 = (x2+self.bn2(self.conv3(F.relu(self.bn1(x2)))))
                x4 = (x3+self.bn4(self.conv4(F.relu(self.bn3(x3)))))  
                x = x1*(x4+self.bn6(self.conv6(F.relu(self.bn5(x4))))>=0) 
            elif self.depth==6:
                x3 = (x2+self.bn2(self.conv3(F.relu(self.bn1(x2)))))
                x4 = (x3+self.bn4(self.conv4(F.relu(self.bn3(x3)))))  
                x5 = (x4+self.bn6(self.conv5(F.relu(self.bn5(x4)))))  
                x = x1*(x5+self.bn8(self.conv8(F.relu(self.bn7(x5))))>=0)
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2) # global average pooling
        #sum the columns of x over the first half and the second half
        xsum1 = torch.sum(x[:,:x.shape[1]//2],dim=1)
        xsum2 = torch.sum(x[:,x.shape[1]//2:],dim=1)
        #stack two sums as a new tensor 
        x = torch.stack((xsum1,xsum2),axis=1)
        return x
class MyNonConvexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 128
        self.kernel_size = 22
        self.conv1 = nn.Conv2d(3, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.fc1 = nn.Linear(self.filters, 2, bias=True)
        self.bn1 = nn.BatchNorm2d(self.filters, affine=True)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
        x = self.fc1(x)
        return x 
class MyNonConvexNet3(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 128
        self.kernel_size = 22
        self.kernel_size_rest = 1
        self.conv1 = nn.Conv2d(3, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.conv2 = nn.Conv2d(self.filters, self.filters, self.kernel_size_rest, stride=1, padding=0, bias=True)
        self.fc1 = nn.Linear(self.filters, 2, bias=True)
        self.bn1 = nn.BatchNorm2d(self.filters, affine=True)
        self.bn2 = nn.BatchNorm2d(self.filters, affine=True)
    def forward(self, x): 
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
        x = self.fc1(x)
        return x             
class MyNonConvexNet4(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 512
        self.kernel_size = 22
        self.kernel_size_rest = 1
        self.conv1 = nn.Conv2d(3, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.conv2 = nn.Conv2d(self.filters, self.filters, self.kernel_size_rest, stride=1, padding=0, bias=True)
        self.conv3 = nn.Conv2d(self.filters, self.filters, self.kernel_size_rest, stride=1, padding=0, bias=True)
        self.fc1 = nn.Linear(self.filters, 2, bias=True)
        self.bn0 = nn.BatchNorm2d(3, affine=True)
        self.bn1 = nn.BatchNorm2d(self.filters, affine=True)
        self.bn2 = nn.BatchNorm2d(self.filters, affine=True)
        self.bn3 = nn.BatchNorm2d(self.filters, affine=True)

    def forward(self, x): 
        x = self.bn0(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
        x = self.fc1(x)
        return x 
class MyNonConvexNet5(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 512
        self.kernel_size = 9
        self.conv1 = nn.Conv2d(3, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.conv2 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.conv3 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.conv4 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.fc1 = nn.Linear(self.filters, 2, bias=True)
    def forward(self, x): 
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
        x = self.fc1(x)
        return x 
class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 1, 1, stride=1, padding=0, bias=False)
        self.fc1 = nn.Linear(1, 2, bias=True)
    def forward(self, x):
        x = self.conv1(x) 
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
        x = self.fc1(x)
        return x   
class MyDeepFCModel(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.filters = 128 #params['neurons']
        self.filtersinitial = 128
        self.fc1 = nn.Linear(3*32*32, self.filtersinitial, bias=True)
        self.fc2 = nn.Linear(self.filtersinitial, self.filters, bias=True)
        self.fc3 = nn.Linear(self.filters, 1, bias=False)
        self.bn1 = nn.BatchNorm1d(self.filtersinitial, affine=True)
        self.bn2 = nn.BatchNorm1d(self.filters, affine=True)
    def forward(self, x): 
        #if not self.training:
        #    x = torch.cat((x, torch.flip(x, (-1,))))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        #x = self.bn2(x)
        x = F.relu(x)
        #x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
        x = self.fc3(x)
        #if not self.training:
        #    # Average the predictions from the lr-flipped inputs during eval
        #    orig, flipped = x.split(x.shape[0]//2, dim=0)
        #    x = .5 * orig + .5 * flipped
        return x
class MyFCModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 8
        self.fc1 = nn.Linear(3*32*32, self.filters, bias=True)
        self.fc2 = nn.Linear(self.filters, 2, bias=True)
        self.bn1 = nn.BatchNorm1d(self.filters, affine=True)
        self.bn2 = nn.BatchNorm1d(2, affine=True)
    def forward(self, x): 
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        #x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        #x = self.bn2(x)
        return x
    
class MyScalarFCModel(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.filters = params['neurons']
        self.fc1 = nn.Linear(3*32*32, self.filters, bias=True)
        self.fc2 = nn.Linear(self.filters, 1, bias=False)
        #self.bn1 = nn.BatchNorm1d(self.filters, affine=True)
        #self.sigmoid = nn.Sigmoid()
    def forward(self, x): 
        if not self.training:
            x = torch.cat((x, torch.flip(x, (-1,))))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        #x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        #x = self.sigmoid(x)
        if not self.training:
            # Average the predictions from the lr-flipped inputs during eval
            orig, flipped = x.split(x.shape[0]//2, dim=0)
            x = .5 * orig + .5 * flipped
        return x

def make_net():
    # TODO: A way to make this cleaner??
    # Note, you have to specify any arguments overlapping with defaults (i.e. everything but in/out depths) as kwargs so that they are properly overridden (TODO cleanup somehow?)
    whiten_conv_depth = 3*hyp['net']['whitening']['kernel_size']**2
    network_dict = nn.ModuleDict({
        'initial_block': nn.ModuleDict({
            'whiten': Conv(3, whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size'], padding=0), #mp modify jan 30 2023 was 3 changed to 1
            'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1),
            'norm': BatchNorm(depths['init'], weight=False),
            'activation': nn.ReLU(), #mp modify 
        }),
        'residual1': ConvGroup(depths['init'], depths['block1'], residual=True, short=False, pool=True, se=False), # mp short was false se was true
        # MP modify. removed the whitening block and changed depths['init'] to 3 in the first residual1 block
        #'residual1': ConvGroup(3, depths['block1'], residual=True, short=False, pool=True, se=False), # mp short was false se was true residual was true
        'residual2': ConvGroup(depths['block1'], depths['block2'], residual=True, short=False, pool=True, se=False), #mp modify se was true residual was true
        'residual3': ConvGroup(depths['block2'], depths['block3'], residual=True, short=False, pool=True, se=False), #mp modify  mp short was false se was true residual was true
        'pooling': FastGlobalMaxPooling(),
        'linear': nn.Linear(depths['block3'], depths['num_classes'], bias=False), #mp modify block3 to block1
        'temperature': TemperatureScaler(hyp['opt']['scaling_factor'])
    })


    net = MyDeepFCModel(params)

    
    #convolutional neural network model
    #net = NonConvexNet()
    net = net.to(hyp['misc']['device'])
    net = net.to(memory_format=torch.channels_last) # to appropriately use tensor cores/avoid thrash while training
    net.train()
    net.half() # Convert network to half before initializing the initial whitening layer. #mp modify

    # Initialize the whitening convolution #mp modify
    # with torch.no_grad():
    #     # Initialize the first layer to be fixed weights that whiten the expected input values of the network be on the unit hypersphere. (i.e. their...average vector length is 1.?, IIRC)
    #     init_whitening_conv(net.net_dict['initial_block']['whiten'],
    #                         data['train']['images'].index_select(0, torch.randperm(data['train']['images'].shape[0], device=data['train']['images'].device)),
    #                         num_examples=hyp['net']['whitening']['num_examples'],
    #                         pad_amount=hyp['net']['pad_amount'],
    #                         whiten_splits=5000) ## Hardcoded for now while we figure out the optimal whitening number
    #                                             ## If you're running out of memory (OOM) feel free to decrease this, but
    #                                             ## the index lookup in the dataloader may give you some trouble depending
    #                                             ## upon exactly how memory-limited you are

    return net

#############################################
#            Data Preprocessing             #
#############################################

## This is actually (I believe) a pretty clean implementation of how to do something like this, since shifted-square masks unique to each depth-channel can actually be rather
## tricky in practice. That said, if there's a better way, please do feel free to submit it! This can be one of the harder parts of the code to understand (though I personally get
## stuck on the fold/unfold process for the lower-level convolution calculations.
def make_random_square_masks(inputs, mask_size):
    ##### TODO: Double check that this properly covers the whole range of values. :'( :')
    if mask_size == 0:
        return None # no need to cutout or do anything like that since the patch_size is set to 0
    is_even = int(mask_size % 2 == 0)
    in_shape = inputs.shape

    # seed centers of squares to cutout boxes from, in one dimension each
    mask_center_y = torch.empty(in_shape[0], dtype=torch.long, device=inputs.device).random_(mask_size//2-is_even, in_shape[-2]-mask_size//2-is_even)
    mask_center_x = torch.empty(in_shape[0], dtype=torch.long, device=inputs.device).random_(mask_size//2-is_even, in_shape[-1]-mask_size//2-is_even)

    # measure distance, using the center as a reference point
    to_mask_y_dists = torch.arange(in_shape[-2], device=inputs.device).view(1, 1, in_shape[-2], 1) - mask_center_y.view(-1, 1, 1, 1)
    to_mask_x_dists = torch.arange(in_shape[-1], device=inputs.device).view(1, 1, 1, in_shape[-1]) - mask_center_x.view(-1, 1, 1, 1)

    to_mask_y = (to_mask_y_dists >= (-(mask_size // 2) + is_even)) * (to_mask_y_dists <= mask_size // 2)
    to_mask_x = (to_mask_x_dists >= (-(mask_size // 2) + is_even)) * (to_mask_x_dists <= mask_size // 2)

    final_mask = to_mask_y * to_mask_x ## Turn (y by 1) and (x by 1) boolean masks into (y by x) masks through multiplication. Their intersection is square, hurray! :D

    return final_mask

def batch_cutout(inputs, patch_size):
    with torch.no_grad():
        cutout_batch_mask = make_random_square_masks(inputs, patch_size)
        if cutout_batch_mask is None:
            return inputs # if the mask is None, then that's because the patch size was set to 0 and we will not be using cutout today.
        # TODO: Could be fused with the crop operation for sheer speeeeeds. :D <3 :))))
        cutout_batch = torch.where(cutout_batch_mask, torch.zeros_like(inputs), inputs)
        return cutout_batch

def batch_crop(inputs, crop_size):
    with torch.no_grad():
        crop_mask_batch = make_random_square_masks(inputs, crop_size)
        cropped_batch = torch.masked_select(inputs, crop_mask_batch).view(inputs.shape[0], inputs.shape[1], crop_size, crop_size)
        return cropped_batch

def batch_flip_lr(batch_images, flip_chance=.5):
    with torch.no_grad():
        # TODO: Is there a more elegant way to do this? :') :'((((
        return torch.where(torch.rand_like(batch_images[:, 0, 0, 0].view(-1, 1, 1, 1)) < flip_chance, torch.flip(batch_images, (-1,)), batch_images)


########################################
#          Training Helpers            #
########################################

class NetworkEMA(nn.Module):
    def __init__(self, net, decay):
        super().__init__() # init the parent module so this module is registered properly
        self.net_ema = copy.deepcopy(net).eval().requires_grad_(False) # copy the model
        self.decay = decay ## you can update/hack this as necessary for update scheduling purposes :3

    def update(self, current_net):
        with torch.no_grad():
            for ema_net_parameter, incoming_net_parameter in zip(self.net_ema.state_dict().values(), current_net.state_dict().values()): # potential bug: assumes that the network architectures don't change during training (!!!!)
                if incoming_net_parameter.dtype in (torch.half, torch.float):
                    ema_net_parameter.mul_(self.decay).add_(incoming_net_parameter.detach().mul(1. - self.decay)) # update the ema values in place, similar to how optimizer momentum is coded

    def forward(self, inputs):
        with torch.no_grad():
            return self.net_ema(inputs)

# TODO: Could we jit this in the (more distant) future? :)
@torch.no_grad()
def get_batches(data_dict, key, batchsize):
    num_epoch_examples = len(data_dict[key]['images'])
    shuffled = torch.randperm(num_epoch_examples, device='cuda')
    crop_size = 32 # mp modify sep 14 2023
    ## Here, we prep the dataset by applying all data augmentations in batches ahead of time before each epoch, then we return an iterator below
    ## that iterates in chunks over with a random derangement (i.e. shuffled indices) of the individual examples. So we get perfectly-shuffled
    ## batches (which skip the last batch if it's not a full batch), but everything seems to be (and hopefully is! :D) properly shuffled. :)
    if key == 'train':
        #images = batch_crop(data_dict[key]['images'], crop_size) # TODO: hardcoded image size for now?
        #images = batch_flip_lr(images)#mp modify sep 14 2023
        #images = batch_cutout(images, patch_size=hyp['net']['cutout_size']) #mp modify sep 14 2023
        images = data_dict[key]['images']
    else:
        images = data_dict[key]['images']

    # Send the images to an (in beta) channels_last to help improve tensor core occupancy (and reduce NCHW <-> NHWC thrash) during training
    images = images.to(memory_format=torch.channels_last)
    for idx in range(num_epoch_examples // batchsize):
        if not (idx+1)*batchsize > num_epoch_examples: ## Use the shuffled randperm to assemble individual items into a minibatch
            yield images.index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]), \
                  data_dict[key]['targets'].index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]) ## Each item is only used/accessed by the network once per epoch. :D


def init_split_parameter_dictionaries(network):
    params_non_bias = {'params': [], 'lr': hyp['opt']['non_bias_lr'], 'momentum': .85, 'nesterov': True, 'weight_decay': hyp['opt']['non_bias_decay']}
    params_bias     = {'params': [], 'lr': hyp['opt']['bias_lr'],     'momentum': .85, 'nesterov': True, 'weight_decay': hyp['opt']['bias_decay']}

    for name, p in network.named_parameters():
        if p.requires_grad:
            if 'bias' in name:
                params_bias['params'].append(p)
            else:
                params_non_bias['params'].append(p)
    return params_non_bias, params_bias


## Hey look, it's the soft-targets/label-smoothed loss! Native to PyTorch. Now, _that_ is pretty cool, and simplifies things a lot, to boot! :D :)
#loss_fn = nn.CrossEntropyLoss()#(label_smoothing=0.2, reduction='none') #mp modify
#loss_fn = nn.BCEWithLogitsLoss()#(label_smoothing=0.2, reduction='none') #mp modify sep 7 2023
#squared loss
loss_fn = nn.MSELoss()#(label_smoothing=0.2, reduction='none') #mp modify sep 7 2023
logging_columns_list = ['epoch', 'train_loss', 'val_loss', 'train_acc', 'val_acc', 'ema_val_acc', 'total_time_seconds']
# define the printing function and print the column heads
def print_training_details(columns_list, separator_left='|  ', separator_right='  ', final="|", column_heads_only=False, is_final_entry=False):
    print_string = ""
    if column_heads_only:
        for column_head_name in columns_list:
            print_string += separator_left + column_head_name + separator_right
        print_string += final
        print('-'*(len(print_string))) # print the top bar
        print(print_string)
        print('-'*(len(print_string))) # print the bottom bar
    else:
        for column_value in columns_list:
            print_string += separator_left + column_value + separator_right
        print_string += final
        print(print_string)
    if is_final_entry:
        print('-'*(len(print_string))) # print the final output bar
def transform_U(U, X):
    """
    For each row in U, find the (d-1) rows in X with the lowest inner-product magnitude,
    then replace the row in U with the normal of the hyperplane through the origin and those points.
    """
    d = U.shape[1]  # Dimensions
    transformed_U = np.zeros_like(U)
    #for i, u_row in enumerate(U):
    for i, u_row in tqdm(enumerate(U), total=U.shape[0], desc="Processing neurons"):
        u_row_normalized = u_row# / np.linalg.norm(u_row)
        inner_products = [np.abs(np.dot(u_row_normalized, x_row)) for x_row in X]#/ np.linalg.norm(x_row)
        #add debug point
        # Get indices of the (d-1) smallest inner products
        smallest_indices = np.argsort(inner_products)[:d-1]
        
        # Get the (d-1) rows from X
        selected_rows = X[smallest_indices]
        
        # Find normal of hyperplane passing through the origin and these rows
        normal = find_normal_of_hyperplane(selected_rows)
        
        # Replace the current row of U with this normal
        transformed_U[i] = normal
    
    return transformed_U
def find_normal_of_hyperplane(points):
    """
    Find the normal vector of the hyperplane passing through the origin and the given (d-1) points.
    Assumes points is a (d-1)xN matrix, where each row is a point in d-dimensional space.
    """
    method = 'eigh'
    # Use SVD to find the null space of the matrix formed by points
    if method == 'svd':
        u, s, vh = np.linalg.svd(points, full_matrices=True)
        #d = u.shape[1]
        # MP: change this to rank k svd not the full svd.
        # problem: smallest singular value is zero for MNIST.
        # The normal vector is the last column of vh, corresponding to the smallest singular value
        normal = vh[-1]
    else:
        G = np.dot(points.T, points)
        # Compute the eigenvalues and eigenvectors
        eigenvalues, eigenvectors = np.linalg.eigh(G)
        # The smallest eigenvalue's corresponding eigenvector
        normal = eigenvectors[:, 0]
    normalized_normal = normal / np.linalg.norm(normal)
    
    return normalized_normal
def relu(x):
    return np.maximum(0,x)
print_training_details(logging_columns_list, column_heads_only=True) ## print out the training column heads before we print the actual content for each run.

########################################
#           Train and Eval             #
########################################

def main(params):
    hyp['opt']['non_bias_lr'] = params['learning_rate']
    hyp['opt']['bias_lr'] = params['learning_rate']
    # Initializing constants for the whole run.
    net_ema = None ## Reset any existing network emas, we want to have _something_ to check for existence so we can initialize the EMA right from where the network is during training
                   ## (as opposed to initializing the network_ema from the randomly-initialized starter network, then forcing it to play catch-up all of a sudden in the last several epochs)

    total_time_seconds = 0.
    current_steps = 0.
    
    # TODO: Doesn't currently account for partial epochs really (since we're not doing "real" epochs across the whole batchsize)....
    num_steps_per_epoch      = len(data['train']['images']) // batchsize
    total_train_steps        = num_steps_per_epoch * hyp['misc']['train_epochs']
    ema_epoch_start          = hyp['misc']['train_epochs'] - hyp['misc']['ema']['epochs']
    num_cooldown_before_freeze_steps = 0
    num_low_lr_steps_for_ema = hyp['misc']['ema']['epochs'] * num_steps_per_epoch

    ## I believe this wasn't logged, but the EMA update power is adjusted by being raised to the power of the number of "every n" steps
    ## to somewhat accomodate for whatever the expected information intake rate is. The tradeoff I believe, though, is that this is to some degree noisier as we
    ## are intaking fewer samples of our distribution-over-time, with a higher individual weight each. This can be good or bad depending upon what we want.
    projected_ema_decay_val  = hyp['misc']['ema']['decay_base'] ** hyp['misc']['ema']['every_n_steps']

    # Adjust pct_start based upon how many epochs we need to finetune the ema at a low lr for
    pct_start = hyp['opt']['percent_start'] * (total_train_steps/(total_train_steps - num_low_lr_steps_for_ema))

    # Get network
    net = make_net()
    # Display neural network architecture
    #print(net)
    # Get the number of parameters in the network
    num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print('Number of parameters: {}'.format(num_params))

    ## Stowing the creation of these into a helper function to make things a bit more readable....
    non_bias_params, bias_params = init_split_parameter_dictionaries(net)
    print('learning rate:')
    print(non_bias_params['lr'])
    #print(bias_params['lr'])
    # One optimizer for the regular network, and one for the biases. This allows us to use the superconvergence onecycle training policy for our networks....
    opt = torch.optim.SGD(**non_bias_params)
    opt_bias = torch.optim.SGD(**bias_params)

    #opt = torch.optim.SGD(**non_bias_params)
    #opt_bias = torch.optim.SGD(**bias_params)

    ## Not the most intuitive, but this basically takes us from ~0 to max_lr at the point pct_start, then down to .1 * max_lr at the end (since 1e16 * 1e-15 = .1 --
    ##   This quirk is because the final lr value is calculated from the starting lr value and not from the maximum lr value set during training)
    initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D
    final_lr_ratio = .135
    
    # constant step size scheduler
    lr_sched = torch.optim.lr_scheduler.StepLR(opt, step_size=100, gamma=0.1)
    lr_sched_bias = torch.optim.lr_scheduler.StepLR(opt_bias, step_size=100, gamma=0.1)

    # lr_sched      = torch.optim.lr_scheduler.OneCycleLR(opt,  max_lr=non_bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps-num_low_lr_steps_for_ema, anneal_strategy='linear', cycle_momentum=False)
    # lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(opt_bias, max_lr=bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps-num_low_lr_steps_for_ema, anneal_strategy='linear', cycle_momentum=False)

    ## For accurately timing GPU code
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    ## There's another repository that's mainly reorganized David's code while still maintaining some of the functional structure, and it
    ## has a timing feature too, but there's no synchronizes so I suspect the times reported are much faster than they may be in actuality
    ## due to some of the quirks of timing GPU operations.
    torch.cuda.synchronize() ## clean up any pre-net setup operations
    
    if True: ## Sometimes we need a conditional/for loop here, this is placed to save the trouble of needing to indent
        for epoch in range(hyp['misc']['train_epochs']):
          #################
          # Training Mode #
          #################
          torch.cuda.synchronize()
          starter.record()
          net.train()

          loss_train = None
          accuracy_train = None

          for epoch_step, (inputs, targets) in enumerate(get_batches(data, key='train', batchsize=batchsize)):
            ## Run everything through the network
            outputs = net(inputs)
            outputs= outputs.squeeze()  #mp modify sep 7 2023
            targets = targets.to(torch.float16) #mp modify sep 7 2023
            loss_scale_scaler = 1./16 # Hardcoded for now, preserves some accuracy during the loss summing process, balancing out its regularization effects
            ## If you want to add other losses or hack around with the loss, you can do that here.
            loss = loss_fn(outputs, targets).mul(loss_scale_scaler).sum().div(loss_scale_scaler) ## Note, as noted in the original blog posts, the summing here does a kind of loss scaling
                                                    ## (and is thus batchsize dependent as a result). This can be somewhat good or bad, depending...

            # we only take the last-saved accs and losses from train
            if epoch_step % 50 == 0:
                train_acc = (outputs.detach().round() == targets).float().mean().item() #mp modify sep 13 2023
                #train_acc = (outputs.detach().argmax(-1) == targets).float().mean().item()
                train_loss = loss.detach().cpu().item()/batchsize

            loss.backward()

            ## Step for each optimizer, in turn.
            opt.step()
            opt_bias.step()

            if current_steps < total_train_steps - num_low_lr_steps_for_ema - 1: # the '-1' is because the lr scheduler tends to overshoot (even below 0 if the final lr is ~0) on the last step for some reason.
                # We only want to step the lr_schedulers while we have training steps to consume. Otherwise we get a not-so-friendly error from PyTorch
                lr_sched.step()
                lr_sched_bias.step()

            ## Using 'set_to_none' I believe is slightly faster (albeit riskier w/ funky gradient update workflows) than under the default 'set to zero' method
            opt.zero_grad(set_to_none=True)
            opt_bias.zero_grad(set_to_none=True)
            current_steps += 1

            if epoch >= ema_epoch_start and current_steps % hyp['misc']['ema']['every_n_steps'] == 0:          
                ## Initialize the ema from the network at this point in time if it does not already exist.... :D
                if net_ema is None or epoch_step < num_cooldown_before_freeze_steps: # don't snapshot the network yet if so!
                    net_ema = NetworkEMA(net, decay=projected_ema_decay_val)
                    continue
                net_ema.update(net)
          ender.record()
          torch.cuda.synchronize()
          total_time_seconds += 1e-3 * starter.elapsed_time(ender)

          ####################
          # Evaluation  Mode #
          ####################
          net.eval()

          eval_batchsize = 1000
          assert data['eval']['images'].shape[0] % eval_batchsize == 0, "Error: The eval batchsize must evenly divide the eval dataset (for now, we don't have drop_remainder implemented yet)."
          loss_list_val, acc_list, acc_list_ema = [], [], []
          
          with torch.no_grad():
              for inputs, targets in get_batches(data, key='eval', batchsize=eval_batchsize):
                  if epoch >= ema_epoch_start:
                      outputs = net_ema(inputs)
                      outputs = outputs.squeeze() #mp modify sep 7 2023
                      targets = targets.to(torch.float16) #mp modify sep 7 2023
                      acc_list_ema.append((outputs.round() == targets).float().mean()) #mp modify sep 13 2023
                      #acc_list_ema.append((outputs.argmax(-1) == targets).float().mean())
                  outputs = net(inputs)
                  outputs = outputs.squeeze() #mp modify sep 7 2023
                  targets = targets.to(torch.float16) #mp modify sep 7 2023
                  loss_list_val.append(loss_fn(outputs, targets).float().mean())
                  acc_list.append((outputs.round() == targets).float().mean()) #mp modify sep 13 2023
                  #acc_list.append((outputs.argmax(-1) == targets).float().mean())
                  
              val_acc = torch.stack(acc_list).mean().item()
              ema_val_acc = None
              # TODO: We can fuse these two operations (just above and below) all-together like :D :))))
              if epoch >= ema_epoch_start:
                  ema_val_acc = torch.stack(acc_list_ema).mean().item()

              val_loss = torch.stack(loss_list_val).mean().item()
          # We basically need to look up local variables by name so we can have the names, so we can pad to the proper column width.
          ## Printing stuff in the terminal can get tricky and this used to use an outside library, but some of the required stuff seemed even
          ## more heinous than this, unfortunately. So we switched to the "more simple" version of this!
          format_for_table = lambda x, locals: (f"{locals[x]}".rjust(len(x))) \
                                                    if type(locals[x]) == int else "{:0.4f}".format(locals[x]).rjust(len(x)) \
                                                if locals[x] is not None \
                                                else " "*len(x)

          # Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....)
          ## We also check to see if we're in our final epoch so we can print the 'bottom' of the table for each round.
          print_training_details(list(map(partial(format_for_table, locals=locals()), logging_columns_list)), is_final_entry=(epoch == hyp['misc']['train_epochs'] - 1))

    # calculate the accuracy before affine distance model
    net.eval()
    eval_batchsize = 1000
    assert data['eval']['images'].shape[0] % eval_batchsize == 0, "Error: The eval batchsize must evenly divide the eval dataset (for now, we don't have drop_remainder implemented yet)."
    loss_list_val, acc_list, acc_list_ema = [], [], []
    with torch.no_grad():
        for inputs, targets in get_batches(data, key='eval', batchsize=eval_batchsize):
            #if epoch >= ema_epoch_start:
            #    outputs = net_ema(inputs)
            #    # round outputs to 0 or 1
            #    #calculate accuracy
            #    acc_list_ema.append(((outputs > 0.5) == targets).float().mean()) #mp modify sep 13 2023
            #    #acc_list_ema.append((outputs.argmax(-1) == targets).float().mean())
            outputs = net(inputs)
            outputs = outputs.squeeze() #mp modify sep 7 2023
            targets = targets.to(torch.float16) #mp modify sep 7 2023
            loss_list_val.append(loss_fn(outputs, targets).float().mean())
            acc_list.append(((outputs >= 0.5) == targets).float().mean())
            #acc_list.append((outputs.argmax(-1) == targets).float().mean())
            
        val_acc = torch.stack(acc_list).mean().item()
    # calculate the training accuracy after affine distance model
    train_accuracy_torch = 0
    net.eval()
    eval_batchsize = 1000
    assert data['train']['images'].shape[0] % eval_batchsize == 0, "Error: The eval batchsize must evenly divide the eval dataset (for now, we don't have drop_remainder implemented yet)."
    loss_list_val, train_acc_list, acc_list_ema = [], [], []
    with torch.no_grad():
        for inputs, targets in get_batches(data, key='train', batchsize=eval_batchsize):
            # if epoch >= ema_epoch_start:
            #     outputs = net_ema(inputs)
            #     #acc_list_ema.append((outputs.argmax(-1) == targets).float().mean())
            #     acc_list_ema.append(( (outputs>0.5) == targets).float().mean()) #mp modify sep 13 2023
            outputs = net(inputs)
            outputs = outputs.squeeze() #mp modify sep 7 2023
            loss_list_val.append(loss_fn(outputs, targets).float().mean())
            train_acc_list.append(( (outputs >= 0.5) == targets).float().mean()) #mp modify sep 13 2023
            #acc_list.append((outputs.argmax(-1) == targets).float().mean())
        train_accuracy_torch = torch.stack(train_acc_list).mean().item()
    # affine distance model
    # move net to cpu
    net = net.to('cpu')
    # Get the weights of the first and second layer
    Uorg = net.fc2.weight.detach().numpy().T
    bias = net.fc2.bias.detach().numpy()
    Uorg = np.vstack((Uorg, bias))
    w2 = net.fc3.weight.detach().numpy().squeeze()
    # Stack training data. Flatten the image data if required (from (N, C, H, W) to (N, C*H*W))
    # Assuming you already have the get_batches function and other necessary imports
    flattened_images_list = []
    labels_list = []
    for epoch_step, (inputs, targets) in enumerate(get_batches(data, key='train', batchsize=10000)):
        # Flatten the images in the current batch: batch size x image size
        flattened_batch = inputs.view(inputs.size(0), -1).cpu().numpy()  # Convert to numpy immediately after flattening
        flattened_images_list.append(flattened_batch)

        # Append labels of the current batch to the labels list and convert to numpy
        labels_list.append(targets.cpu().numpy())

    # Stack all the flattened images and labels to create numpy matrices
    Xdata = np.vstack(flattened_images_list)
    #pass through the network layer 1
    net = net.to('cuda')
    #import pdb; pdb.set_trace()
    Xdata = F.relu(net.bn1(net.fc1(torch.tensor(np.copy(Xdata)).cuda()))).detach().cpu().numpy()
    #append a column of ones to Xdata
    onestrain = np.ones((Xdata.shape[0], 1))
    Xdata = np.hstack((Xdata, onestrain))
    ydata = np.concatenate(labels_list)

    flattened_images_list = []
    labels_list = []
    for epoch_step, (inputs, targets) in enumerate(get_batches(data, key='eval', batchsize=1000)):
        # Flatten the images in the current batch: batch size x image size
        flattened_batch = inputs.view(inputs.size(0), -1).cpu().numpy()  # Convert to numpy immediately after flattening
        flattened_images_list.append(flattened_batch)

        # Append labels of the current batch to the labels list and convert to numpy
        labels_list.append(targets.cpu().numpy())
    Xdatatest = np.vstack(flattened_images_list)
    #append a column of ones to Xdata
    Xdatatest = F.relu(net.bn1(net.fc1(torch.tensor(np.copy(Xdatatest)).cuda()))).detach().cpu().numpy()
    onestest = np.ones((Xdatatest.shape[0], 1))
    Xdatatest = np.hstack((Xdatatest, onestest))
    ydatatest = np.concatenate(labels_list)        
    # Xdata = train_dataset_gpu['images'].cpu().numpy()
    # ydata = train_dataset_gpu['targets'].cpu().numpy()
    # Xdata = Xdata.reshape(Xdata.shape[0], -1)
    # Stack the images and targets into one matrix
    #ydata = ydata[:, None]
    #import pdb; pdb.set_trace()
    new_U = transform_U(np.copy(Uorg).T,Xdata).T
    #new_U = np.copy(Uorg)
    #new_U = U #debug
    # reoptimize second layer weights using least squares
    #w_ls = np.linalg.pinv(relu(X@new_U))@y
    # regularized pseudo-inverse
    A = relu(Xdata@new_U)
    AtA = A.T@A
    #w_ls = np.linalg.pinv(AtA + A.shape[0]*beta*np.identity(AtA.shape[0]))@(A.T@y)
    beta = hyp['opt']['non_bias_decay']
    beta = 1e-5
    ls_method = 'pseudoinverse'
    energy_percentage = 1
    if ls_method == 'pseudoinverse':
        w_ls = np.linalg.pinv(AtA + beta*np.identity(AtA.shape[0]))@(A.T@ydata)
        #w_ls = np.linalg.pinv(relu(Xdata@new_U))@ydata
        #w_ls = w2 #debug
    #w_ls_convex = np.linalg.lstsq(AtA + A.shape[0]*beta*np.identity(AtA.shape[0]), A.T@y, rcond='warn')[0]
    else: #ls_method == 'truncated': #truncate the svd to 99 percent of the energy
        u, s, vh = np.linalg.svd(AtA, full_matrices=True)
        # Normalize the singular values
        sn = s/s.sum()
        # Truncate singular values
        idx_trunc = np.where(sn.cumsum() <= energy_percentage)[0]
        # Form truncated AtA
        AtA_trunc = u[:,idx_trunc] @ (np.diag(s[idx_trunc]) @ u[:,idx_trunc].T)
        w_ls = np.linalg.pinv(AtA_trunc + A.shape[0]*beta*np.identity(AtA.shape[0]))@(A.T@ydata)
    #apply optimal scaling to new U
    for col in range(new_U.shape[1]):
        if np.abs(w_ls[col]) > 1e-10:
            scale = np.sqrt(np.linalg.norm(new_U[:,col])/np.abs(w_ls[col]))
            new_U[:,col] = new_U[:,col]/scale
            w_ls[col] = w_ls[col]*scale 
        else:
            new_U[:,col] = np.zeros_like(new_U[:,col])
            w_ls[col] = 0
            print('dropped neuron')
    w_ls_final = np.linalg.pinv(relu(Xdata@new_U))@ydata
    #apply optimal scaling to old U
    U2 = np.copy(Uorg)
    U = np.copy(Uorg)
    w_ls2 = np.copy(w_ls)
    for col in range(U.shape[1]):
        if np.abs(w_ls[col]) > 1e-10:
            scale = np.sqrt(np.linalg.norm(U[:,col])/np.abs(w_ls[col]))
            U2[:,col] = U2[:,col]/scale
            w_ls2[col] = w_ls2[col]*scale 
        else:
            U2[:,col] = np.zeros_like(new_U[:,col])
            w_ls2[col] = 0
            print('dropped neuron')        
#    w_ls_old_U = np.linalg.pinv(relu(Xdata@U))@ydata
#    w_ls_old_U2 = np.linalg.pinv(relu(Xdata@U2))@ydata
    net = net.to('cuda')
    #import pdb; pdb.set_trace()
    # Xdatagpu = torch.tensor(Xdata[:,:-1]).half().to('cuda')
    # ydatagpu = torch.tensor(ydata).half().to('cuda')
    # Xdatatestgpu = torch.tensor(Xdatatest[:,:-1]).half().to('cuda')
    # ydatatestgpu = torch.tensor(ydatatest).half().to('cuda')
    #train_accuracy_torch = torch.sum(((net(Xdatagpu)>0.5).T==ydatagpu))/Xdata.shape[0]
    # test_accuracy_torch = torch.sum(((net(Xdatatestgpu)>0.5).T==ydatatestgpu))/Xdatatest.shape[0]
 ##   train_accuracy_torch = np.sum(((relu(Xdata@U)@w2)>0.5)==ydata)/Xdata.shape[0]
    # test_accuracy_direct = np.sum(((relu(Xdatatest@U)@w2)>0.5)==ydatatest)/Xdatatest.shape[0]
    # train_accuracy_direct_old_ls = np.sum(((relu(Xdata@U)@w_ls_old_U)>0.5)==ydata)/Xdata.shape[0]
    # test_accuracy_direct_old_ls = np.sum(((relu(Xdatatest@U)@w_ls_old_U)>0.5)==ydatatest)/Xdatatest.shape[0]
    # train_accuracy_direct_old_scale = np.sum(((relu(Xdata@U2)@w_ls_old_U2)>0.5)==ydata)/Xdata.shape[0]
    # test_accuracy_direct_old_scale = np.sum(((relu(Xdatatest@U2)@w_ls_old_U2)>0.5)==ydatatest)/Xdatatest.shape[0]
 ##   train_accuracy_torch_polished = np.sum(((relu(Xdata@new_U)@w_ls)>0.5)==ydata)/Xdata.shape[0]
    #test_accuracy_direct_polished = np.sum(((relu(Xdatatest@new_U)@w_ls)>0.5)==ydatatest)/Xdatatest.shape[0]
    train_accuracy_direct_polished_resolvels = np.sum(((relu(Xdata@new_U)@w_ls)>0.5)==ydata)/Xdata.shape[0]
    #test_accuracy_direct_polished_resolvels = np.sum(((relu(Xdatatest@new_U)@w_ls_final)>0.5)==ydatatest)/Xdatatest.shape[0]
    #print('test accuracy torch: ', test_accuracy_torch)
    # print('train accuracy direct: ', train_accuracy_direct)
    # print('test accuracy direct: ', test_accuracy_direct)
    # print('train accuracy direct old ls: ', train_accuracy_direct_old_ls)
    # print('test accuracy direct old ls: ', test_accuracy_direct_old_ls)
    # print('train accuracy direct old scale: ', train_accuracy_direct_old_scale)
    # print('test accuracy direct old scale: ', test_accuracy_direct_old_scale)
    # print('train accuracy direct updated: ', train_accuracy_direct_polished)
    # print('test accuracy direct updated: ', test_accuracy_direct_polished)
    # print('train accuracy direct updated resolve ls: ', train_accuracy_direct_polished_resolvels)
    # print('test accuracy direct updated resolve ls: ', test_accuracy_direct_polished_resolvels)
    #update the network net using first layer weights new_U except last column, biases last column of new_U and second layer weights w_ls
    net.fc2.weight.data = torch.tensor(new_U[:-1,:].T).half()
    net.fc2.bias.data = torch.tensor(new_U[-1,:]).half()
    net.fc3.weight.data = torch.tensor(w_ls).half()
    #train_accuracy_torch_polished = torch.sum(((net(Xdatagpu)>0.5).T==ydatagpu))/Xdata.shape[0]
    #move the network to gpu
    net = net.to(hyp['misc']['device'])
    # calculate the test accuracy after affine distance model
    val_acc_new = 0
    net.eval()
    eval_batchsize = 1000
    assert data['eval']['images'].shape[0] % eval_batchsize == 0, "Error: The eval batchsize must evenly divide the eval dataset (for now, we don't have drop_remainder implemented yet)."
    loss_list_val, acc_list, acc_list_ema = [], [], []
    with torch.no_grad():
        for inputs, targets in get_batches(data, key='eval', batchsize=eval_batchsize):
            # if epoch >= ema_epoch_start:
            #     outputs = net_ema(inputs)
            #     #acc_list_ema.append((outputs.argmax(-1) == targets).float().mean())
            #     acc_list_ema.append(( (outputs>0.5) == targets).float().mean()) #mp modify sep 13 2023
            outputs = net(inputs)
            outputs = outputs.squeeze() #mp modify sep 7 2023
            loss_list_val.append(loss_fn(outputs, targets).float().mean())
            acc_list.append(( (outputs >= 0.5) == targets).float().mean()) #mp modify sep 13 2023
            #acc_list.append((outputs.argmax(-1) == targets).float().mean())
        val_acc_new = torch.stack(acc_list).mean().item()
    # calculate the training accuracy after affine distance model
    train_accuracy_torch_polished = 0
    net.eval()
    eval_batchsize = 1000
    assert data['train']['images'].shape[0] % eval_batchsize == 0, "Error: The eval batchsize must evenly divide the eval dataset (for now, we don't have drop_remainder implemented yet)."
    loss_list_val, train_acc_list, acc_list_ema = [], [], []
    with torch.no_grad():
        for inputs, targets in get_batches(data, key='train', batchsize=eval_batchsize):
            # if epoch >= ema_epoch_start:
            #     outputs = net_ema(inputs)
            #     #acc_list_ema.append((outputs.argmax(-1) == targets).float().mean())
            #     acc_list_ema.append(( (outputs>0.5) == targets).float().mean()) #mp modify sep 13 2023
            outputs = net(inputs)
            outputs = outputs.squeeze() #mp modify sep 7 2023
            loss_list_val.append(loss_fn(outputs, targets).float().mean())
            train_acc_list.append(( (outputs >= 0.5) == targets).float().mean()) #mp modify sep 13 2023
            #acc_list.append((outputs.argmax(-1) == targets).float().mean())
        train_accuracy_torch_polished = torch.stack(train_acc_list).mean().item()
    #display accuracy
    print('train accuracy: ', train_accuracy_torch)
    print('train accuracy polished: ', train_accuracy_torch_polished)
    print('test accuracy: ', val_acc)
    print('test accuracy polished: ', val_acc_new)

    return train_accuracy_torch, val_acc, train_accuracy_torch_polished, val_acc_new
if __name__ == "__main__":
    acc_list = {
    'train_accuracy_torch': [],
    'test_accuracy_torch': [],
    'train_accuracy_torch_polished': [],
    'test_accuracy_torch_polished': []
    }
lridx = 0
lr_list = np.logspace(-3, -1, num=20, dtype=float)#[2,4,8,16,32,64,128,256,512,1024,2048,4096]
num_trials = 20
print(lr_list)
#tqdm over the learning rates
for lr in tqdm(lr_list):
    lridx = lridx + 1
    params = {'learning_rate': lr}
    # Lists to store the accuracy over trials
    train_accs, test_accs, train_accs_polished, test_accs_polished = [], [], [], []
    for i in range(num_trials+1):
        print('trial ', i+1, ' out of ', num_trials)
        print('learning rate ', lridx, ' out of ', len(lr_list))
        train_accuracy_torch_trial, test_acc_torch_trial, train_accuracy_torch_polished_trial, test_accuracy_torch_polished_trial = main(params)
        
        # Append the accuracy to the respective lists
        train_accs.append(train_accuracy_torch_trial)
        test_accs.append(test_acc_torch_trial)
        train_accs_polished.append(train_accuracy_torch_polished_trial)
        test_accs_polished.append(test_accuracy_torch_polished_trial)

    # Calculate the average accuracy and append to acc_list
    acc_list['train_accuracy_torch'].append(np.mean(train_accs))
    acc_list['test_accuracy_torch'].append(np.mean(test_accs))
    acc_list['train_accuracy_torch_polished'].append(np.mean(train_accs_polished))
    acc_list['test_accuracy_torch_polished'].append(np.mean(test_accs_polished))

    
# Extract values from the acc_list dictionary

test_acc_torch = acc_list['test_accuracy_torch']
test_acc_polished = acc_list['test_accuracy_torch_polished']
# Create the plots
plt.figure(figsize=(10, 6))
#import pdb; pdb.set_trace()
plt.plot(lr_list, test_acc_torch, color='blue', label='Validation Accuracy - Original', linestyle='--')
plt.plot(lr_list, test_acc_polished, color='red', label='Validation Accuracy - Polished', linestyle='--')
plt.plot(lr_list, acc_list['train_accuracy_torch'], color='blue', label='Train Accuracy - Original', linestyle='-')
plt.plot(lr_list, acc_list['train_accuracy_torch_polished'], color='red', label='Train Accuracy Polished', linestyle='-')



# Adding title, legend, and labels
plt.title("Accuracy vs. learning rate")
#plt.xlabel("Number of Neurons")
plt.xlabel("Learning rate")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True, which='both', linestyle='--', linewidth=0.5)

# Save the plot to a PDF file
now = datetime.now()
timestamp = datetime.timestamp(now)
timestamp = datetime.fromtimestamp(timestamp).strftime('%m%d%Y-%H%M%S')
plt.savefig("ubuntu_3_layer_accuracy_vs_lr_"+str(timestamp)+".pdf")

# Optionally display the plot
plt.show()
    


