# -*- coding: utf-8 -*-
"""
Code for the experiements of Towards Understanding Neural Collapse: The Effects
of Batch Normalization and Weight Decay

Written based on the demonstration code of [1] V. Papyan, X.Y. Han, and D.L. Donoho. "[Prevalence of Neural Collapse During the Terminal Phase of Deep Learning Training.](https://www.pnas.org/content/117/40/24652)" *Proceedings of the National Academy of Sciences (PNAS)* 117, no. 40 (2020): 24652-24663.

Below is the original README of the code of [1]:

# Neural Collapse Examples

Code demonstrating Neural Collapse on Cross-Entropy [1] and MSE Loss [2].
Notebook is designed to be short, easy-to-interpret, and executable
from the browser using Google Colab.

MNIST-ResNet18 was chosen because it ran most reliably within the in-browser
memory constraints of Google Colab.
If you are *still* getting out-of-memory errors, try clicking
"Runtime"->"Factory Reset Runtime" on the menu bar.

It should be clear how to adapt the code to other networks-dataset combinations
to be run on local clusters with more memory.

### References:

[1] V. Papyan, X.Y. Han, and D.L. Donoho. "[Prevalence of Neural Collapse During the Terminal Phase of Deep Learning Training.](https://www.pnas.org/content/117/40/24652)" *Proceedings of the National Academy of Sciences (PNAS)* 117, no. 40 (2020): 24652-24663.

[2] X.Y. Han, V. Papyan, and D.L. Donoho. [“Neural Collapse Under MSE Loss: Proximity to and Dynamics on the Central Path.”](https://openreview.net/forum?id=w1UbdvWH_R3) *International Conference on Learning Representations (ICLR)*, 2022.
"""

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.models as models

from tqdm.auto import tqdm
from collections import OrderedDict
from scipy.sparse.linalg import svds
from torchvision import datasets, transforms
import itertools

from tiny_imagenet import TinyImageNetDataset

debug = False

# constants
custom_data_path = {
    'conic': './dataset/conic.npz',
    'mlp3': './mlp_3.npz',
    'mlp6': './mlp.npz',
    'mlp9': './mlp_9.npz',
    'mlp3_W20': './mlp3_W20.npz',
}


epoch_list          = [1,   2,   3,   4,   5,   6,   7,   8,   9,   10,   11,
                       12,  13,  14,  16,  17,  19,  20,  22,  24,  27,   29,
                       32,  35,  38,  42,  45,  50,  54,  59,  65,  71,   77,
                       85,  92,  101, 110, 121, 132, 144, 158, 172, 188,  206,
                       225, 245, 268, 293, 320, 350]

model_save_path = './New_Models'
data_save_path = './Expt_Data'
figure_save_path = './figures'
tiny_imagenet_path = '../data/tiny-Imagenet/tiny-imagenet-200'

# if None, retrain model. Otherwise, the parameter contains the path to the model to load
train_models = True

# Tunable Hyperparameters
# Optimization Criterion
loss_name = 'CrossEntropyLoss'

# 'MLP' for linear netowrk with ReLU activation. ResNet for ResNet18. CNN for vanila CNN
model_type = 'MLP'

# 'conic' for conic hull data loaded from drive. 
# 'conic_gen' for generation of conic hull data and save to conic_data_path
# 'MNIST' 'CIFAR10'
dataset = 'mlp6' 

# conic data only: dimention of conic hull data input
input_dim = 16
train_samples = 8000
test_samples = 2000
num_classes = 4

# Linear network only: model depth
model_depth_MLP = 6
hidden_layer_width = 200

# Use random labels for dataset
random_labels = False

# BatchNormalization Hyperparameters
bn = True
bn_affine = True
bn_eps = 1e-5

# Linear network only: whether to use use bias in linear layers.
linear_bias = True

# Linear network only: normalize weight matrices after each iteraiton
weight_norm = False

# Linear Network Only: Use InstanceNorm instead of BatchNorm
inst_norm = False

# Linear Network Only: Use LayerNorm
layer_norm = False

# Linear network only: add a loss term to directly optimize NC measures at each level
nc_train = False
nc_coeff = 0.1

# GD instead of batched SGD
no_batch = False
adam = False

# Linear network only: Activation between layers
activation_MLP = 'ReLU'

# nc_train only: the coefficient of cosine similaity
lamb = 1

# Optimization hyperparameters
lr_decay            = 0.1

# Best lr after hyperparameter tuning
if loss_name == 'CrossEntropyLoss':
  lr = 0.0679
  #lr = 0.0184
elif loss_name == 'MSELoss':
  lr = 0.0184

if nc_train and model_type == 'MLP':
  lr /= 5

if debug:
  epochs            = 10
else:
  epochs            = 300
