{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import torch\n",
    "import os\n",
    "from torch import nn\n",
    "import pandas as pd\n",
    "import math\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "import re\n",
    "from IPython.display import set_matplotlib_formats\n",
    "from pyhessian import hessian \n",
    "\n",
    "from types import SimpleNamespace\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"source_code/\")\n",
    "from architectures.vit import ViT\n",
    "from utils import get_width, get_depth, get_model, set_parametr_args, load_data\n",
    "from architectures.utils import CenteredModel\n",
    "\n",
    "set_matplotlib_formats('retina')\n",
    "%autoreload 2\n",
    "\n",
    "def load_data(directory, epoch):\n",
    "    df = []\n",
    "    experiments = os.listdir(directory)\n",
    "    for exp in experiments:\n",
    "        \n",
    "        path = os.path.join(directory, exp, 'args.json')\n",
    "        parts = exp.split('-')\n",
    "        d = json.load(open(path))\n",
    "\n",
    "        exp_path = os.path.join(directory, exp, f\"ckpt_N_{d['width_mult']}_batches_{epoch}_.pth\")\n",
    "        if not os.path.exists(exp_path):\n",
    "            continue\n",
    "        data = torch.load(exp_path)\n",
    "        metrics = data['metrics']\n",
    "        n_steps = len(metrics['train_loss'])\n",
    "        # print(metrics.keys())\n",
    "        for step in range(n_steps):\n",
    "            d2 = d.copy()\n",
    "            d2['step'] = step\n",
    "            d2['Epoch'] = (step / n_steps) * d['epochs']\n",
    "            d2['train_loss'] = metrics['train_loss'][step]\n",
    "            d2['train_acc'] = metrics['train_acc'][step]\n",
    "            # d2['test_acc'] = metrics['test_acc'][step]\n",
    "            # d2['test_loss'] = metrics['test_loss'][step]\n",
    "            df.append(d2)\n",
    "    \n",
    "    df = pd.DataFrame(df)\n",
    "    return df\n",
    "experiments_dir = \"\"\n",
    "df = load_data(experiments_dir, epoch=9).reset_index()\n",
    "\n",
    "temps = sorted(df['temperature'].unique())\n",
    "print(temps, flush=True)\n",
    "df['Learning Rate'] = df['lr']\n",
    "# df = df[df['temperature'] == 1.0]\n",
    "lrs = sorted(df['lr'].unique())\n",
    "gamma_zeros = sorted(df['gamma_zero'].unique())\n",
    "df['step'] += 1\n",
    "display(df)\n",
    "\n",
    "\n",
    "# get best lr at every gamma\n",
    "gamma_vals = np.sort(df.gamma_zero.unique())\n",
    "best_dfs = []\n",
    "best_lr_for_gamma = {}\n",
    "maximal_lr_for_gamma = {}\n",
    "for gv in gamma_vals:\n",
    "    df2 = df[df['step'] == df['step'].max()].reset_index()\n",
    "    df2 = df2[df2['gamma_zero'] == gv]\n",
    "    # best_acc = df2['test_acc'].max()\n",
    "    best_acc = df2['train_acc'].max()\n",
    "    best_lr = df2[df2['train_acc'] == best_acc].lr.values[0]\n",
    "    print(best_acc, best_lr, gv)\n",
    "    best_lr_for_gamma[gv] = best_lr\n",
    "\n",
    "    for lr in lrs[::-1]:\n",
    "        df3 = df2[df2['lr'] == lr]\n",
    "        if len(df3.train_acc) == 0:\n",
    "            continue\n",
    "        else:\n",
    "            break\n",
    "    maximal_lr_for_gamma[gv] = lr\n",
    "    best_dfs.append(df[(df['lr'] == best_lr) & (df['gamma_zero'] == gv)])\n",
    "best_dfs = pd.concat(best_dfs, ignore_index=True).reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "param = 'mup_sqrt_depth'\n",
    "model_type = 'conv'\n",
    "depth_mult = 4.0\n",
    "\n",
    "dataset = 'cifar10'\n",
    "width_mult = df['width_mult'].values[0]\n",
    "skip_scaling = df['skip_scaling'].values[0]\n",
    "file_name = f\"{model_type}_{param}_{dataset}_width-mult_{width_mult}_depth-mult_{depth_mult}_residuals_{skip_scaling}\"\n",
    "print(file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate phase plot\n",
    "df2 = df[df['step'] == df['step'].max()].reset_index()\n",
    "df2 = df2[df2['parametr'] == param]\n",
    "df2 = df2[df2['depth_mult'] == depth_mult]\n",
    "df2 = df2[df2['temperature'] == temps[0]]\n",
    "# Prepare the data\n",
    "data = df2[['gamma_zero', 'lr', 'train_acc']]\n",
    "data['gamma_zero'] = np.log10(data['gamma_zero']).round(2)\n",
    "data['lr'] = np.log10(data['lr']).round(2)\n",
    "display(data)\n",
    "lrs = np.sort(data['lr'].unique())\n",
    "gamma_zeros = np.sort(data['gamma_zero'].unique())\n",
    "\n",
    "data['train_acc'] = data['train_acc'] / 100  # Normalize accuracy to [0, 1]\n",
    "data_pivot = data.pivot(index=\"lr\", columns=\"gamma_zero\", values=\"train_acc\")\n",
    "\n",
    "X, Y = np.meshgrid(gamma_zeros, lrs)\n",
    "cmap = plt.cm.get_cmap('rocket')\n",
    "cmap.set_bad(color='black')  # Set NaN color\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "mesh = ax.pcolormesh(X, Y, data_pivot.values, cmap=cmap, shading='auto')\n",
    "ax.plot(gamma_zeros[3:10], gamma_zeros[3:10]*2 - 3, linestyle='--', color='gold', label='$\\eta \\sim \\gamma^2$')\n",
    "\n",
    "# Set axis labels and title\n",
    "ax.set_xlabel(\"$\\log \\gamma$\")\n",
    "ax.set_ylabel(\"$\\log \\eta$\")\n",
    "if param == 'mup':\n",
    "    param_text = '$\\mu P$'\n",
    "elif param == 'mup_sqrt_depth':\n",
    "    param_text = '$Depth-\\mu P$'\n",
    "ax.set_title(f'{param_text} ViT Temp=1e-3 Data Augmentation CIFAR-10 \\n N={int(width_mult * 64)}, L={int(depth_mult*6)}, {int(depth_mult)} blocks, B=128, loss=xent')\n",
    "\n",
    "cbar = fig.colorbar(mesh, ax=ax)\n",
    "legend = ax.legend(loc='lower right', frameon=True)\n",
    "legend.get_frame().set_alpha(0.9)  # Make the legend slightly transparent if needed\n",
    "plt.show()\n",
    "\n",
    "# Save gvs and lrvs\n",
    "gvs = []\n",
    "lrvs = []\n",
    "for k, v in maximal_lr_for_gamma.items():\n",
    "    if k < 1.0:\n",
    "        continue\n",
    "    gvs.append(k)\n",
    "    lrvs.append(v)\n",
    "gvs = np.array(gvs)\n",
    "lrvs = np.array(lrvs)\n",
    "lrs = sorted(df['lr'].unique())\n",
    "gamma_zeros = sorted(df['gamma_zero'].unique())\n",
    "print(lrs, gamma_zeros)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print across lr ~ gamma\n",
    "results = []\n",
    "for lr in lrs:\n",
    "    for gv in gamma_zeros:\n",
    "        if 0.99 <= (lr/(gv*(1e-1))) <= 1.00:\n",
    "            results.append(df[(df['lr'] == lr) & (df['gamma_zero'] == gv) & (df['parametr'] == 'mup_sqrt_depth') & (df['depth_mult'] == 4)])\n",
    "pd.set_option('display.float_format', '{:.2e}'.format)\n",
    "\n",
    "results = pd.concat(results)\n",
    "gval = results['gamma_zero'].astype(str)\n",
    "lval = results['lr'].astype(str) \n",
    "results['gamma_zero'] = gval.apply(lambda x: \"{:.0e}\".format(float(x)))\n",
    "results['lr'] = results['lr'].apply(lambda x: \"{:.0e}\".format(float(x)))\n",
    "\n",
    "results['$\\eta \\sim \\gamma$'] = \" $\\gamma=$\" + results['gamma_zero'] + \" $\\eta=$\" + results['lr'] \n",
    "results = results.reset_index()\n",
    "results.loc[results['step'] == 1, 'train_loss'] = np.log(10)\n",
    "\n",
    "sns.set_style('whitegrid')\n",
    "results['t'] = results['step']\n",
    "results['Train Loss'] = results['train_loss']\n",
    "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n",
    "sns.lineplot(results, x='t', y='Train Loss', hue='$\\eta \\sim \\gamma$', legend='full', ax=ax, palette=sns.color_palette(\"plasma_r\", 16))\n",
    "ax.set_yscale('log')\n",
    "ax.set_xscale('log')\n",
    "ax.set_ylabel(\"Test Loss\")\n",
    "ax.set_title(f\"Depth-$\\mu P$ Residual ConvNet for Data Augmentation CIFAR10 \\n N={128}, L=12, 4 blocks B=128, loss=xent\")\n",
    "sns.move_legend(ax, \"upper left\", bbox_to_anchor=(1, 1.02))\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print across lr ~ gamma^2\n",
    "\n",
    "\n",
    "results = []\n",
    "for lr in lrs:\n",
    "    for gv in gamma_zeros:\n",
    "        if gv > 1e-2:\n",
    "            continue\n",
    "        if 0.99 <= (lr/(gv**2)) <= 1.00:\n",
    "            results.append(df[(df['lr'] == lr) & (df['gamma_zero'] == gv) & (df['temperature'] == 1)])\n",
    "pd.set_option('display.float_format', '{:.2e}'.format)\n",
    "\n",
    "results = pd.concat(results)\n",
    "gval = results['gamma_zero'].astype(str)\n",
    "lval = results['lr'].astype(str) \n",
    "results['gamma_zero'] = gval.apply(lambda x: \"{:.0e}\".format(float(x)))\n",
    "results['lr'] = results['lr'].apply(lambda x: \"{:.0e}\".format(float(x)))\n",
    "\n",
    "results['$\\eta \\sim \\gamma$'] = \" $\\gamma=$\" + results['gamma_zero'] + \" $\\eta=$\" + results['lr'] \n",
    "results = results.reset_index()\n",
    "results.loc[results['step'] == 1, 'train_loss'] = np.log(10)\n",
    "\n",
    "sns.set_style('whitegrid')\n",
    "results['t'] = results['step']\n",
    "results['Train Loss'] = results['train_loss']#.rolling(window=10).mean()\n",
    "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n",
    "sns.lineplot(results, x='t', y='Train Loss', hue='$\\eta \\sim \\gamma$', legend='full', ax=ax, palette=sns.color_palette(\"plasma_r\", 21))\n",
    "ax.set_yscale('log')\n",
    "ax.set_xscale('log')\n",
    "ax.set_ylabel(\"Test Loss\")\n",
    "ax.set_title(\"Depth-$\\mu P$ ViT Online Loss for CIFAR-5M (Cross Entropy)\")\n",
    "ax.set_title(f\"Online Loss for CIFAR-5M Depth-$\\mu P$ ViT \\n N={64*8}, L=3, 1 block, B=128, loss=xent\")\n",
    "sns.move_legend(ax, \"upper left\", bbox_to_anchor=(1, 0.8))\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model\n",
    "experiments = os.listdir(experiments_dir)\n",
    "np.set_printoptions(suppress=True)\n",
    "def get_config_and_ckpt_path(gamma, lr, batch_seen, temperature):\n",
    "    exp_path = os.path.join(experiments_dir, filter_lr_gamma(experiments, lr=lr, gamma=gamma, temperature=temperature))\n",
    "    assert os.path.exists(exp_path), exp_path\n",
    "    all_batches = os.listdir(exp_path)\n",
    "    ckpt_file = filter_batches_seen(all_batches, batch_seen)[0]\n",
    "    ckpt_path = os.path.join(exp_path, ckpt_file)\n",
    "    config = os.path.join(exp_path, \"args.json\")\n",
    "    return ckpt_path, config\n",
    "\n",
    "def load_model(ckpt_path, config, model_string):\n",
    "    state_dict = torch.load(ckpt_path)\n",
    "    args = json.load(open(config), object_hook=lambda d: SimpleNamespace(**d))\n",
    "    # width = get_width(\"vit\", args.width_mult)\n",
    "    # depth = get_depth(\"vit\", args.depth_mult)\n",
    "    width = get_width(model_string, args.width_mult)\n",
    "    depth = get_depth(model_string, args.depth_mult)\n",
    "    args = set_parametr_args('mup_sqrt_depth', args)\n",
    "    # args['dataset'] = 'cifar10'\n",
    "    args['dataset'] = 'cifar10'\n",
    "    args['res_scaling'] = 1/math.sqrt(3)\n",
    "    # args['res_scaling'] = 1.0\n",
    "    \n",
    "    args = json.loads(json.dumps(args), object_hook=lambda d: SimpleNamespace(**d))\n",
    "    # import pdb\n",
    "    # pdb.set_trace()\n",
    "    # print(args.norm)\n",
    "    if args.norm == 'none':\n",
    "        args.norm = None\n",
    "    else:\n",
    "        args.norm = 'ln'\n",
    "    model = CenteredModel(get_model(model_string, width, depth, args))\n",
    "    model.load_state_dict(torch.load(ckpt_path, weights_only=True))\n",
    "    \n",
    "    model = model.cuda()\n",
    "    model.eval()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load batch of data\n",
    "\n",
    "# Define a transform to normalize the data\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision\n",
    "\n",
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "\n",
    "transform_mean = np.array([ 0.485, 0.456, 0.406 ])\n",
    "transform_std = np.array([ 0.229, 0.224, 0.225 ])\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomResizedCrop(64),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean = transform_mean, std = transform_std),\n",
    "])\n",
    "# Load the CIFAR-10 training dataset\n",
    "trainset = torchvision.datasets.CIFAR10(root='', train=True, download=True, transform=transform)\n",
    "# trainset = torchvision.datasets.ImageFolder(os.path.join('', 'tiny-imagenet-200/train'), transform=transform_train)\n",
    "# Define a DataLoader for batching\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True, num_workers=1)\n",
    "inputs, targets = next(iter(trainloader))\n",
    "inputs = inputs.cuda().double()\n",
    "targets = targets.cuda()\n",
    "print(inputs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_lr_gamma(filenames, lr, gamma, temperature):\n",
    "    return f\"model_conv-lr_{lr}-seed_0-batch_size_128-res_scaling_sqrt_depth-width_mult_8.0-depth_mult_1-skip_scaling_1-beta_1-gamma_zero_{gamma}-wd_0.0-norm_none-k_layers_1-width_-1-temperature_1.0\"\n",
    "\n",
    "def filter_gamma(filenames, gamma_zero_value):\n",
    "    pattern = re.compile(r\"gamma_zero_([+-]?[0-9]*\\.?[0-9]+([eE][+-]?[0-9]+)?)\")\n",
    "    return [filename for filename in filenames if pattern.search(filename) and float(pattern.search(filename).group(1)) == gamma_zero_value]\n",
    "\n",
    "def filter_batches_seen(filenames, x):\n",
    "    pattern = re.compile(r\"model_ckpt_epoch_\\d+_batches_{}_\\.pth\".format(x))\n",
    "    ret =  [filename for filename in filenames if pattern.match(filename)]\n",
    "    return ret\n",
    "\n",
    "print(best_lr_for_gamma)\n",
    "step = df.step.max() * 200\n",
    "model_string = 'conv'\n",
    "print(gamma_zeros)\n",
    "\n",
    "ckpt_path, config = get_config_and_ckpt_path(gamma_zeros[5], best_lr_for_gamma[gamma_zeros[5]], step, temperature=1000.0)\n",
    "model = load_model(ckpt_path, config, model_string).double().cuda()\n",
    "print(model(inputs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# get sharpness\n",
    "sharpness_res = []\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "def get_result(ckpt_path, config, sharpness_res, string_to_save, step):\n",
    "    model = load_model(ckpt_path, config, model_string).double().cuda()\n",
    "    hessian_comp = hessian(model, criterion, data=(inputs, targets), cuda=True)\n",
    "    top_eigenvalues, _ = hessian_comp.eigenvalues(top_n=3)\n",
    "    sharpness_res.append({\n",
    "        'gamma': gamma,\n",
    "        'T': step,\n",
    "        '$\\lambda$': top_eigenvalues[0],\n",
    "        'eig': 1,\n",
    "        '$\\eta$': string_to_save\n",
    "    })\n",
    "    sharpness_res.append({\n",
    "        'gamma': gamma,\n",
    "        'T': step,\n",
    "        '$\\lambda$': top_eigenvalues[1],\n",
    "        'eig': 2,\n",
    "        '$\\eta$': string_to_save\n",
    "    })\n",
    "    sharpness_res.append({\n",
    "        'gamma': gamma,\n",
    "        'T': step,\n",
    "        '$\\lambda$': top_eigenvalues[2],\n",
    "        'eig': 3,\n",
    "        '$\\eta$': string_to_save\n",
    "    })\n",
    "    del model\n",
    "    return sharpness_res\n",
    "temp = 1.0\n",
    "for gamma in tqdm(gamma_vals):\n",
    "    step = df.step.max()\n",
    "    if np.log10(gamma) < 0.0:\n",
    "        for lr in lrs:\n",
    "            if 0.95 <= lr/(gamma**2) <= 1.01:\n",
    "                ckpt_path, config = get_config_and_ckpt_path(gamma, lr, step * 200, temperature=temp)\n",
    "                sharpness_res = get_result(ckpt_path, config, sharpness_res, string_to_save='$\\gamma^2$', step=step)\n",
    "    else:\n",
    "        for lr in lrs:\n",
    "            if 0.95 <= lr/(gamma*1e-2) <= 1.01:\n",
    "                print(gamma, lr)\n",
    "                ckpt_path, config = get_config_and_ckpt_path(gamma, lr, step * 200, temperature=temp)\n",
    "                sharpness_res = get_result(ckpt_path, config, sharpness_res, string_to_save='$\\gamma$', step=step)\n",
    "sharpness_res = pd.DataFrame(sharpness_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.lines import Line2D\n",
    "sns.set_style('whitegrid')\n",
    "sharpness_res2 = sharpness_res.copy(deep=True)\n",
    "sharpness_res2 = sharpness_res2[sharpness_res2['gamma'] < 100]\n",
    "sharpness_res2 = sharpness_res2[sharpness_res2['gamma'] > 1e-4]\n",
    "sharpness_res2['$\\lambda$'] = np.abs(sharpness_res2['$\\lambda$'])\n",
    "sharpness_res2 = sharpness_res2.rename(columns={'eig': '$\\lambda_k$'})\n",
    "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n",
    "palette = sns.color_palette([\"red\", \"orange\", 'teal'])\n",
    "gammas = np.array(sorted(sharpness_res2['gamma'].unique()))\n",
    "\n",
    "sns.lineplot(sharpness_res2, x='gamma', y='$\\lambda$', hue='$\\lambda_k$', ax=ax, marker='o', palette=sns.color_palette('plasma_r', 3))\n",
    "plt.plot(gammas[:-3], gammas[:-3]**(-2)*0.0001, color='blue', linestyle = '--', label='$\\gamma^{-2}$')\n",
    "plt.plot(gammas[-4:], gammas[-4:]**(-1.0)*0.0001, color='red', linestyle = '--', label='$\\gamma^{-1}$')\n",
    "\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "solid_line = Line2D([0], [0], color='blue', lw=1, label='$\\gamma^{-2}$')\n",
    "dashed_line = Line2D([0], [0], color='red', lw=1, label='$\\gamma^{-1}$')\n",
    "\n",
    "ax.set_yscale('log')\n",
    "ax.set_xscale('log')\n",
    "ax.set_xlabel('$\\gamma$')\n",
    "ax.legend(handles=handles, labels=labels, loc='upper left', bbox_to_anchor=(1, 1.02))\n",
    "\n",
    "sns.move_legend(ax, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "torch.set_grad_enabled(False)\n",
    "results = []\n",
    "\n",
    "def center_kernel(K):\n",
    "    n = K.size(0)  # Assuming K is square\n",
    "    ones = torch.ones(n, n).double().cuda() / n\n",
    "    K_c = K - ones @ K - K @ ones + ones @ K @ ones\n",
    "    return K_c\n",
    "\n",
    "def get_cka(ckpt_path, config):\n",
    "    model1 = load_model(ckpt_path, config, 'conv').double().cuda()\n",
    "    \n",
    "    # Compute the kernel K\n",
    "    outputs = model1(inputs)\n",
    "    K = outputs @ outputs.T  # DXD\n",
    "    k_norm = torch.linalg.norm(K, ord='fro').item()\n",
    "    K_c = center_kernel(K)  # Center the kernel matrix K\n",
    "    ck_norm = torch.linalg.norm(K_c, ord='fro').item()\n",
    "    K_c /= torch.linalg.norm(K_c, ord='nuc')  # Normalize by Frobenius norm\n",
    "    Y = torch.nn.functional.one_hot(targets).double().cuda()\n",
    "    Y = Y @ Y.T\n",
    "    Y_c = center_kernel(Y)  # Center the kernel matrix Y\n",
    "    Y_c /= torch.linalg.norm(Y_c, ord='fro')  # Normalize by Frobenius norm\n",
    "    cka = torch.trace(K_c @ Y_c).item()\n",
    "    ka = torch.trace(K @ Y).item() / (torch.linalg.norm(K, ord='nuc') * torch.linalg.norm(Y, ord='fro'))\n",
    "    del model1  # Delete model explicitly\n",
    "    torch.cuda.empty_cache()  # Free up GPU memory\n",
    "    gc.collect()  # Call garbage collector to free up other unused memory\n",
    "    return cka, ka, ck_norm, k_norm\n",
    "\n",
    "to_plot = gamma_zeros[-7:-3]\n",
    "for gamma in to_plot:    \n",
    "    for step in tqdm(range(0, df.step.max(), 1)):\n",
    "        ckpt_path, config = get_config_and_ckpt_path(gamma, best_lr_for_gamma[gamma], step * 200, temperature=1.0)\n",
    "        cka, ka, ck_norm, k_norm = get_cka(ckpt_path, config)\n",
    "        results.append({\n",
    "            'lr': best_lr_for_gamma[gamma],\n",
    "            'gamma': gamma,\n",
    "            'step': step,\n",
    "            'CKA': cka,\n",
    "            'KA': ka.item(),\n",
    "            'CK_norm': ck_norm,\n",
    "            'K_norm': k_norm\n",
    "        })\n",
    "results = pd.DataFrame(results)\n",
    "results['$\\eta \\sim \\gamma$'] = \" $\\log \\gamma=$\" + np.log10(results['gamma']).astype(str) + \" $\\log \\eta=$\" + np.log10(results['lr']).astype(str) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses = df.copy(deep=True)\n",
    "losses_ = []\n",
    "for gamma in to_plot:\n",
    "    losses_.append(losses[(losses['gamma_zero'] == gamma) & (losses['lr'] == best_lr_for_gamma[gamma]) & (losses['depth_mult'] == 1.0) & (df['parametr'] == 'mup_sqrt_depth')])\n",
    "losses = pd.concat(losses_)\n",
    "losses = losses[losses['gamma_zero'] < 3000]\n",
    "gval = losses['gamma_zero'].astype(str)\n",
    "lval = losses['lr'].astype(str) \n",
    "losses['gamma_zero'] = gval.apply(lambda x: \"{:.0e}\".format(float(x)))\n",
    "losses['lr'] = losses['lr'].apply(lambda x: \"{:.0e}\".format(float(x)))\n",
    "losses['$\\eta \\sim \\gamma$'] = \" $\\gamma=$\" + losses['gamma_zero'] + \" $\\eta=$\" + losses['lr'] \n",
    "\n",
    "\n",
    "results2 = results.copy(deep=True)\n",
    "results2 = results2[results2['gamma'] < 3000]\n",
    "gval = results2['gamma'].astype(str)\n",
    "lval = results2['lr'].astype(str) \n",
    "results2['gamma'] = gval.apply(lambda x: \"{:.0e}\".format(float(x)))\n",
    "results2['lr'] = results2['lr'].apply(lambda x: \"{:.0e}\".format(float(x)))\n",
    "results2['$\\eta \\sim \\gamma$'] = \" $\\gamma=$\" + results2['gamma'] + \" $\\eta=$\" + results2['lr'] \n",
    "results2 = results2.rename(columns={'$\\eta \\sim \\gamma$': '$\\eta - \\gamma$'})\n",
    "results2 = results2.rename(columns={'CK_norm': '$\\|K_c\\|_F$'})\n",
    "results2 = results2.rename(columns={'K_norm': '$\\|K\\|_F$'})\n",
    "results2 = results2.rename(columns={'step': 't'})\n",
    "from matplotlib.lines import Line2D\n",
    "import seaborn as sns\n",
    "from matplotlib.lines import Line2D\n",
    "import seaborn as sns\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(7, 5))\n",
    "sns.lineplot(results2, x='t', y='CKA', hue= '$\\eta - \\gamma$', legend=False, linestyle='--', palette=sns.color_palette('plasma_r',3))\n",
    "sns.lineplot(losses, x='step', y='train_loss', hue= '$\\eta \\sim \\gamma$', legend='full', palette=sns.color_palette('plasma_r', 3))\n",
    "\n",
    "ax.set_yscale('log')\n",
    "ax.set_xscale('log')\n",
    "\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "\n",
    "solid_line = Line2D([0], [0], color='black', lw=1, label='Loss')\n",
    "dashed_line = Line2D([0], [0], color='black', lw=1, linestyle='--', label='Alignment')\n",
    "\n",
    "# Add the custom lines to the existing handles and labels\n",
    "handles.extend([solid_line, dashed_line])\n",
    "labels.extend(['Loss', 'Alignment'])\n",
    "\n",
    "ax.legend(handles=handles, labels=labels, loc='upper left', bbox_to_anchor=(1, 1.02))\n",
    "ax.set_ylabel(\"Alignment\")\n",
    "ax.set_title(f\"Alignment for CIFAR-5M Depth-$\\mu P$ ViT (no LN) \\n N={64*8}, B=128, loss=CE\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