epochs_lr_decay     = [epochs//4, epochs*2//4, epochs*3//4]

batch_size          = 128

momentum            = 0.9
weight_decay        = 1e-2
rand_seed           = 12138
default_hypers      = None

def combine_dicts(base_dict, enum_dict):
    # Initialize an empty list to store the result
    result = []
    # Get the keys and values from the enum_dict
    enum_keys = list(enum_dict.keys())
    enum_values = list(enum_dict.values())
    # Create a list of all combinations of enum_dict values
    enum_combinations = list(itertools.product(*enum_values))
    # Loop through the enum_dict combinations
    for enum_combination in enum_combinations:
        # Create a new dictionary for this combination
        new_dict = base_dict.copy()
        # Add the enum_dict values to the new dictionary
        for i, key in enumerate(enum_keys):
            new_dict[key] = enum_combination[i]
        # Add the new dictionary to the result list
        result.append(new_dict)
    return result


if train_models:
  if default_hypers is None:
    default_hypers = globals().copy()
  # The hyperparameter combinations to run. Uncomment the corresponding lines to run the experiments
  # hyperparameter_tables = combine_dicts({"dataset": "MNIST", "model_type" : "vgg11", "lr": 1e-3, "epochs": 100, "bn": False}, {"weight_decay": [1e-4, 3e-4, 5e-4, 7e-4, 1e-3, 3e-3, 5e-3, 7e-3, 1e-2, 2e-2, 3e-2], "rand_seed": [314159, 265358, 979323, 846264, 338327]})
  # hyperparameter_tables = [{"dataset": "CIFAR10", "model_type" : "vgg19", "bn": True, "lr": 1e-3, "epochs": 100, "weight_decay": 5e-3, "train_samples": None, "train_samples": None}]
  #  hyperparameter_tables = combine_dicts({"dataset": "MNIST", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None}, {"model_type": ["vgg11", "vgg19"], "weight_decay": [1e-4, 3e-4, 5e-4, 1e-3, 3e-3, 5e-3, 1e-2, 2e-2, 3e-2], "bn": [True, False], "rand_seed": [314159, 265358, 979323]})
  # hyperparameter_tables = combine_dicts({"dataset": "CIFAR10", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None}, {"model_type": ["vgg11", "vgg19"], "weight_decay": [1e-4, 3e-4, 5e-4, 1e-3, 3e-3, 5e-3, 1e-2, 2e-2, 3e-2], "bn": [True, False], "rand_seed": [314159, 265358, 979323]})
  if not debug:
    hyperparameter_tables = combine_dicts({"dataset": "CIFAR100", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None}, {"model_type": ["vgg11", "vgg19"], "weight_decay": [1e-4, 3e-4, 5e-4, 1e-3, 3e-3, 5e-3, 1e-2, 2e-2, 3e-2], "bn": [True, False], "rand_seed": [314159, 265358, 979323]})
  else:
    hyperparameter_tables = [{"dataset": "CIFAR100", "lr": 1e-2, "epochs": 10, "train_samples": 200, "test_samples": 100, "model_type": "vgg11"}]
  print("Total number of Experiments:" + str(len(hyperparameter_tables)))
else:
  load_model = ''
  hyperparameter_table = {"dataset": "mlp6", "model_depth_MLP": 15, "train_samples": 4000, "test_samples": 1000, "weight_decay": 5e-3}

# Set hyperparmeters according to 
def set_hypers(hypers_dict):
  for key in hypers_dict.keys():
    if key not in globals():
      raise Exception(f'Cannot set hyperparameter named {key}')
    globals()[key] = hypers_dict[key]

def restore_hypers(hypers_dict):
  for key in hypers_dict.keys():
    if key not in globals():
      raise Exception(f'Cannot set hyperparameter named {key}')
    globals()[key] = default_hypers[key]

def dict_to_file_string(d):
  dict_str = ''
  first = True
  for key in sorted(d.keys()):
    if first:
      first = False
    else:
      dict_str += '_'
    dict_str += f'{key}_{d[key]}'
  return dict_str

import os
import random
import numpy as np

# Reproducibility
def set_all_seeds(seed):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True

"""# Dataset Generation"""

import numpy as np

def generate_data_plane(num_classes, num_points, dim):
    r"""
    Generate conic hull data through plane separation
    """
    num_planes = int(np.ceil(np.log2(num_classes)))
    planes = np.random.randn(num_planes, dim)
    datapoints = np.random.randn(num_points, dim)
    class_bin = ((datapoints @ planes.T) > 0).astype(int)
    class_orig = (class_bin * (2 ** np.arange(class_bin.shape[1])[::-1])).sum(axis=1).astype('float')
    diff = 2 ** num_planes - num_classes
    class_orig[class_orig < 2 * diff] //= 2.0
    class_orig[class_orig >= 2 * diff] -= diff
    print('Data count:', np.unique(class_orig.astype(int), return_counts=True))
    return datapoints, class_orig.astype(int)

def generate_MLP_data(num_points, dim):
  import time
  set_all_seeds(989996)
  data_model = MLP(layer_width).to(device)
  datapoints = np.random.randn(num_points, dim)
  data_tensor = torch.tensor(datapoints).to(device).float()
  outputs = model(data_tensor)
  _, labels = torch.max(outputs, 1)
  labels_numpy = labels.cpu().detach().numpy()
  set_all_seeds(rand_seed)
  return datapoints, labels_numpy, data_model

"""TODO: put in separate file"""

import imageio

def _add_channels(img, total_channels=3):
  while len(img.shape) < 3:  # third axis is the channels
    img = np.expand_dims(img, axis=-1)
  while(img.shape[-1]) < 3:
    img = np.concatenate([img, img[:, :, -1:]], axis=-1)
  return img


"""# Linear Network Definition"""

import torch
from torch import nn

class MLP(nn.Module):
    def __init__(self, layer_widths, bn=bn, weight_norm=weight_norm, inst_norm=inst_norm, layer_norm=layer_norm):
        if activation_MLP == 'ReLU':
          act_layer = nn.ReLU
        if activation_MLP == 'tanh':
          act_layer = nn.Tanh
        super().__init__()
        self.layer_widths = layer_widths
        self.weight_norm = weight_norm
        self.inst_norm = inst_norm
        self.layer_norm = layer_norm
        self.bn = bn
        layers = []
        for i in range(len(layer_widths) - 2):
          if self.weight_norm:
            layers.append(nn.utils.weight_norm(nn.Linear(layer_widths[i], layer_widths[i+1], bias=linear_bias), dim=None))
          else:
            layers.append(nn.Linear(layer_widths[i], layer_widths[i+1], bias=linear_bias))
          layers.append(act_layer())
          if self.bn:
            layers.append(torch.nn.BatchNorm1d(layer_widths[i+1], eps=bn_eps, affine=bn_affine))
          elif self.inst_norm:
            layers.append(torch.nn.InstanceNorm1d(layer_widths[i+1]))
          elif self.layer_norm:
            layers.append(torch.nn.LayerNorm(layer_widths[i+1], elementwise_affine=False))
        self.last_layer = nn.Linear(layer_widths[-2], layer_widths[-1], bias=linear_bias)
        self.feature = nn.Sequential(*layers)
        

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        features = self.feature(x)
        return self.last_layer(features)

    def last_layer_feat(self, x):
        return self.feature(x)

    def all_features(self, x):
        features = []
        for i in range(len(self.feature)):
          x = self.feature[i](x)
          if isinstance(self.feature[i], nn.Linear):
            features.append(x)
        x = self.last_layer(x)
        features.append(x)
        return features

    def nc_loss(self, x, y):
        loss = torch.tensor(0)
        decay_fac = 0.75
        all_feats = self.all_features(x)
        cur_fac = decay_fac ** len(all_feats)
        for feature in all_feats:
          loss = loss + cur_fac * dist_from_etf(feature, y)
          cur_fac /= decay_fac
        return loss

    def num_layers(self):
        return len(self.model)

    def normalize_weight(self):
      with torch.no_grad():
        for name, p in self.named_parameters():
          if 'weight' in name or 'bias' in name:
            p /= torch.norm(p)

    
    def layer_feat(self, x, i):
        for j in range(i):
            x = self.model[j](x)
        return x

    def last_layer_weight(self):
        return self.model[-1].weight

# Compute how close to ETF is a set of features w.r.t. labels
def dist_from_etf(x, y, lamb=lamb):
  total_var = torch.tensor(0).to(device)
  means = []
  for i in torch.unique(y):
    x_i = x[y == i]
    means.append(x_i.mean(axis=0))
    x_central = x_i - x_i.mean(axis=0)
    total_var = total_var + (x_central ** 2).sum() / x_i.shape[0]
  total_var = total_var / num_classes
  means = torch.stack(means)
  means = means / torch.linalg.norm(means, axis=1).view(-1, 1)
  cos_sim = means @ means.T
  cos_sim_flat = cos_sim.masked_select(~torch.eye(num_classes, dtype=bool).to(device))
  cos_diff = ((cos_sim_flat + 1 / (num_classes - 1)) ** 2).mean()
  return total_var + lamb * cos_diff

"""# Training & Analysis Function"""

def train(model, criterion, device, num_classes, train_loader, optimizer, epoch, callback=None):
    model.train()
    
    pbar = tqdm(total=len(train_loader), position=0, leave=True, ncols=120)
    for batch_idx, (data, target) in enumerate(train_loader, start=1):
        if data.shape[0] != batch_size:
            continue
        
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data)
        if str(criterion) == 'CrossEntropyLoss()':
          loss = criterion(out, target)
        elif str(criterion) == 'MSELoss()':
          loss = criterion(out, F.one_hot(target, num_classes=num_classes).float())
        if nc_train and model_type == 'MLP':
          loss = loss + nc_coeff * model.nc_loss(data, target)
        
        loss.backward()
        optimizer.step()
        #if weight_reg and model_type == 'MLP':
        #  model.normalize_weight()
        accuracy = torch.mean((torch.argmax(out,dim=1)==target).float()).item()
        pbar.update(1)
        pbar.set_description(
            'Train\tEpoch: {} [({:.0f}%)] \t'
            'Loss: {:.6f} \t'
            'Accuracy: {:.6f}'.format(
                epoch,
                100. * batch_idx / len(train_loader),
                loss.item(),
                accuracy))
        if callback:
          callback(model, data, target)
    pbar.close()

def analysis(graphs, model, classifier, criterion_summed, device, num_classes, loader):
  loss, accuracy, Sw_invSb, NCC_mismatch, reg_loss, norm_M_CoV, norm_W_CoV, W_M_dist, cos_M, cos_W = analysis_wrapper(model, classifier, criterion_summed, device, num_classes, loader)
  graphs.loss.append(loss)
  graphs.accuracy.append(accuracy)
  graphs.reg_loss.append(reg_loss)
  graphs.norm_M_CoV.append(norm_M_CoV)
  graphs.norm_W_CoV.append(norm_W_CoV)
  graphs.Sw_invSb.append(Sw_invSb)
  graphs.NCC_mismatch.append(NCC_mismatch)
  graphs.W_M_dist.append(W_M_dist)
  graphs.cos_M.append(cos_M)
  graphs.cos_W.append(cos_W)

def analysis_wrapper(model, classifier, criterion_summed, device, C, loader):
    model.eval()

    N             = [0 for _ in range(C)]
    mean          = [0 for _ in range(C)]
    Sw            = 0
    dim_feats     = 0

    loss          = 0
    net_correct   = 0
    NCC_match_net = 0

    for computation in ['Mean','Cov']:
        pbar = tqdm(total=len(loader), position=0, leave=True, ncols=120)
        for batch_idx, (data, target) in enumerate(loader, start=1):

            data, target = data.to(device), target.to(device)
            
            output = model(data)
            h = features.value.data.view(data.shape[0],-1) # B CHW
            dim_feats = h.shape[1]

            # during calculation of class means, calculate loss
            if computation == 'Mean':
                if str(criterion_summed) == 'CrossEntropyLoss()':
                  loss += criterion_summed(output, target).item()
                elif str(criterion_summed) == 'MSELoss()':
                  loss += criterion_summed(output, F.one_hot(target, num_classes=C).float()).item()

            for c in range(C):
                # features belonging to class c
                idxs = (target == c).nonzero(as_tuple=True)[0]

                if len(idxs) == 0: # If no class-c in this batch
                  continue

                h_c = h[idxs,:] # B CHW
                if computation == 'Mean':
                    # update class means
                    mean[c] += torch.sum(h_c, dim=0) # CHW
                    N[c] += h_c.shape[0]
                    
                elif computation == 'Cov':
                    # update within-class cov

                    z = h_c - mean[c].unsqueeze(0) # B CHW
                    cov = torch.matmul(z.unsqueeze(-1), # B CHW 1
                                       z.unsqueeze(1))  # B 1 CHW
                    Sw += torch.sum(cov, dim=0)

                    # during calculation of within-class covariance, calculate:
                    # 1) network's accuracy
                    net_pred = torch.argmax(output[idxs,:], dim=1)
                    net_correct += sum(net_pred==target[idxs]).item()

                    # 2) agreement between prediction and nearest class center
                    NCC_scores = torch.stack([torch.norm(h_c[i,:] - M.T,dim=1) \
                                              for i in range(h_c.shape[0])])
                    NCC_pred = torch.argmin(NCC_scores, dim=1)
                    NCC_match_net += sum(NCC_pred==net_pred).item()

            pbar.update(1)
            pbar.set_description(
                'Analysis {}\t'
                'Epoch: [{}/{} ({:.0f}%)]'.format(
                    computation,
                    batch_idx,
                    len(loader),
                    100. * batch_idx/ len(loader)))
        pbar.close()
        
        if computation == 'Mean':
            for c in range(C):
                if N[c] == 0:
                  mean[c] = torch.zeros(dim_feats, device=device)
                else:
                  mean[c] /= N[c]
            M = torch.stack(mean).T
            loss /= sum(N)
        elif computation == 'Cov':
            Sw /= sum(N)
    accuracy = net_correct/sum(N)
    NCC_mismatch = 1 - NCC_match_net/sum(N)

    # loss with weight decay
    reg_loss = loss
    for param in model.parameters():
        reg_loss += 0.5 * weight_decay * torch.sum(param**2).item()
    

    # global mean
    muG = torch.mean(M, dim=1, keepdim=True) # CHW 1
    
    # between-class covariance
    M_ = M - muG
    Sb = torch.matmul(M_, M_.T) / C

    # avg norm
    W  = classifier.weight
    M_norms = torch.norm(M_,  dim=0)
    W_norms = torch.norm(W.T, dim=0)

    norm_M_CoV = (torch.std(M_norms)/torch.mean(M_norms)).item()
    norm_W_CoV = (torch.std(W_norms)/torch.mean(W_norms)).item()

    # tr{Sw Sb^-1}
    Sw = Sw.cpu().numpy()
    Sb = Sb.cpu().numpy()
    eigvec, eigval, _ = svds(Sb, k=C-1)
    inv_Sb = eigvec @ np.diag(eigval**(-1)) @ eigvec.T 
    Sw_invSb = np.trace(Sw @ inv_Sb)

    # ||W^T - M_||
    normalized_M = M_ / torch.norm(M_,'fro')
    normalized_W = W.T / torch.norm(W.T,'fro')
    W_M_dist = (torch.norm(normalized_W - normalized_M)**2).item()

    # mutual coherence
    def coherence(V): 
        G = V.T @ V
        G += torch.ones((C,C),device=device) / (C-1)
        G -= torch.diag(torch.diag(G))
        return torch.norm(G,1).item() / (C*(C-1))

    cos_M = coherence(M_/M_norms)
    cos_W = coherence(W.T/W_norms)

    return loss, accuracy, Sw_invSb, NCC_mismatch, reg_loss, norm_M_CoV, norm_W_CoV, W_M_dist, cos_M, cos_W

def analysis_str(loss, accuracy, Sw_invSb, NCC_mismatch, reg_loss, norm_M_CoV, norm_W_CoV, W_M_dist, cos_M, cos_W):
    out_str = ""
    out_str += f"Average Loss: {loss}\n"
    out_str += f"Accuracy: {accuracy}\n"
    out_str += f"NC1 Within Class Collapse: {Sw_invSb}\n"
    out_str += f"NC2 Equinorm: Features: {norm_M_CoV}, Weights: {norm_W_CoV}\n"
    out_str += f"NC2 Equiangle: Features: {cos_M}, Weights: {cos_W}\n"
    out_str += f"NC3 Self-Duality: {W_M_dist}\n"
    out_str += f"NC4 NCC Mismatch: {NCC_mismatch}\n\n"
    return out_str

def matrix_rank(m, eps=0.01):
  u, s, vh = torch.svd(m)
  total = 0
  s /= s.sum()
  for i in range(len(s)):
    total += s[i].item()
    if total > 1 - eps:
      return i + 1
  raise Exception("SVD Invalid")

# custom analysis function
def full_analysis(model, modules, loader, num_classes, output_layer=True):
    model.eval()
    num_modules = len(modules)
    num_modules_o = num_modules + 1 if output_layer else num_modules
    remove_all_hooks(model)
    hook_group(modules)

    # Output variables, features are all centered
    weight_norms = []
    bn_norms = []
    intra_cos = []
    inter_cos = []
    qmean_norms = []
    # Weight Matrix Rank
    ranks = []
    feature_ranks = []

    # Since no update here anyway
    with torch.no_grad():
      # Data-independent measures
      for m in modules:
        if isinstance(m, nn.Linear):
          weight_norms.append(torch.norm(m.weight).cpu().item())
          ranks.append(matrix_rank(m.weight))
        if isinstance(m, nn.BatchNorm1d):
          bn_norms.append(torch.norm(m.weight).cpu().item())

      # First interation to get global means, class means, loss and accuracy
      
      # Global means for each layer
      means = [0 for _ in range(num_modules_o)]

      # Class Means for each layer for each class
      class_means = [[0 for _ in range(num_classes)] for _ in range(num_modules_o)]

      # Total number of data
      cnt = 0

      # Number of data for each class
      class_cnt = [0 for _ in range(num_classes)]
      pbar = tqdm(total=len(loader), position=0, leave=True, ncols=120)
      for batch_idx, (data, target) in enumerate(loader, start=1):
        data, target = data.to(device), target.to(device)
        output = model(data)
        for i in range(num_modules_o):
          if i == num_modules:
            class_features_i = output.view(output.shape[0], -1)
          else:
            class_features_i = features.values[i].view(output.shape[0], -1)
          for c in range(num_classes):
            c_indices = (target == c)
            if i == 0:
              class_cnt[c] += c_indices.int().sum().item()
            class_means[i][c] =  class_means[i][c] + class_features_i[c_indices].sum(axis=0)
          means[i] += class_features_i.sum(axis=0)
          if i == 0:
            cnt += class_features_i.shape[0]
        pbar.update(1)
      # change sums to means
      for i in range(num_modules_o):
        means[i] /= cnt
        for c in range(num_classes):
          class_means[i][c] /= class_cnt[c]
        class_means[i] = torch.stack(class_means[i])

      # Second iteration computes cos similarities

      # Num of vecs for each class
      cnts = [0 for _ in range(num_classes)]

      # Unit Norm Feature Vectors
      normed_vecs = [[0 for _ in range(num_classes)] for _ in range(num_modules_o)]

      # Feature Vector Norms
      norms = [0 for _ in range(num_modules_o)]

      # Nearest Class Center Classifcation Accuracy
      nccs = [0 for _ in range(num_modules_o)]

      pbar = tqdm(total=len(loader), position=0, leave=True, ncols=120)
      for batch_idx, (data, target) in enumerate(loader, start=1):
        num_samples = 0
        data, target = data.to(device), target.to(device)
        output = model(data)

        for i in range(num_modules_o):
          if i == num_modules:
            class_features_i = output.view(output.shape[0], -1)
          else:
            class_features_i = features.values[i].view(output.shape[0], -1)

          # Compute Class Mean in the First Interation

          # Center feature relative to global mean
          centered_features =  class_features_i - means[i]

          # Quadratic Average of Vector Norms
          norms[i] += torch.norm(centered_features) ** 2
          pairwise_dists = torch.norm(class_features_i[:, None, :] - class_means[i][None, :, :], dim=-1)

          # Number of correct NCC predictions
          ncc_pred = torch.argmin(pairwise_dists, dim=1)
          nccs[i] += (ncc_pred == target).int().sum().item()
          for c in range(num_classes):
            # Features for class c
            centered_features_c = centered_features[target == c]
            # Normalize each feature to norm 1
            centered_features_c_normed = centered_features_c / torch.norm(centered_features_c, dim=1).reshape(-1, 1)
            normed_vecs[i][c] += centered_features_c_normed.sum(dim=0)
            if i == 0:
              cnts[c] += centered_features_c.shape[0]
          
        pbar.update(1)
      cnts = torch.tensor(cnts).to(device)
      # Compute Inter and Intra cos using mean normalized vectors
      for i in range(num_modules_o):
        nccs[i] /= cnt
        qmean_norms.append(torch.sqrt(norms[i] / cnt))
        for c in range(num_classes):
          normed_vecs[i][c] /= cnts[c]
        full_cos = torch.stack(normed_vecs[i]) @ torch.stack(normed_vecs[i]).T
        intra_cos_vals = torch.diag(full_cos)
        intra_cos_vals *= (cnts / (cnts - 1))
        intra_cos_vals -= (1 / (cnts - 1))
        intra_cos.append(intra_cos_vals.min().item())
        inter_cos.append(full_cos[torch.eye(full_cos.shape[0], dtype=int) == 0].max().item())
    return intra_cos, inter_cos, qmean_norms, bn_norms, weight_norms, nccs, ranks

def full_analysis_str(intra_cos, inter_cos, qmean_norms, bn_norms, weight_norms, nccs, ranks, num_modules):
  output_str = ""
  bn_id = 0
  linear_id = 0
  for i in range(num_modules + 1):
    if i == num_modules:
      output_str += f"Output Layer:\n"
    else:
      output_str += f"Layer {i}: {hooked_modules[i]}\n"
      if isinstance(hooked_modules[i], nn.Linear):
        output_str += f"Linear Weight Norm: {weight_norms[linear_id]}\n"
        output_str += f"Linear Weight Rank: {ranks[linear_id]}\n"
        linear_id += 1
      elif isinstance(hooked_modules[i], nn.BatchNorm1d):
        output_str += f"Batch Normalization Weight Norm: {bn_norms[bn_id]}\n"
        bn_id += 1
    output_str += f"Intra Cos: {intra_cos[i]}\n"
    output_str += f"Inter Cos: {inter_cos[i]}\n"
    output_str += f"Norm Quadratic Average: {qmean_norms[i]}\n"
    output_str += f"Nearest Class Center Accuracy: {nccs[i]}\n"
    output_str += "\n"
  return output_str

"""# Load dataset"""

def get_dataset(dataset_name, train_samples, test_samples, random_labels=False):
  target_trans = (lambda y: torch.randint(0, C, (1,)).item()) if random_labels else None
  if dataset_name == 'MNIST':
    transform = transforms.Compose([transforms.Pad((padded_im_size - im_size)//2),
                                  transforms.ToTensor(),
                                  transforms.Normalize(0.1307,0.3081)])

    train_set = datasets.MNIST('../data/MNIST', train=True, download=True, transform=transform, target_transform=target_trans)
    if train_samples:
      train_set = torch.utils.data.Subset(train_set, range(train_samples))
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)

    test_set = datasets.MNIST('../data/MNIST', train=False, download=True, transform=transform, target_transform=target_trans)
    if test_samples:
      test_set = torch.utils.data.Subset(test_set, range(test_samples))
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=8)
    in_channels = 1
    num_classes = 10

  if dataset_name == 'CIFAR10':
    transform = transforms.Compose(
    [transforms.Pad((padded_im_size - im_size)//2),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_set = datasets.CIFAR10(root='./data/CIFAR10', train=True, download=True, transform=transform, target_transform=target_trans)
    if train_samples:
      train_set = torch.utils.data.Subset(train_set, range(train_samples))
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
    
    test_set = datasets.CIFAR10(root='./data/CIFAR10', train=False, download=True, transform=transform, target_transform=target_trans)
    if test_samples:
      test_set = torch.utils.data.Subset(test_set, range(test_samples))
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=8)
    in_channels = 3
    num_classes = 10

  if dataset_name == 'CIFAR100':
    transform = transforms.Compose(
    [transforms.Pad((padded_im_size - im_size)//2),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_set = datasets.CIFAR100(root='./data/CIFAR100', train=True, download=True, transform=transform, target_transform=target_trans)
    if train_samples:
      train_set = torch.utils.data.Subset(train_set, range(train_samples))
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
    
    test_set = datasets.CIFAR100(root='./data/CIFAR100', train=False, download=True, transform=transform, target_transform=target_trans)
    if test_samples:
      test_set = torch.utils.data.Subset(test_set, range(test_samples))
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=8)
    in_channels = 3
    num_classes = 100

  if dataset_name == 'conic' or dataset_name[:3] == 'mlp':
    print(f'Using custom dataset at {custom_data_path[dataset_name]}')
    npz = np.load(custom_data_path[dataset_name])
    X, y = npz['X'], npz['y']
    if random_labels:
      y = torch.randint(0, C, (num_samples, ))

    tensor_X, tensor_y = torch.tensor(X).float(), torch.tensor(y).to(torch.int64)
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(tensor_X[:train_samples, :], tensor_y[:train_samples]),
        batch_size=batch_size, shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(tensor_X[train_samples: train_samples + test_samples, :], tensor_y[train_samples: train_samples + test_samples]),
        batch_size=batch_size, shuffle=True)
    num_classes = C
    in_channels = None

  if dataset_name == 'tImageNet':
    transform = transforms.Compose(
      [transforms.ToTensor(),
      transforms.Pad((padded_im_size - im_size)//2),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_dataset = TinyImageNetDataset(tiny_imagenet_path, mode='train', transform=transform, max_samples=train_samples)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    test_dataset = TinyImageNetDataset(tiny_imagenet_path, mode='val', transform=transform, max_samples=test_samples)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    in_channels = 3
    num_classes = 200

  return train_loader, test_loader, num_classes, in_channels

"""# Actual Run"""

class Features:
    def __init__(self):
      self.values = {}

features = Features()

def hook(self, input, output):
    features.value = input[0].clone()

def hook_helper(module, i):
    def hook_temp(self, input, output):
      features.values[i] = input[0].clone()
    module.register_forward_hook(hook_temp)

def hook_group(modules):
  for i, module in enumerate(modules):
    hook_helper(module, i)

import torch
import torchvision.models as models

def load_vgg_model(model_type, bn=False, num_classes=10, input_channels=3):
    if model_type == 'vgg11':
        if bn:
            model = models.vgg11_bn(pretrained=False, num_classes=num_classes)
        else:
            model = models.vgg11(pretrained=False, num_classes=num_classes)
    elif model_type == 'vgg13':
        if bn:
            model = models.vgg13_bn(pretrained=False, num_classes=num_classes)
        else:
            model = models.vgg13(pretrained=False, num_classes=num_classes)
    elif model_type == 'vgg16':
        if bn:
            model = models.vgg16_bn(pretrained=False, num_classes=num_classes)
        else:
            model = models.vgg16(pretrained=False, num_classes=num_classes)
    elif model_type == 'vgg19':
        if bn:
            model = models.vgg19_bn(pretrained=False, num_classes=num_classes)
        else:
            model = models.vgg19(pretrained=False, num_classes=num_classes)
    else:
        raise ValueError('Invalid VGG model type: ' + model_type)

    # Modify the first layer to accept the specified number of input channels
    model.features[0] = torch.nn.Conv2d(input_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    # Get a list of convolutional layers in the features network and all linear layers in the classifier network
    modules = []
    for m in model.features.modules():
        if isinstance(m, torch.nn.Conv2d):
            modules.append(m)
    for m in model.classifier.modules():
        if isinstance(m, torch.nn.Linear):
            modules.append(m)

    return model, modules

def get_model(model_type, num_classes, in_channels):
  hooked_modules = []
  if model_type == 'ResNet':
    if dataset == 'conic' or dataset == 'conic_gen':
      raise Exeception("Can't use Resnet with none-image dataset")
    model = models.resnet18(pretrained=False, num_classes=num_classes)
    model.conv1 = nn.Conv2d(in_channels, model.conv1.weight.shape[0], 3, 1, 1, bias=False) # Small dataset filter size used by He et al. (2015)
    model.maxpool = nn.MaxPool2d(kernel_size=1, stride=1, padding=0)
    # register hook that saves last-layer input into features
    classifier = model.fc
    hooked_modules = [model.conv1, model.layer1, model.layer2, model.layer3, model.layer4, model.fc]

  elif model_type == 'MLP':
    depths = layer_width = [input_dim] + [hidden_layer_width] * (model_depth_MLP - 1) + [num_classes]
    C = num_classes
    model = MLP(layer_width, bn=bn, weight_norm=weight_norm)
    classifier = model.last_layer
    for m in model.feature:
      if isinstance(m, nn.Linear):
        hooked_modules += [m]
    hooked_modules += [model.last_layer]

  elif model_type.startswith('vgg'):
    model, modules = load_vgg_model(model_type, bn, num_classes, in_channels)
    classifier = model.classifier[-1]
    hooked_modules = modules

  elif model_type == 'CNN':
    if dataset == 'conic' or dataset == 'conic_gen':
      raise Exeception("Can't use CNN with none-image dataset")
    model = VanillaCNN(input_ch)
    classifier = model.fc3
    hooked_modules = [model.conv1, model.conv2, model.fc1, model.fc3]

  classifier.register_forward_hook(hook)

  model = model.to(device)
  return model, classifier, hooked_modules

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class Graphs:
  def __init__(self):
    self.accuracy     = []
    self.loss         = []
    self.reg_loss     = []

    # NC1
    self.Sw_invSb     = []

    # NC2
    self.norm_M_CoV   = []
    self.norm_W_CoV   = []
    self.cos_M        = []
    self.cos_W        = []

    # NC3
    self.W_M_dist     = []
    
    # NC4
    self.NCC_mismatch = []

from collections import OrderedDict
from typing import Dict, Callable
import torch

def remove_all_hooks(model: torch.nn.Module) -> None:
    for name, child in model._modules.items():
        if child is not None:
            if hasattr(child, "_forward_hooks"):
                child._forward_hooks = OrderedDict()
            elif hasattr(child, "_forward_pre_hooks"):
                child._forward_pre_hooks = OrderedDict()
            elif hasattr(child, "_backward_hooks"):
                child._backward_hooks = OrderedDict()
            remove_all_hooks(child)

def cos_matrix(model, modules, analysis_loader, sample_per_class, C, single=False, analysis=False, output_layer=False):
  class_input = [None] * C
  num_modules = len(modules)
  remove_all_hooks(model)
  hook_group(modules)
  for batch_idx, (data, target) in enumerate(analysis_loader, start=1):
    #data, target = data.to(device), target.to(device)
    #output = model(data)
    #target = target.cpu().detach().numpy()
    #h = features.value.data.view(data.shape[0],-1).cpu().detach().numpy() # B CHW
    enough = True
    for i in range(C):
      if class_input[i] is None:
        class_input[i] = data[target == i]
      else:
        class_input[i] = np.vstack((class_input[i], data[target == i]))
      if class_input[i].shape[0] < sample_per_class:
        enough = False
    if enough:
      break
  structured_input = torch.cat([torch.tensor(class_input_c[:sample_per_class]) for class_input_c in class_input], dim=0).to(device)
  print(structured_input.shape)
  output = model(structured_input)
  layer_matrices = []
  cos_vals = []
  var_vals = []
  inter_cos = []
  feat_norm = []
  sv = []
  last_weight_norm = torch.norm(model.last_layer.weight).item()
  for i in range(num_modules + 1 if output_layer else num_modules):
    if i == num_modules:
      class_features_i = output.detach().view(output.shape[0], -1).cpu().numpy()
    else:
      class_features_i = features.values[i].detach().view(output.shape[0], -1).cpu().numpy()
    class_features_i = class_features_i.reshape(class_features_i.shape[0], -1)
    class_features_i -= np.mean(class_features_i, axis=0)
    if analysis:
      cos_i = []
      var_i = []
      class_feature_means = []
      feat_norm.append(np.linalg.norm(class_features_i, axis=1).mean())
      for c in range(C):
        class_features_ic = class_features_i[sample_per_class * c: sample_per_class * (c + 1)]
        # Variance Analysis
        class_feature_mean = np.mean(class_features_ic, axis=0)
        class_feature_means.append(class_feature_mean)
        class_features_centered = class_features_ic - class_feature_mean
        var_ic = (class_features_centered ** 2).sum(axis=1).mean(axis=0)
        # Cosine Similarity Analysis
        class_features_ic /= np.linalg.norm(class_features_ic, axis=1)[:, np.newaxis]
        # None-diagonal entries
        cos_ic = (class_features_ic @ class_features_ic.T)[~np.eye(class_features_ic.shape[0], dtype=bool)].mean()
        cos_i.append(cos_ic)
        var_i.append(var_ic)
      cos_vals.append(min(cos_i))
      var_vals.append(max(var_i))
      class_feature_means = np.array(class_feature_means)
      class_feature_means /= np.linalg.norm(class_feature_means, axis=1)[:, np.newaxis]
      inter_cos.append(((class_feature_means @ class_feature_means.T)[~np.eye(C, dtype=bool)]).max())
      u, s, vh = np.linalg.svd(class_features_i)
      sv.append(sum(s[:num_classes - 1]) / sum(s))
    elif not single:
      class_features_i /= np.linalg.norm(class_features_i, axis=1)[:, np.newaxis]
      layer_matrices.append(class_features_i @ class_features_i.T)
    else:
      class_features_single = np.stack([class_features_i[sample_per_class * c: sample_per_class * (c + 1)].mean(axis=0) for c in range(C)], axis=0)
      class_features_single /= np.linalg.norm(class_features_single, axis=1)[:, np.newaxis]
      layer_matrices.append(class_features_single @ class_features_single.T)
  if analysis:
    return cos_vals, var_vals, inter_cos, feat_norm, last_weight_norm, sv
  else:
    return layer_matrices

if loss_name == 'CrossEntropyLoss':
  criterion = nn.CrossEntropyLoss()
  criterion_summed = nn.CrossEntropyLoss(reduction='sum')
elif loss_name == 'MSELoss':
  criterion = nn.MSELoss()
  criterion_summed = nn.MSELoss(reduction='sum')

if train_models:
  # Training
  for params in hyperparameter_tables:
    set_hypers(params)
    set_all_seeds(rand_seed)

    if dataset == 'MNIST':
      # dataset parameters
      im_size             = 28
      padded_im_size      = 32
      C                   = 10
      input_ch            = 1

    if dataset == 'CIFAR10':
      # dataset parameters
      im_size             = 32
      padded_im_size      = 32
      C                   = 10
      input_ch            = 3

    if dataset == 'tImageNet':
      # dataset parameters
      im_size             = 64
      padded_im_size      = 64
      C                   = 200
      input_ch            = 3

    
    if dataset == 'CIFAR100':
      # dataset parameters
      im_size             = 32
      padded_im_size      = 32
      C                   = 10
      input_ch            = 3

    if dataset == 'conic' or dataset.startswith('mlp'):
      C = num_classes
    train_loader, test_loader, num_classes, in_channels = get_dataset(dataset, train_samples, test_samples, random_labels)
    graphs = Graphs()

    cur_epochs = []
    dict_str = dict_to_file_string(params)
    model_fn = os.path.join(model_save_path, f"{dict_str}.pth.tar")
    data_fn = os.path.join(data_save_path, f"{dict_str}.txt")
    figure_fn = os.path.join(figure_save_path, f"{dict_str}.png")
    if os.path.exists(data_fn):
        print(f"{data_fn} already exists!")
        continue
    model, classifier, hooked_modules = get_model(model_type, num_classes, in_channels)
    print([name for name, p in model.named_parameters()])
    if adam:
      print('Using Adam Optimizer!')
      optimizer = optim.Adam(model.parameters(),
                    lr=lr,
                    weight_decay=weight_decay)
    else:
      optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=weight_decay)
    
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                milestones=epochs_lr_decay,
                                                gamma=lr_decay)
    for epoch in range(1, epochs + 1):
        callback = None
        train(model, criterion, device, num_classes, train_loader, optimizer, epoch, callback)
        if model_type == 'toy':
          temp = np.array(dot_prods)
          print(f'epoch: {epoch}, mean: {temp.mean()}, prob: {(temp > 0).sum() / len(temp)}')
          dot_prods = []
        lr_scheduler.step()
    remove_all_hooks(model)
    if not debug:
      #torch.save(model, model_fn)
      pass
    full_data_str = ""
    full_data_str += f"Model save path: {model_fn}\n"
    full_data_str += "Training Set:\n"
    intra_cos, inter_cos, qmean_norms, bn_norms, weight_norms, nccs, ranks = full_analysis(model, hooked_modules, train_loader, C)
    #loss, accuracy, Sw_invSb, NCC_mismatch, reg_loss, norm_M_CoV, norm_W_CoV, W_M_dist, cos_M, cos_W = analysis_wrapper(model, model.last_layer, criterion_summed, device, C, train_loader)
    #full_data_str += analysis_str(loss, accuracy, Sw_invSb, NCC_mismatch, reg_loss, norm_M_CoV, norm_W_CoV, W_M_dist, cos_M, cos_W)
    full_data_str += full_analysis_str(intra_cos, inter_cos, qmean_norms, bn_norms, weight_norms, nccs, ranks, len(hooked_modules))
    full_data_str += "Test Set:\n"
    plt.plot(intra_cos, label='min intra class')
    plt.plot(inter_cos,  label='max inter class')
    plt.plot(nccs, label='NCC accuracy')
    plt.ylim([-0.5, 1])
    plt.ylabel('value')
    plt.xlabel('layer')
    plt.legend()
    plt.savefig(figure_fn)
    intra_cos, inter_cos, qmean_norms, bn_norms, weight_norms, nccs, ranks = full_analysis(model, hooked_modules, test_loader, num_classes)
    classifier.register_forward_hook(hook)
    loss, accuracy, Sw_invSb, NCC_mismatch, reg_loss, norm_M_CoV, norm_W_CoV, W_M_dist, cos_M, cos_W = analysis_wrapper(model, classifier, criterion_summed, device, num_classes, test_loader)
    full_data_str += analysis_str(loss, accuracy, Sw_invSb, NCC_mismatch, reg_loss, norm_M_CoV, norm_W_CoV, W_M_dist, cos_M, cos_W)
    full_data_str += full_analysis_str(intra_cos, inter_cos, qmean_norms, bn_norms, weight_norms, nccs, ranks, len(hooked_modules))
    print(full_data_str)
    with open(data_fn, 'w') as df:
      df.write(full_data_str)
    restore_hypers(params)
elif load_model:
  set_hypers(hyperparameter_table)
  if dataset == 'MNIST':
    # dataset parameters
    im_size             = 28
    padded_im_size      = 32
    C                   = 10
    input_ch            = 1

  if dataset == 'CIFAR10':
    # dataset parameters
    im_size             = 32
    padded_im_size      = 32
    C                   = 10
    input_ch            = 3

  if dataset == 'CIFAR100':
    # dataset parameters
    im_size             = 32
    padded_im_size      = 32
    C                   = 10
    input_ch            = 3

  if dataset == 'conic' or dataset.startswith('mlp'):
    C = num_classes

  train_loader, test_loader = get_dataset(dataset, train_samples, test_samples, random_labels)
  model = torch.load(load_model)
  raw_name = os.path.basename(load_model)[:-len('.pth.tar')]
  data_fn = os.path.join(data_save_path, f"{raw_name}.txt")
  figure_fn = os.path.join(figure_save_path, f"{raw_name}.png")
  if model_type == 'MLP':
    hooked_modules = []
    for m in model.feature:
      if isinstance(m, nn.Linear):
        hooked_modules += [m]
    hooked_modules += [model.last_layer]
    last_layer = model.last_layer
  elif model_type == 'ResNet':
    hooked_modules = [model.conv1, model.layer1, model.layer2, model.layer3, model.layer4, model.fc]
    last_layer = model.fc
  elif model_type == 'CNN':
    hooked_modules = [model.conv1, model.conv2, model.fc1, model.fc3]
    last_layer = model.fc
  elif model_type == 'toy':
    hooked_modules = [model.linear1, model.linear2]
    last_layer = model.linear2
  full_data_str = ""
  full_data_str += f"Model save path: {load_model}\n"
  full_data_str += "Training Set:\n"
  intra_cos, inter_cos, qmean_norms, bn_norms, weight_norms, nccs, ranks = full_analysis(model, hooked_modules, train_loader, C)
  last_layer.register_forward_hook(hook)
  loss, accuracy, Sw_invSb, NCC_mismatch, reg_loss, norm_M_CoV, norm_W_CoV, W_M_dist, cos_M, cos_W = analysis_wrapper(model, last_layer, criterion_summed, device, C, train_loader)
  full_data_str += analysis_str(loss, accuracy, Sw_invSb, NCC_mismatch, reg_loss, norm_M_CoV, norm_W_CoV, W_M_dist, cos_M, cos_W)
  full_data_str += full_analysis_str(intra_cos, inter_cos, qmean_norms, bn_norms, weight_norms, nccs, ranks, len(hooked_modules))
  full_data_str += "Test Set:\n"
  plt.plot(intra_cos, label='min intra class')
  plt.plot(inter_cos,  label='max inter class')
  plt.plot(nccs, label='NCC accuracy')
  plt.ylim([-0.5, 1])
  plt.ylabel('value')
  plt.xlabel('layer')
  plt.savefig(figure_fn)
  plt.legend()
  plt.plot(weight_norms, label='Linear Weight Norm')
  plt.ylabel('value')
  plt.xlabel('layer')
  plt.legend()
  intra_cos, inter_cos, qmean_norms, bn_norms, weight_norms, nccs, ranks = full_analysis(model, hooked_modules, test_loader, C)
  remove_all_hooks(model)
  last_layer.register_forward_hook(hook)
  loss, accuracy, Sw_invSb, NCC_mismatch, reg_loss, norm_M_CoV, norm_W_CoV, W_M_dist, cos_M, cos_W = analysis_wrapper(model, last_layer, criterion_summed, device, C, test_loader)
  full_data_str += analysis_str(loss, accuracy, Sw_invSb, NCC_mismatch, reg_loss, norm_M_CoV, norm_W_CoV, W_M_dist, cos_M, cos_W)
  full_data_str += full_analysis_str(intra_cos, inter_cos, qmean_norms, bn_norms, weight_norms, nccs, ranks, len(hooked_modules))
  print(f"Weight Norm Array: {' '.join(map(str, weight_norms))}")
  print(full_data_str)
  with open(data_fn, 'w') as fn:
    fn.write(full_data_str)