{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7b7b16c9-c63a-439c-8e5a-4136c3aef177",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64\n"
     ]
    }
   ],
   "source": [
    "### All imports ###\n",
    "import os\n",
    "import sys\n",
    "paths_to_add = [\"..\", \"../..\"]\n",
    "for path in paths_to_add:\n",
    "    sys_path = os.path.relpath(path)\n",
    "    if sys_path not in sys.path:  # Check to avoid duplicates\n",
    "        sys.path.append(sys_path)\n",
    "\n",
    "import argparse\n",
    "import itertools\n",
    "from functools import partial\n",
    "import copy\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import pickle\n",
    "\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import time as TIME\n",
    "\n",
    "from _src.datasets import prepare_data\n",
    "from _src.models import RoPETransformer, RoPEFlashAttention\n",
    "\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "plt.rcParams.update({\"font.size\": 20})\n",
    "\n",
    "\n",
    "### Custom class to use pickle files on cpu ###\n",
    "import io\n",
    "class CPU_Unpickler(pickle.Unpickler):\n",
    "    def find_class(self, module, name):\n",
    "        if module == 'torch.storage' and name == '_load_from_bytes':\n",
    "            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')\n",
    "        else:\n",
    "            return super().find_class(module, name)\n",
    "\n",
    "\n",
    "def str2bool(v):\n",
    "    if isinstance(v, bool):\n",
    "        return v\n",
    "    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n",
    "        return True\n",
    "    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n",
    "        return False\n",
    "    else:\n",
    "        raise argparse.ArgumentTypeError('Boolean value expected.')\n",
    "\n",
    "### Helper functions to create task lists ### \n",
    "def generate_all_unique_sublists(args):\n",
    "    all_combinations = list(itertools.product(range(1, args.base), repeat=args.n_var))\n",
    "    if len(all_combinations) < args.n_tasks:\n",
    "        raise ValueError(\"Not enough unique combinations available.\")\n",
    "    selected_combinations = random.sample(all_combinations, args.n_tasks)\n",
    "    return [list(combination) for combination in selected_combinations]\n",
    "\n",
    "\n",
    "def get_ood_lists(Ws, args):\n",
    "    all_combinations = list(itertools.product(range(1, args.base), repeat=args.n_var))\n",
    "    Ws = set(tuple(W) for W in Ws)\n",
    "    return list(set(all_combinations) - Ws)\n",
    "\n",
    "\n",
    "def attach_tasks_with_shared_components(Ws_unique, args):\n",
    "    \"\"\"Currently only work\"\"\"\n",
    "    all_possibilities = itertools.product(range(args.p), repeat=len(Ws_unique[0]))\n",
    "    Ws = copy.deepcopy(Ws_unique)\n",
    "    for possible_W in all_possibilities:\n",
    "        for W in Ws_unique:\n",
    "            for i in range(len(possible_W)):\n",
    "                if W[i] == possible_W[i] and possible_W not in Ws:\n",
    "                    Ws.append(list(possible_W))\n",
    "    return Ws\n",
    "\n",
    "\n",
    "def parallelogram_tasks_with_shared_components(Ws_unique, args):\n",
    "    \"\"\"Currently only work for 2 variables\"\"\"\n",
    "    def generate_lists(a, b, p):\n",
    "        # Generate unique list combinations given the constraints\n",
    "        while True:\n",
    "            x = random.choice([i for i in range(p) if i != a and i != b])\n",
    "            y = random.choice([i for i in range(p) if i != a and i != b and i != x])\n",
    "\n",
    "            list1 = [a, y]\n",
    "            list2 = [x, b]\n",
    "            list3 = [x, y]\n",
    "\n",
    "            if list1 not in Ws and list2 not in Ws and list3 not in Ws:\n",
    "                return [list1, list2, list3]\n",
    "\n",
    "    Ws = copy.deepcopy(Ws_unique)\n",
    "    while len(Ws) < 4 * len(Ws_unique):\n",
    "        for W in Ws_unique:\n",
    "            if len(Ws) >= 4 * len(Ws_unique):\n",
    "                break  # Break early if the target size is already reached\n",
    "            new_Ws = generate_lists(W[0], W[1], args.p)\n",
    "            Ws.extend(new_Ws)\n",
    "    \n",
    "    return Ws\n",
    "\n",
    "### Parser for all required settings ###\n",
    "parser = argparse.ArgumentParser(description=\"Transformer Grokking\")\n",
    "parser.add_argument(\"--model_name\", default=\"rope_decoder\", type=str, help=\"Encoder or Decoder only Transformers\")\n",
    "parser.add_argument(\"--one_shot\", default=False, type=str2bool, help=\"One shot or CoT\")\n",
    "parser.add_argument(\"--device\", default=\"cpu\", type=str, help=\"device\")\n",
    "parser.add_argument(\"--dtype\", default=\"float32\", type=str, help=\"dtype\")\n",
    "parser.add_argument(\"--mixed_precision\", default=False, type=str2bool, help=\"Automatic Mixed Precision\")\n",
    "parser.add_argument(\"--seed\", default=1, type=int, help=\"random seed\")\n",
    "parser.add_argument(\"--ddp\", default=False, type=str2bool, help=\"DDP or not\")\n",
    "parser.add_argument(\"--world_size\", default=1, type=int, help=\"World Size\")\n",
    "\n",
    "# Model Settings\n",
    "parser.add_argument(\"--n_layer\", default=4, type=int, help=\"Number of Transformer Blocks\")\n",
    "parser.add_argument(\"--dp\", default=0.0, type=float, help=\"Dropout Probability\")\n",
    "parser.add_argument(\"--if_ln\", default=True, type=str2bool, help=\"If use LayerNorm or Not\")\n",
    "parser.add_argument(\"--n_embd\", default=512, type=int, help=\"Embedding Dimension\")\n",
    "parser.add_argument(\"--n_head\", default=4, type=int, help=\"Number of Heads\")\n",
    "parser.add_argument(\"--block_size\", default=512, type=int, help='maximum length')\n",
    "parser.add_argument(\"--act_name\", default=\"relu\", type=str, help=\"activation: relu, gelu, swiglu\")\n",
    "parser.add_argument(\"--widen_factor\", default=4, type=int, help=\"MLP widening\")\n",
    "parser.add_argument(\"--mu\", default=1.0, type=float, help=\"Skip connection strength\")\n",
    "parser.add_argument(\"--weight_tying\", default=False, type=str2bool, help=\"If use weight tying\")\n",
    "parser.add_argument(\"--dont_decay_embd\", default=True, type=str2bool, help=\"If use weight tying\")\n",
    "\n",
    "# Data\n",
    "parser.add_argument(\"--n_tasks\", default=64, type=int, help=\"number of independent tasks\")\n",
    "parser.add_argument(\"--parallelogram\", default=True, type=str2bool, help=\"Perform parallelogram construction on task vectors or not\")\n",
    "parser.add_argument(\"--n_var\", default=2, type=int, help=\"number of variables, i.e. dimension of the problem\")\n",
    "parser.add_argument(\"--data_seed\", default=0, type=int, help=\"random seed for generating datasets\")\n",
    "parser.add_argument(\"--data_pct\", default=80.0, type=float, help=\"Data Percentage\")\n",
    "parser.add_argument(\"--task_pct\", default=50.0, type=float, help=\"Task Percentage\")\n",
    "parser.add_argument(\"--p\", default=29, type=int, help=\"Modulo p\")\n",
    "parser.add_argument(\"--base\", default=29, type=int, help=\"Represent Numbers in base\")\n",
    "parser.add_argument(\"--n_point_per_row\", default=32, type=int, help=\"Number of data points per row\")\n",
    "parser.add_argument(\"--n_point_per_row_gen\", default=1, type=int, help=\"k-shot\")\n",
    "parser.add_argument(\"--ctx_masked\", default=0, type=int, help=\"Number of first i data points to mask\")\n",
    "parser.add_argument(\"--encrypted\", default=True, type=str2bool, help=\"Write the task vectors in data or not.\")\n",
    "parser.add_argument(\"--pos_hint\", default=False, type=str2bool, help=\"Add positional hint or not\")\n",
    "parser.add_argument(\"--reverse_target\", default=False, type=str2bool, help=\"Reverse the digits order of targets or not\")\n",
    "parser.add_argument(\"--show_mod\", default=False, type=str2bool, help=\"Add mod p to token or not\")\n",
    "parser.add_argument(\"--show_seos\", default=False, type=str2bool, help=\"USe SOS and EOS or not\")\n",
    "parser.add_argument(\"--split_tasks\", default=False, type=str2bool, help=\"Train/Test set have different task vectors or not.\")\n",
    "parser.add_argument(\"--split_data\", default=True, type=str2bool, help=\"Train/Test set have different datapoints or not.\")\n",
    "\n",
    "# Optimization\n",
    "parser.add_argument(\"--optim\", default=\"adamw\", type=str, help=\"Optimizer: adamw or sgd\")\n",
    "parser.add_argument(\"--s\", default=0.0, type=float, help=\"s=0 for SP, 1 for muP like attention. Use 0.0 only for now.\")\n",
    "parser.add_argument(\"--bs\", default=1024, type=int, help=\"Batchsize\")\n",
    "parser.add_argument(\"--eval_bs\", default=1024, type=int, help=\"Batchsize for Evaluation\")\n",
    "parser.add_argument(\"--lr\", default=1.5e-4, type=float, help=\"Learning Rate\")\n",
    "parser.add_argument(\"--n_cycles\", default=1, type=int, help=\"Cycles of Scheduler\")\n",
    "parser.add_argument(\"--wd\", default=5.0, type=float, help=\"Weight Decay\")\n",
    "parser.add_argument(\"--beta1\", default=0.9, type=float, help=\"Beta 1 for AdamW\")\n",
    "parser.add_argument(\"--beta2\", default=0.98, type=float, help=\"Beta 2 for AdamW\")\n",
    "parser.add_argument(\"--eps\", default=1e-8, type=float, help=\"Eps for AdamW\")\n",
    "parser.add_argument(\"--momentum\", default=0.9, type=float, help=\"Momentum for SGD\")\n",
    "parser.add_argument(\"--ckpt_step\", default=160000, type=int, help=\"Training Epochs\")\n",
    "parser.add_argument(\"--steps\", default=200000, type=int, help=\"Training Epochs\")\n",
    "parser.add_argument(\"--warmup_steps\", default=10000, type=int, help=\"Warmup Epochs\")\n",
    "parser.add_argument(\"--lr_decay\", default='cosine', type=str, help=\"If Use Scheduler\")\n",
    "parser.add_argument(\"--steps_per_record\", default=1000, type=int, help=\"Save Results\")\n",
    "parser.add_argument(\"--reshuffle_step\", default=1, type=int, help=\"Save Results\")\n",
    "parser.add_argument(\"--no_of_steps\", default=1, type = int, help='Number of examples')\n",
    "\n",
    "\n",
    "# Inference\n",
    "parser.add_argument(\"--train_set\", default=False, type=str2bool, help=\"Use training inputs or val inputs\")\n",
    "parser.add_argument(\"--n_measure\", default=3, type=int, help=\"How many batches to average.\")\n",
    "parser.add_argument(\"--savefig\", default=True, type=str2bool, help=\"Save Figure\")\n",
    "parser.add_argument(\"--no_of_dis_el\", default = 1, type = int, help='Number of distinct elements')\n",
    "parser.add_argument(\"--ood_tasks\",default=False,type=str2bool,help='Whether to use ID (False) or OOD (True) tasks')\n",
    "#parser.add_argument(\"--no_o\")\n",
    "\n",
    "\n",
    "args, unknown = parser.parse_known_args()\n",
    "assert args.show_seos == False\n",
    "\n",
    "device = torch.device(args.device)\n",
    "if args.mixed_precision is True:\n",
    "    assert args.dtype in ['float16', 'bfloat16']\n",
    "    assert 'cuda' in args.device\n",
    "    if args.dtype == 'float16':\n",
    "        args.dtype = torch.float16\n",
    "    else:\n",
    "        args.dtype = torch.bfloat16\n",
    "else:\n",
    "    torch.set_float32_matmul_precision('high')\n",
    "    args.dtype = torch.float32\n",
    "\n",
    "### Set random seeds manually ###\n",
    "torch.manual_seed(args.seed)\n",
    "np.random.seed(args.seed)\n",
    "random.seed(args.seed)\n",
    "\n",
    "Ws = generate_all_unique_sublists(args)\n",
    "Ws_og = Ws\n",
    "print(len(Ws_og))\n",
    "if args.parallelogram is True:\n",
    "    Ws = parallelogram_tasks_with_shared_components(Ws, args)\n",
    "check = list(set(tuple(W) for W in Ws))\n",
    "\n",
    "args.pre_train_n_tasks = len(Ws)\n",
    "\n",
    "Ws_ood = get_ood_lists(Ws, args)\n",
    "\n",
    "random.shuffle(Ws_ood)\n",
    "\n",
    "# Choose 16 tasks, half are OOD tasks, half are ID.\n",
    "no_of_tasks_tc = 16\n",
    "Ws_i = [Ws[i] for i in range(no_of_tasks_tc//2) ]\n",
    "Ws_o = [Ws_ood[i] for i in range(no_of_tasks_tc//2-1)]\n",
    "\n",
    "Ws = Ws_i + [(1,1)] +  Ws_o\n",
    "\n",
    "args.n_tasks = len(Ws)\n",
    "# print(Ws, args.n_tasks)\n",
    "\n",
    "if args.n_tasks == 1:\n",
    "    assert args.split_tasks == False\n",
    "\n",
    "### Defining datasets, data, data-loader. ###\n",
    "    \n",
    "train_set, val_set, tokenizer = prepare_data(args, Ws)\n",
    "# print(train_set[:2, :])\n",
    "# exit()\n",
    "original_n_train_row = train_set.size(0)\n",
    "original_n_val_row = val_set.size(0)\n",
    "args.vocab_size = tokenizer.__len__()\n",
    "args.max_digits = tokenizer.max_digits\n",
    "args.max_digits = 2 * args.max_digits if args.pos_hint is True else args.max_digits\n",
    "args.dim = args.max_digits * (len(Ws[0]) + 1)\n",
    "\n",
    "\n",
    "\"\"\"Make copies of data\"\"\"\n",
    "if args.split_tasks is True:\n",
    "    task_rows = round(args.n_tasks * (args.task_pct / 100.0))\n",
    "    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)\n",
    "    val_set = val_set.view((args.n_tasks - task_rows), len(val_set) // task_rows, -1)\n",
    "else:\n",
    "    task_rows = args.n_tasks\n",
    "    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)\n",
    "    val_set = val_set.view(args.n_tasks, len(val_set) // task_rows, -1)\n",
    "\n",
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, dataset, bs, args):\n",
    "        self.dataset = dataset.transpose(0, 1)\n",
    "        self.n_data, self.n_task, self.dim = self.dataset.shape\n",
    "        self.bs = bs\n",
    "        self.args = args\n",
    "        \n",
    "    def __len__(self):\n",
    "        return self.n_data * self.args.n_point_per_row * 5 * self.args.n_measure # To ensure we don't have to restart dataloader.\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        step_x = self.dataset[idx % self.n_data] # (n_tasks, dim)\n",
    "        return step_x\n",
    "\n",
    "\n",
    "def custom_collate_fn(batch, bs, args):\n",
    "    inputs = torch.stack(batch, dim=1)\n",
    "    idx = torch.randperm(inputs.size(1))\n",
    "    inputs = inputs[torch.arange(args.n_tasks)[:, None], idx[None, :]]\n",
    "    targets = inputs.clone() # (n_tasks, bs * ctx_length, dim)\n",
    "    targets[:, :, :-args.max_digits] = -100 # Mask the unpredictable part\n",
    "    return inputs.view(bs, -1), targets.view(bs, -1)  # (bs, ctx_length * dim)\n",
    "\n",
    "\n",
    "def get_dataloader(dataset, bs, args):\n",
    "    custom_dataset = CustomDataset(dataset, bs, args)\n",
    "    collate_fn = partial(custom_collate_fn, bs=bs, args=args) # collate_fn should only take one input\n",
    "    if args.ddp is False:\n",
    "        g = torch.Generator()\n",
    "        g.manual_seed(args.seed + 38493483)\n",
    "        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks), num_workers=0, collate_fn=collate_fn, drop_last=True, generator=g) \n",
    "    else:\n",
    "        sampler = DistributedSampler(custom_dataset, shuffle=False)\n",
    "        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks) // args.world_size, drop_last=True, num_workers=0, collate_fn=collate_fn, sampler=sampler) \n",
    "    return dataloader\n",
    "\n",
    "args.eval_bs = len(Ws)\n",
    "train_iter = get_dataloader(train_set, args.eval_bs, args)\n",
    "val_iter = get_dataloader(val_set, args.eval_bs, args)\n",
    "\n",
    "if args.train_set is True:\n",
    "    iter_to_use = get_dataloader(train_set, args.eval_bs, args)\n",
    "else:\n",
    "    iter_to_use = get_dataloader(val_set, args.eval_bs, args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea4266f9-0093-4e26-b12c-5e212d8e1086",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Checkpoint name ###\n",
    "ckpt_path = f'../ckpts/d{args.n_layer}_h{args.n_head}_embd{args.n_embd}/noembd{args.dont_decay_embd}_parale{args.parallelogram}_{args.model_name}_p{args.p}_base{args.base}_row{args.n_point_per_row}_ntask{args.pre_train_n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_dfrac{args.data_pct:.1f}_{args.act_name}_n{args.n_embd}_h{args.n_head}_d{args.n_layer}_lctx{args.block_size}_I{args.seed}_dI{args.data_seed}_{args.optim}_bs{args.bs}_t{args.ckpt_step:d}_T{args.steps:d}_Tw{args.warmup_steps:d}_Trshf{args.reshuffle_step}_lr{args.lr:0.2e}_wd{args.wd:.2e}.pth'\n",
    "\n",
    "model = RoPETransformer(RoPEFlashAttention, args).to(device=device)\n",
    "model.load_state_dict(torch.load(ckpt_path, map_location='cpu'), strict=False)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32e4f327-40a8-4ad3-ac2a-8a734f70a444",
   "metadata": {},
   "source": [
    "# Label corruption (at random locations) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a92a2090-6be8-493b-b9ed-99900c126890",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "105\n",
      "The seq is  tensor([ 8, 28,  7, 12, 23,  6, 21, 12,  4, 14,  4, 18, 28, 17, 16, 15,  9, 24,\n",
      "        11,  9, 20, 26, 27, 24,  6,  8, 14, 21, 28, 20, 27,  1, 28,  3,  7, 10,\n",
      "        28,  5,  4,  8, 16, 24, 25,  9,  5,  9, 14, 23, 20,  9,  0,  4, 23, 27,\n",
      "         7, 23,  1,  9, 25,  5, 15, 26, 12,  9,  8, 17,  0, 19, 19, 20, 18,  9,\n",
      "        10,  6, 16, 20, 27, 18,  6, 10, 16, 15, 15,  1, 18,  6, 24,  9, 19, 28,\n",
      "        17, 22, 10, 11,  4, 15])\n",
      "The corrupted seq is  tensor([ 4, 23,  8, 20, 27, 18,  8, 28,  7,  9, 14,  4, 15, 15,  1,  6,  8, 14,\n",
      "         7, 23,  1, 15,  9, 24, 20,  9,  0, 21, 28, 20, 20, 18,  9, 18,  6, 24,\n",
      "         9,  8, 17, 17, 22, 10, 14,  4, 18, 11,  4, 15,  6, 10, 16, 28,  5,  4,\n",
      "         9, 19, 28, 21, 12,  4, 10,  6, 16,  0, 19, 19,  8, 16, 24, 26, 27, 24,\n",
      "        25,  9,  5, 11,  9, 20,  3,  7, 10,  9, 25,  5, 12, 23,  6, 27,  1, 28,\n",
      "        15, 26, 12, 28, 17, 16])\n"
     ]
    }
   ],
   "source": [
    "######## LABEL CORRUPTION AT RANDOM LOCATIONS ########\n",
    "\n",
    "\n",
    "#### Define data-loader for random label corruption ####\n",
    "\n",
    "class CustomDatasetCorruption(Dataset):\n",
    "    def __init__(self, dataset, bs, args):\n",
    "        self.dataset = dataset.transpose(0, 1)\n",
    "        self.n_data, self.n_task, self.dim = self.dataset.shape\n",
    "        self.bs = bs\n",
    "        self.args = args\n",
    "        \n",
    "    def __len__(self):\n",
    "        return self.n_data * self.args.n_point_per_row * 5 * self.args.n_measure # To ensure we don't have to restart dataloader.\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        step_x = self.dataset[idx % self.n_data] # (n_tasks, dim)\n",
    "        return step_x\n",
    "\n",
    "\n",
    "def custom_collate_fn_corrupt(batch, bs, args):\n",
    "    inputs = torch.stack(batch, dim=1)\n",
    "    idx = torch.randperm(inputs.size(1))\n",
    "    inputs = inputs[torch.arange(args.n_tasks)[:, None], idx[None, :]]\n",
    "    targets = copy.deepcopy(inputs) # (n_tasks, bs * ctx_length, dim)\n",
    "    targets[:, :, :-args.max_digits] = -100 # Mask the unpredictable part\n",
    "    \n",
    "    if args.no_of_corr > 0:\n",
    "        to_corrupt = torch.randperm(args.n_point_per_row)[:args.no_of_corr]\n",
    "        inputs[:,to_corrupt,args.to_corrupt] = (inputs[:,to_corrupt, args.to_corrupt]+torch.randint(1,args.p,(inputs.size(0), args.no_of_corr)))%args.p\n",
    "    \n",
    "    return inputs.view(bs, -1), targets.view(bs, -1)  # (bs, ctx_length * dim)\n",
    "\n",
    "\n",
    "def get_dataloader_corrupt(dataset, bs, args):\n",
    "    custom_dataset = CustomDatasetCorruption(dataset, bs, args)\n",
    "    collate_fn = partial(custom_collate_fn_corrupt, bs=bs, args=args) # collate_fn should only take one input\n",
    "    if args.ddp is False:\n",
    "        g = torch.Generator()\n",
    "        g.manual_seed(args.seed + 38493483)\n",
    "        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks), num_workers=0, collate_fn=collate_fn, drop_last=True, generator=g) \n",
    "    else:\n",
    "        sampler = DistributedSampler(custom_dataset, shuffle=False)\n",
    "        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks) // args.world_size, drop_last=True, num_workers=0, collate_fn=collate_fn, sampler=sampler) \n",
    "    return dataloader\n",
    "\n",
    "Ws = [(1,1),(1,2)]\n",
    "args.n_tasks = len(Ws)\n",
    "train_set, val_set, tokenizer = prepare_data(args, Ws)\n",
    "# print(train_set[:2, :])\n",
    "# exit()\n",
    "original_n_train_row = train_set.size(0)\n",
    "original_n_val_row = val_set.size(0)\n",
    "args.vocab_size = tokenizer.__len__()\n",
    "args.max_digits = tokenizer.max_digits\n",
    "args.max_digits = 2 * args.max_digits if args.pos_hint is True else args.max_digits\n",
    "args.dim = args.max_digits * (len(Ws[0]) + 1)\n",
    "\n",
    "\n",
    "\"\"\"Make copies of data\"\"\"\n",
    "if args.split_tasks is True:\n",
    "    task_rows = round(args.n_tasks * (args.task_pct / 100.0))\n",
    "    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)\n",
    "    val_set = val_set.view((args.n_tasks - task_rows), len(val_set) // task_rows, -1)\n",
    "else:\n",
    "    task_rows = args.n_tasks\n",
    "    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)\n",
    "    val_set = val_set.view(args.n_tasks, len(val_set) // task_rows, -1)\n",
    "\n",
    "args.eval_bs = len(Ws)\n",
    "\n",
    "if args.train_set is True:\n",
    "    corrupt_iter_to_use = get_dataloader_corrupt(train_set, args.eval_bs, args)\n",
    "else:\n",
    "    corrupt_iter_to_use = get_dataloader_corrupt(val_set, args.eval_bs, args)\n",
    "\n",
    "if args.train_set is True:\n",
    "    iter_to_use = get_dataloader(train_set, args.eval_bs, args)\n",
    "else:\n",
    "    iter_to_use = get_dataloader(val_set, args.eval_bs, args)\n",
    "    \n",
    "### Ws[no_of_tasks_tc//2] is (1,1)\n",
    "# print(Ws[no_of_tasks_tc//2])\n",
    "\n",
    "### TEST CODE TO SEE IF DATALOADER WORKS ###\n",
    "args.n_measure = 1\n",
    "args.n_point_per_row = 4\n",
    "args.no_of_corr = 2\n",
    "args.to_corrupt = 2 ### which one to corrupt, 0 -> x, 1 -> y, 2 -> f.\n",
    "\n",
    "with torch.inference_mode():\n",
    "    print(len(iter_to_use))\n",
    "    for t, (x,y) in enumerate(iter_to_use):\n",
    "        print(\"The seq is \", x[0,:])\n",
    "        if t >= 0:\n",
    "            break\n",
    "    for t, (x,y) in enumerate(corrupt_iter_to_use):\n",
    "        print(\"The corrupted seq is \", x[0,:])\n",
    "        if t >= 0:\n",
    "            break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5aa4e855",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.inference_mode()\n",
    "def measure_perpos_accloss_full_int_preds(model, val_iter, args, device, n_measure = 1):\n",
    "    \"\"\"Measure per position accuracy for one batch, modular arithmetic\n",
    "    \"\"\"\n",
    "    ctx = nullcontext() if 'mps' in args.device else torch.autocast(device_type=args.device, dtype=args.dtype, enabled=args.mixed_precision)\n",
    "    model.eval()\n",
    "    model.eval()\n",
    "    \n",
    "    acc_records, loss_records, logits_records = [], [], []\n",
    "    tgts = []\n",
    "    \n",
    "    # acc_record = torch.zeros((args.eval_bs, args.new_seq_len+1), device=device, dtype=args.dtype)\n",
    "    # loss_record = torch.zeros((args.eval_bs, args.new_seq_len+1), device=device, dtype=args.dtype)\n",
    "    t1 = TIME.time()\n",
    "    \n",
    "    for t, (x, y) in enumerate(val_iter):\n",
    "        print(\"At pos \",t)\n",
    "        # print(TIME.time()-t1)\n",
    "        t1 = TIME.time()\n",
    "        x = x[:, :-1].contiguous().to(device)\n",
    "        y = y[:, 1:].contiguous().to(device)\n",
    "        print((x[:,2::3]!=y[:,1:-1:3]).float().mean())\n",
    "        if t >= (n_measure):\n",
    "            break\n",
    "        # print(\"The seq is \", x[:2,:12])\n",
    "        # print(\"X shape : \", x.shape)\n",
    "        \n",
    "\n",
    "        losses_ints = []\n",
    "        acc_ints = []\n",
    "        logits_ints = []\n",
    "        \n",
    "        tgts.append(y)\n",
    "        \n",
    "        with ctx:\n",
    "            logits, qkv_list, input_list, int_list = model.record(x)\n",
    "            # print(\"logits : \",logits.shape)\n",
    "            \n",
    "            for i in range(args.n_layer-1):\n",
    "                x1 = model.transformer.ln_f(input_list[i])\n",
    "                logits_int = model.scale * model.lm_head(x1)\n",
    "                logits_ints.append(logits_int)\n",
    "                loss_int = F.cross_entropy(logits_int.view(-1, logits_int.size(-1)), y.view(-1), reduction='none')\n",
    "                loss_int = loss_int.reshape(logits_int.size(0),-1)\n",
    "                losses_ints.append(loss_int[:, (args.dim-2)::args.dim])\n",
    "                pred_int = logits_int.argmax(-1)\n",
    "                correct_mask_int = (pred_int[:, (args.dim-2)::args.dim] == y[:, (args.dim-2)::args.dim])\n",
    "                acc_ints.append(correct_mask_int.float())\n",
    "                \n",
    "            losses = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), reduction='none')\n",
    "            losses = losses.reshape(logits.size(0), -1)\n",
    "            pred = logits.argmax(-1)\n",
    "            logits_ints.append(logits)\n",
    "        \n",
    "            # print(\"preds and tgts shape: \",pred.shape, y.shape)\n",
    "            \n",
    "            correct_mask = (pred[:, (args.dim-2)::args.dim] == y[:, (args.dim-2)::args.dim])\n",
    "            # print(correct_mask.shape)\n",
    "            \n",
    "            # acc_record += correct_mask\n",
    "            # loss_record += losses[:, (args.dim-2)::args.dim]\n",
    "        \n",
    "            acc_ints.append(correct_mask.float())\n",
    "            losses_ints.append(losses[:, (args.dim-2)::args.dim])\n",
    "        \n",
    "            acc_records.append(acc_ints)\n",
    "            loss_records.append(losses_ints)\n",
    "            logits_records.append(logits_ints)\n",
    "\n",
    "    return acc_records, loss_records, logits_records, tgts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d9eeade-8ad8-453e-9481-5e60d802d4b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Last token accuracy and/or loss as a function of sequence length and amount of label corruption ###\n",
    "\n",
    "no_of_tasks_tc = 64\n",
    "args.max_ctx= 32\n",
    "args.n_measure = 20\n",
    "args.to_corrupt = 2\n",
    "\n",
    "Ws = generate_all_unique_sublists(args)\n",
    "print(len(Ws))\n",
    "if args.parallelogram is True:\n",
    "    Ws = parallelogram_tasks_with_shared_components(Ws, args)\n",
    "check = list(set(tuple(W) for W in Ws))\n",
    "# print('train Ws: \\n', Ws, len(Ws), len(check))\n",
    "args.pre_train_n_tasks = len(Ws)\n",
    "\n",
    "Ws_i = [Ws[i] for i in range(no_of_tasks_tc//2) ]\n",
    "Ws_o = [Ws_ood[i] for i in range(no_of_tasks_tc//2)]\n",
    "\n",
    "Ws = Ws_i + Ws_o\n",
    "args.n_tasks = len(Ws)\n",
    "print(Ws)\n",
    "print(len(Ws))\n",
    "\n",
    "train_set, val_set, tokenizer = prepare_data(args, Ws)\n",
    "# print(train_set[:2, :])\n",
    "# exit()\n",
    "original_n_train_row = train_set.size(0)\n",
    "original_n_val_row = val_set.size(0)\n",
    "args.vocab_size = tokenizer.__len__()\n",
    "args.max_digits = tokenizer.max_digits\n",
    "args.max_digits = 2 * args.max_digits if args.pos_hint is True else args.max_digits\n",
    "args.dim = args.max_digits * (len(Ws[0]) + 1)\n",
    "\n",
    "\n",
    "\"\"\"Make copies of data\"\"\"\n",
    "if args.split_tasks is True:\n",
    "    task_rows = round(args.n_tasks * (args.task_pct / 100.0))\n",
    "    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)\n",
    "    val_set = val_set.view((args.n_tasks - task_rows), len(val_set) // task_rows, -1)\n",
    "else:\n",
    "    task_rows = args.n_tasks\n",
    "    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)\n",
    "    val_set = val_set.view(args.n_tasks, len(val_set) // task_rows, -1)\n",
    "\n",
    "\n",
    "\n",
    "#### Quick helper function to get the average last token accuracy and loss averaged over trials ###\n",
    "\n",
    "def helper_last_token_acc_loss(acc_records, loss_records):\n",
    "    acc_avg = np.zeros(acc_records[0][-1].size(0))\n",
    "    loss_avg = np.zeros(loss_records[0][-1].size(0))\n",
    "    for i in range(len(acc_records)):\n",
    "        acc_avg = np.add(acc_avg,acc_records[i][-1][:,-1])\n",
    "        loss_avg = np.add(loss_avg,loss_records[i][-1][:,-1])\n",
    "    # print(loss_avg)\n",
    "    return acc_avg/len(acc_records), loss_avg/len(loss_records)\n",
    "\n",
    "args.eval_bs = len(Ws)\n",
    "\n",
    "acc_avgs = np.full((args.eval_bs, args.max_ctx, args.max_ctx),np.nan)\n",
    "loss_avgs = np.full((args.eval_bs, args.max_ctx, args.max_ctx),np.nan)\n",
    "\n",
    "for ctx_len in range(1,args.max_ctx):\n",
    "    args.n_point_per_row = ctx_len\n",
    "    for j in range(ctx_len):\n",
    "        args.no_of_corr = j\n",
    "\n",
    "        print( float(j)/ctx_len)\n",
    "        if args.train_set is True:\n",
    "            corrupt_iter_to_use = get_dataloader_corrupt(train_set, args.eval_bs, args)\n",
    "        else:\n",
    "            corrupt_iter_to_use = get_dataloader_corrupt(val_set, args.eval_bs, args)\n",
    "        \n",
    "        with torch.inference_mode():\n",
    "            print(len(corrupt_iter_to_use))\n",
    "            acc_records, loss_records, logits_records, tgts = measure_perpos_accloss_full_int_preds(model, corrupt_iter_to_use,args, device, n_measure=args.n_measure)\n",
    "            acc_avg, loss_avg = helper_last_token_acc_loss(acc_records, loss_records)\n",
    "            # print(acc_avg.shape)\n",
    "            \n",
    "            acc_avgs[:, ctx_len, j] = acc_avg\n",
    "            loss_avgs[:, ctx_len, j] = loss_avg\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "id": "3452004a-f891-40b3-ad3b-416ec7a74432",
   "metadata": {},
   "outputs": [],
   "source": [
    "imp_datas = {\"acc\" : acc_avgs, \"loss\":loss_avgs, \"wts\":Ws}\n",
    "        \n",
    "fn = f\"./data/lc_accloss_b_n_s_x_cond_maxc_{args.max_ctx}_tr{args.n_measure}_Ws{len(Ws)}_n{args.n_embd}_h{args.n_head}_d{args.n_layer}.pkl\"    \n",
    "        \n",
    "with open(fn, 'wb') as fp:\n",
    "    pickle.dump(imp_datas,fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 224,
   "id": "75e7bbe8-620b-49c6-8615-8cb09ca24e9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ctx_label_corr_random(acc_avgs, vmin, vmax, cbar_name = 'accuracy'):\n",
    "    '''\n",
    "    Parameters\n",
    "    ----------\n",
    "    acc_avgs : array of shape (# of tasks, max_cc, max_cc)\n",
    "        Average accuracy (loss) \n",
    "    vmin, vmax : int, int\n",
    "        vmin & vmax settings for ax.imshow\n",
    "    cbar_name : str\n",
    "        Name for colorbar and/or plot name\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None.\n",
    "\n",
    "    '''\n",
    "\n",
    "    args.n_rows = 4\n",
    "    args.n_cols = (no_of_tasks_tc)//args.n_rows\n",
    "\n",
    "    fig, axs = plt.subplots(args.n_rows, args.n_cols, figsize=(26, 15), constrained_layout=True)\n",
    "\n",
    "    if vmin is None:\n",
    "        vmin = np.nanmin(acc_avgs)*0.9\n",
    "    if vmax is None:\n",
    "        vmax = np.nanmax(acc_avgs)*1.1\n",
    "    \n",
    "    handles, labels = [], []\n",
    "    for i in range(args.n_rows):\n",
    "        for j in range(args.n_cols):\n",
    "            ax = axs[i][j]\n",
    "            \n",
    "            ti = i*args.n_cols + j\n",
    "            ax.set_title(f\"W : {Ws[ti]}\")\n",
    "            to_plot = acc_avgs[ti,:,:]\n",
    "            \n",
    "            \n",
    "            im0 = ax.imshow(to_plot, vmin=vmin, vmax=vmax,cmap = 'cividis')\n",
    "            \n",
    "            ax.set_ylabel('Ctx')\n",
    "            ax.set_xlabel('# of Wrong labels')\n",
    "\n",
    "            \n",
    "\n",
    "    fig.suptitle(f'Label corruption')\n",
    "    cbar = fig.colorbar(im0, ax=axs.ravel().tolist(), fraction=0.046, pad=0.04)\n",
    "    cbar.set_label(cbar_name)  # Example label for the colorbar\n",
    "    # fig.suptitle()    \n",
    "    if args.savefig is True:\n",
    "        fig.savefig(f'label_corr/label_corr_{cbar_name}_Ws{len(Ws)}_d_{args.n_layer}_h_{args.n_head}.pdf', format = 'pdf')\n",
    "    plt.show()\n",
    "    plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52d7ce21-58cd-40b4-8632-68279257624e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fn = './data/lc_accloss_b_n_s_x_cond_maxc_32_tr20_n512_h4_d2.pkl'\n",
    "\n",
    "with open(fn, 'rb') as fp:\n",
    "    imp_datas = pickle.load(fp)\n",
    "\n",
    "acc_avgs , loss_avgs = imp_datas['acc'], imp_datas['loss']\n",
    "\n",
    "plot_ctx_label_corr_random(acc_avgs, vmin= 0.0, vmax = 1.0,cbar_name = 'accuracy')\n",
    "plot_ctx_label_corr_random(loss_avgs, vmin= None, vmax = None,cbar_name= 'loss')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c553f2-1b9e-402c-a5a3-ae4db4c22d9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Plot average over tasks ###\n",
    "\n",
    "SIZE = 40\n",
    "\n",
    "acc_avg_tasks = acc_avgs.mean(0)\n",
    "loss_avg_tasks = loss_avgs.mean(0)\n",
    "\n",
    "args.n_cols = 2\n",
    "args.n_rows = 1\n",
    "\n",
    "fig, axs = plt.subplots(args.n_rows, args.n_cols, figsize=(28, 15), constrained_layout=True)\n",
    "\n",
    "vmin = 0.0\n",
    "vmax = 1.0\n",
    "    \n",
    "ax = axs[0]\n",
    "\n",
    "to_plot = acc_avg_tasks[:,:].T\n",
    "im0 = ax.imshow(to_plot, vmin=vmin, vmax=vmax,cmap = 'cividis')\n",
    "            \n",
    "ax.set_xlabel('Shot $i$',size=SIZE)\n",
    "ax.set_xticks(np.arange(to_plot.shape[1],step=4),labels=np.arange(to_plot.shape[1],step=4),size=SIZE)\n",
    "ax.set_ylabel('# of Wrong labels',size=SIZE)\n",
    "ax.set_yticks(np.arange(to_plot.shape[1],step=4),labels=np.arange(to_plot.shape[1],step=4),size=SIZE)\n",
    "ax.invert_yaxis()\n",
    "cbar_name = 'Acc'\n",
    "cbar = fig.colorbar(im0, ax=ax, fraction=0.046, pad=0.04)\n",
    "cbar.set_label(cbar_name,size = SIZE)  # Example label for the colorbar\n",
    "cbar.ax.tick_params(labelsize=SIZE)\n",
    "\n",
    "\n",
    "vmin = np.nanmin(loss_avgs)*0.9\n",
    "vmax = np.nanmax(loss_avgs)*1.1\n",
    "cbar_name = 'Loss'\n",
    "\n",
    "ax = axs[1]\n",
    "\n",
    "to_plot = loss_avg_tasks[:,:].T\n",
    "im0 = ax.imshow(to_plot, vmin=vmin, vmax=vmax,cmap = 'inferno')\n",
    "            \n",
    "ax.set_xlabel('Shot $i$',size=SIZE)\n",
    "ax.set_xticks(np.arange(to_plot.shape[1],step=4),labels=np.arange(to_plot.shape[1],step=4),size=SIZE)\n",
    "ax.set_ylabel('# of Wrong labels',size=SIZE)\n",
    "ax.set_yticks(np.arange(to_plot.shape[1],step=4),labels=np.arange(to_plot.shape[1],step=4),size=SIZE)\n",
    "cbar = fig.colorbar(im0, ax=ax, fraction=0.046, pad=0.04)\n",
    "cbar.set_label(cbar_name,size = SIZE)  # Example label for the colorbar\n",
    "cbar.ax.tick_params(labelsize=SIZE)\n",
    "ax.invert_yaxis()\n",
    "# fig.suptitle(f'Label corruption, avg over tasks')\n",
    "if args.savefig is True:\n",
    "    fig.savefig(f'label_corr/label_corr_avg_tasks_Ws_d_2_h_4.pdf', format = 'pdf')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "573a6ea6-822d-4b5a-9716-e4a206086f78",
   "metadata": {},
   "source": [
    "# Label corruption at single location"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 207,
   "id": "1f6d1926-d1cf-42cc-a5b9-aa75abb74343",
   "metadata": {},
   "outputs": [],
   "source": [
    "def custom_collate_fn_corrupt_single(batch, bs, args):\n",
    "    inputs = torch.stack(batch, dim=1)\n",
    "    idx = torch.randperm(inputs.size(1))\n",
    "    inputs = inputs[torch.arange(args.n_tasks)[:, None], idx[None, :]]\n",
    "    targets = copy.deepcopy(inputs) # (n_tasks, bs * ctx_length, dim)\n",
    "    targets[:, :, :-args.max_digits] = -100 # Mask the unpredictable part\n",
    "    \n",
    "    # print(targets.shape)\n",
    "    \n",
    "    # to_corrupt = torch.randperm(args.n_point_per_row)[:args.no_of_corr]\n",
    "    \n",
    "    inputs[:,args.corrupt_loc,args.to_corrupt] = (inputs[:,args.corrupt_loc, args.to_corrupt]+torch.randint(1,args.p,(inputs.size(0), )))%args.p\n",
    "    \n",
    "    return inputs.view(bs, -1), targets.view(bs, -1)  # (bs, ctx_length * dim)\n",
    "\n",
    "\n",
    "def get_dataloader_corrupt_single(dataset, bs, args):\n",
    "    custom_dataset = CustomDatasetCorruption(dataset, bs, args)\n",
    "    collate_fn = partial(custom_collate_fn_corrupt_single, bs=bs, args=args) # collate_fn should only take one input\n",
    "    if args.ddp is False:\n",
    "        g = torch.Generator()\n",
    "        g.manual_seed(args.seed + 38493483)\n",
    "        bss = (bs * args.n_point_per_row // args.n_tasks)\n",
    "        print(type(bss), type(bs), bss,bs)\n",
    "        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks), num_workers=0, collate_fn=collate_fn, drop_last=True, generator=g) \n",
    "    else:\n",
    "        sampler = DistributedSampler(custom_dataset, shuffle=False)\n",
    "        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks) // args.world_size, drop_last=True, num_workers=0, collate_fn=collate_fn, sampler=sampler) \n",
    "    return dataloader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec11ddbf-009f-4b42-a1fa-dc465d0a4c80",
   "metadata": {},
   "outputs": [],
   "source": [
    "args.n_measure = 3\n",
    "sl = 32\n",
    "\n",
    "args.to_corrupt = 2\n",
    "### pos = 0 : corrupt x's\n",
    "### pos = 1 : corrupt y's\n",
    "### pos = 2 : corrupt f's\n",
    "args.train_set = False\n",
    "datas = {}\n",
    "for seq_len in range(1,sl+1):\n",
    "    args.n_point_per_row = seq_len\n",
    "    for i in range(seq_len):\n",
    "        # args.new_seq_len = seq_len\n",
    "        args.corrupt_loc = i\n",
    "        print(\"-\"*100)\n",
    "        print(\"Seq len : \", seq_len)\n",
    "        print(\"-\"*100)\n",
    "        if args.train_set is True:\n",
    "            corrupt_iter_to_use = get_dataloader_corrupt_single(train_set, args.eval_bs, args)\n",
    "        else:\n",
    "            corrupt_iter_to_use = get_dataloader_corrupt_single(val_set, args.eval_bs, args)\n",
    "    \n",
    "        with torch.inference_mode():\n",
    "            print(len(corrupt_iter_to_use))\n",
    "            step = 0\n",
    "            for t, (x,y) in enumerate(corrupt_iter_to_use):\n",
    "                x = x[:, :-1].contiguous().to(device)\n",
    "                y = y[:, 1:].contiguous().to(device)\n",
    "                # print(x[0,:])\n",
    "                # print(y[0,:])\n",
    "                step+=1\n",
    "                print(\"Trial : \", step)\n",
    "                # print((x[:,pos+s*(args.n_var+1)]==xc[:,pos+s*(args.n_var+1)]).float().mean())\n",
    "            \n",
    "                logits = model(x)\n",
    "        \n",
    "                losses = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), reduction='none')\n",
    "                losses = losses.reshape(logits.size(0), -1)\n",
    "                pred = logits.argmax(-1)\n",
    "                loss_record = losses[:, (args.dim-2)::args.dim]\n",
    "                # print(\"preds and tgts shape: \",pred.shape, y.shape)\n",
    "        \n",
    "                acc_record = (pred[:, (args.dim-2)::args.dim] == y[:, (args.dim-2)::args.dim]).float()\n",
    "                # print(acc_record.mean())\n",
    "                # print(loss_record.mean())\n",
    "                \n",
    "                datas[f\"s:{seq_len}, tr:{step}, loc:{args.corrupt_loc}\"] = (x,y,logits, pred, acc_record, loss_record)\n",
    "                \n",
    "                # datas[f\"s:{seq_len}, tr:{step}\"].append()\n",
    "                if step >= args.n_measure:\n",
    "                    break\n",
    "        \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 234,
   "id": "a53b79b3-4bc5-431a-81bb-3cb549c5b42e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pos=2\n",
    "fn = f\"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl\"    \n",
    "        \n",
    "with open(fn, 'wb') as fp:\n",
    "    pickle.dump(datas,fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "760f20f4-7437-4259-b02b-fda40901db84",
   "metadata": {},
   "outputs": [],
   "source": [
    "### First check the mask for the initial case, and pick a case where it's correct.\n",
    "### Then check what happens to the mask as we corrupt all the previous cases.\n",
    "\n",
    "fn = f\"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl\"    \n",
    "        \n",
    "with open(fn, 'rb') as fp:\n",
    "    datas = pickle.load(fp)\n",
    "\n",
    "# print(datas.keys())\n",
    "step = 1 ### trial no.\n",
    "ic = np.random.randint(no_of_tasks_tc)\n",
    "seq_len = 31\n",
    "locc = 0\n",
    "xc,yc,logits, pred, acc_record, loss_record = datas[f\"s:{seq_len}, tr:{step}, loc:{locc}\"]\n",
    "\n",
    "plt.plot(acc_record[ic,:],marker=\".\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 211,
   "id": "f1908f28-ff5a-41c0-81ff-81b052edbad1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def label_corr_exp(max_cc, step,ic=0,pos=0):\n",
    "    \n",
    "    fn = f\"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl\"    \n",
    "            \n",
    "    with open(fn, 'rb') as fp:\n",
    "        datas = pickle.load(fp)\n",
    "    \n",
    "    \n",
    "    mask = np.full((max_cc-1,max_cc), np.nan)\n",
    "\n",
    "    for cc in range(1,max_cc):\n",
    "        for s in range(cc):\n",
    "            xc,yc,logits, pred, acc_record, loss_record = datas[f\"s:{cc}, tr:{step}, loc:{s}\"]\n",
    "            mask[cc-1,s] = (acc_record[ic,cc-1])\n",
    "    \n",
    "    \n",
    "    args.n_rows = 1\n",
    "    args.n_cols = 1\n",
    "\n",
    "    fig, axs = plt.subplots(args.n_rows, args.n_cols, figsize=(15, 15), constrained_layout=True)\n",
    "    \n",
    "    ax = axs\n",
    "\n",
    "    vmin=0.0\n",
    "    vmax= 1.0\n",
    "    cbar_name = \"Mask\"\n",
    "    im0 = ax.imshow(mask, cmap = 'viridis',vmin=vmin,vmax=vmax)\n",
    "    cbar = fig.colorbar(im0, ax=ax, fraction=0.046, pad=0.04)\n",
    "    #cbar.set_ticks(np.linspace(vmin, vmax, num=30))\n",
    "    cbar.set_label(cbar_name)\n",
    "\n",
    "    ax.set_ylabel(\"Len of seq\")\n",
    "    ax.set_xlabel(\"Label corrupted\")\n",
    "\n",
    "    fig.suptitle(f'{Ws[ic]}, Trial : {step}, Pos: {pos} \\n X :{xc[0,:]}')\n",
    "    if args.savefig is True:\n",
    "        fig.savefig(f'./label_corr/label_corr_single_pos{pos}_maxcc_{max_cc}_tr{step}_Ws{Ws[ic]}_b_n_s_{args.model_name}_p{args.p}_row{args.n_point_per_row}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_tspl{args.train_set}_ood{args.ood_tasks}_dfrac{args.data_pct:.1f}_{args.act_name}_n{args.n_embd}_h{args.n_head}_d{args.n_layer}_I{args.seed}_dI{args.data_seed}_{args.optim}_bs{args.bs}_t{args.ckpt_step:d}_T{args.steps:d}_Tw{args.warmup_steps:d}_lr{args.lr:0.2e}_wd{args.wd:.2e}.pdf', format='pdf')\n",
    "    plt.show()\n",
    "    plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "030f76c2-4d82-4fc3-83cb-62f5d4ff8efd",
   "metadata": {},
   "outputs": [],
   "source": [
    "pos = 2\n",
    "for step in range(1, args.n_measure+1):\n",
    "    max_cc = 32\n",
    "    label_corr_exp(max_cc, step,pos=pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ca11443-7144-4a79-959e-2b419545724d",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Better plots for paper ### \n",
    "\n",
    "def label_corr_exp2(max_cc, step,sl,no_tasks = 20,pos=2):\n",
    "    \n",
    "    fn = f\"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl\"    \n",
    "            \n",
    "    with open(fn, 'rb') as fp:\n",
    "        datas = pickle.load(fp)\n",
    "    \n",
    "    mask = np.full((no_tasks, max_cc-1,max_cc-1), np.nan)\n",
    "\n",
    "    for cc in range(1,max_cc):\n",
    "        for s in range(cc-1):\n",
    "            xc,yc,logits, pred, acc_record, loss_record = datas[f\"s:{cc}, tr:{step}, loc:{s}\"]\n",
    "            mask[:no_tasks,cc-1,s] = (acc_record[:no_tasks,cc-1])\n",
    "    \n",
    "    xp = torch.cat((xc, yc[:,-1].unsqueeze(1)), dim=-1)\n",
    "    # print(xp)\n",
    "    # print(xp.shape)\n",
    "    xp = np.array(xp.view(xp.size(0),-1,args.n_var+1))\n",
    "    print(xp.shape)\n",
    "    args.n_rows = 4\n",
    "    args.n_cols = no_tasks//args.n_rows\n",
    "\n",
    "    sns.set_theme(style = 'whitegrid')\n",
    "\n",
    "    # Apply the seaborn-whitegrid style\n",
    "    # plt.style.use('seaborn-whitegrid')\n",
    "\n",
    "    fig, axs = plt.subplots(args.n_rows, args.n_cols, figsize=(25, 20), constrained_layout=True)\n",
    "    vmin=0.0\n",
    "    vmax= 1.0\n",
    "    cbar_name = \"Acc\"\n",
    "    \n",
    "    \n",
    "    for i in range(args.n_rows):\n",
    "        for j in range(args.n_cols):\n",
    "            ax = axs[i][j]\n",
    "            ti = i*args.n_cols + j\n",
    "            top = mask[ti,:,:]\n",
    "\n",
    "            im0 = sns.heatmap(top, ax=ax, cmap = 'viridis',vmin=vmin,vmax=vmax,cbar=False)\n",
    "\n",
    "            if(i==args.n_rows-1):\n",
    "                ax.set_xlabel(\"$z'_j$\")\n",
    "            if (j == 0):\n",
    "                ax.set_ylabel(\"Shot $i$\")\n",
    "\n",
    "            ax.set_title(f'W : {Ws[ti]}')\n",
    "\n",
    "            ytick_labels = [tuple(xp[ti, idx]) for idx in range(xp.shape[1])]\n",
    "            # print(xtick_labels)\n",
    "            ax.set_yticks(np.arange(xp.shape[1])+1)\n",
    "            ax.set_yticklabels(ytick_labels, rotation=0, fontsize=8)\n",
    "\n",
    "    # print(axs[0][0].collections)\n",
    "    \n",
    "    cbar = fig.colorbar(axs[0][0].collections[0], ax=axs.ravel().tolist(), fraction=0.046, pad=0.04)\n",
    "    cbar.set_ticks(np.linspace(vmin, vmax, num=2))\n",
    "    cbar.set_label(cbar_name)\n",
    "    \n",
    "\n",
    "    print(mask.shape)\n",
    "\n",
    "    \n",
    "\n",
    "    # fig.suptitle(f'{Ws[ic]}, Trial : {step}, Pos: {pos} \\n X :{xc[0,:]}')\n",
    "    if args.savefig is True:\n",
    "        fig.savefig(f'./label_corr/label_corr_single_pos{pos}_maxcc_{max_cc}_tr{step}_Ws{Ws[ic]}_b_n_s_{args.model_name}_p{args.p}_row{args.n_point_per_row}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_tspl{args.train_set}_ood{args.ood_tasks}_dfrac{args.data_pct:.1f}_{args.act_name}_n{args.n_embd}_h{args.n_head}_d{args.n_layer}_I{args.seed}_dI{args.data_seed}_{args.optim}_bs{args.bs}_t{args.ckpt_step:d}_T{args.steps:d}_Tw{args.warmup_steps:d}_lr{args.lr:0.2e}_wd{args.wd:.2e}.pdf', format='pdf')\n",
    "    plt.show()\n",
    "    plt.close()\n",
    "\n",
    "pos = 2\n",
    "for step in range(1, 2):\n",
    "    max_cc = 32\n",
    "    label_corr_exp2(max_cc, step,sl=32,pos=pos)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6981430-48ad-4640-a3f3-a7afeea97d29",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Final plot for paper ###\n",
    "\n",
    "def label_corr_exp3(max_cc,sl, step=1,no_tasks = 20,pos=2):\n",
    "\n",
    "    SIZE = 36\n",
    "    FRAC = 0.6\n",
    "\n",
    "    cmap = sns.color_palette(['black', '#FFFDD0'])\n",
    "\n",
    "\n",
    "    args.n_rows = 1\n",
    "    args.n_cols = 4\n",
    "\n",
    "    sns.set_theme(style = 'whitegrid')\n",
    "\n",
    "    # Apply the seaborn-whitegrid style\n",
    "    # plt.style.use('seaborn-whitegrid')\n",
    "\n",
    "    fig, axs = plt.subplots(args.n_rows, args.n_cols, figsize=(53, 13), constrained_layout=True)\n",
    "    vmin=0.0\n",
    "    vmax= 1.0\n",
    "    cbar_name = \"Acc\"\n",
    "\n",
    "    ##############################################################################\n",
    "    ### TRAIN ###\n",
    "\n",
    "    \n",
    "    args.n_measure = 3\n",
    "    args.train_set = True\n",
    "\n",
    "    fn = f\"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl\"    \n",
    "            \n",
    "    with open(fn, 'rb') as fp:\n",
    "        datas = pickle.load(fp)\n",
    "    \n",
    "    mask = np.full((no_tasks, max_cc-1,max_cc-1), np.nan)\n",
    "\n",
    "    for cc in range(1,max_cc):\n",
    "        for s in range(cc-1):\n",
    "            xc,yc,logits, pred, acc_record, loss_record = datas[f\"s:{cc}, tr:{step}, loc:{s}\"]\n",
    "            mask[:no_tasks,cc-1,s] = (acc_record[:no_tasks,cc-1])\n",
    "    \n",
    "    xp = torch.cat((xc, yc[:,-1].unsqueeze(1)), dim=-1)\n",
    "    # print(xp)\n",
    "    # print(xp.shape)\n",
    "    xp = np.array(xp.view(xp.size(0),-1,args.n_var+1))\n",
    "    print(xp.shape)\n",
    "\n",
    "    ##############################################################################\n",
    "    #### $\\mathcal{S}^{\\mathrm{i.d.}_{\\mathrm{train}}$ ###########################\n",
    "    \n",
    "    ax = axs[0]\n",
    "    ti = 0\n",
    "    top = mask[ti,:,:].T\n",
    "\n",
    "    # im0 = sns.heatmap(top, ax=ax, cmap = 'viridis',vmin=vmin,vmax=vmax,cbar=False)\n",
    "    im0 = sns.heatmap(top, ax=ax, cmap = cmap,vmin=vmin,vmax=vmax,cbar=False, cbar_kws={'ticks': [0, 1]})\n",
    "    \n",
    "    ax.set_ylabel(\"$z'_j$\",size=SIZE)\n",
    "    ax.set_xlabel(\"Shot $i$\",size=SIZE)\n",
    "\n",
    "    ax.set_title('$\\mathcal{S}^{\\mathrm{i.d.}}_{\\mathrm{train}}$',size=1.5*SIZE)\n",
    "\n",
    "    ytick_labels = [tuple(xp[ti, idx]) for idx in range(xp.shape[1])]\n",
    "    # print(xtick_labels)\n",
    "    ax.set_yticks(np.arange(xp.shape[1],step=4)+1)\n",
    "    ax.set_yticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)\n",
    "    ax.set_xticks(np.arange(xp.shape[1],step=4))\n",
    "    ax.set_xticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)\n",
    "\n",
    "    ax.invert_yaxis()\n",
    "\n",
    "    ##############################################################################\n",
    "\n",
    "    ##############################################################################\n",
    "    #### $\\mathcal{S}^{\\mathrm{o.o.d.}_{\\mathrm{train}}$ #########################\n",
    "\n",
    "    ax = axs[1]\n",
    "    ti = no_tasks//2\n",
    "    top = mask[ti,:,:].T\n",
    "\n",
    "    # im0 = sns.heatmap(top, ax=ax, cmap = 'viridis',vmin=vmin,vmax=vmax,cbar=False)\n",
    "    im0 = sns.heatmap(top, ax=ax, cmap = cmap,vmin=vmin,vmax=vmax,cbar=False, cbar_kws={'ticks': [0, 1]})\n",
    "    \n",
    "    ax.set_ylabel(\"$z'_j$\",size=SIZE)\n",
    "    ax.set_xlabel(\"Shot $i$\",size=SIZE)\n",
    "\n",
    "    ax.set_title('$\\mathcal{S}^{\\mathrm{o.o.d.}}_{\\mathrm{train}}$',size=1.5*SIZE)\n",
    "\n",
    "    ytick_labels = [tuple(xp[ti, idx]) for idx in range(xp.shape[1])]\n",
    "    # print(xtick_labels)\n",
    "    ax.set_yticks(np.arange(xp.shape[1],step=4)+1)\n",
    "    ax.set_yticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)\n",
    "    ax.set_xticks(np.arange(xp.shape[1],step=4))\n",
    "    ax.set_xticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)\n",
    "\n",
    "    ax.invert_yaxis()\n",
    "\n",
    "    ##############################################################################\n",
    "\n",
    "    ##############################################################################\n",
    "    ### TEST ###\n",
    "\n",
    "    args.n_measure = 5\n",
    "    args.train_set = False\n",
    "    \n",
    "    fn = f\"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{no_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl\"    \n",
    "            \n",
    "    with open(fn, 'rb') as fp:\n",
    "        datas = pickle.load(fp)\n",
    "    \n",
    "    mask = np.full((no_tasks, max_cc-1,max_cc-1), np.nan)\n",
    "\n",
    "    for cc in range(1,max_cc):\n",
    "        for s in range(cc-1):\n",
    "            xc,yc,logits, pred, acc_record, loss_record = datas[f\"s:{cc}, tr:{step}, loc:{s}\"]\n",
    "            mask[:no_tasks,cc-1,s] = (acc_record[:no_tasks,cc-1])\n",
    "    \n",
    "    xp = torch.cat((xc, yc[:,-1].unsqueeze(1)), dim=-1)\n",
    "    # print(xp)\n",
    "    # print(xp.shape)\n",
    "    xp = np.array(xp.view(xp.size(0),-1,args.n_var+1))\n",
    "    print(xp.shape)\n",
    "    \n",
    "\n",
    "    ##############################################################################\n",
    "    #### $\\mathcal{S}^{\\mathrm{i.d.}_{\\mathrm{test}}$ ############################\n",
    "    \n",
    "    ax = axs[2]\n",
    "    ti = 0\n",
    "    top = mask[ti,:,:].T\n",
    "\n",
    "    # im0 = sns.heatmap(top, ax=ax, cmap = 'viridis',vmin=vmin,vmax=vmax,cbar=False)\n",
    "    im0 = sns.heatmap(top, ax=ax, cmap = cmap,vmin=vmin,vmax=vmax,cbar=False, cbar_kws={'ticks': [0, 1]})\n",
    "    \n",
    "    ax.set_ylabel(\"$z'_j$\",size=SIZE)\n",
    "    ax.set_xlabel(\"Shot $i$\",size=SIZE)\n",
    "\n",
    "    ax.set_title('$\\mathcal{S}^{\\mathrm{i.d.}}_{\\mathrm{test}}$',size=1.5*SIZE)\n",
    "\n",
    "    ytick_labels = [tuple(xp[ti, idx]) for idx in range(xp.shape[1])]\n",
    "    # print(xtick_labels)\n",
    "    ax.set_yticks(np.arange(xp.shape[1],step=4)+1)\n",
    "    ax.set_yticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)\n",
    "    ax.set_xticks(np.arange(xp.shape[1],step=4))\n",
    "    ax.set_xticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)\n",
    "\n",
    "    ax.invert_yaxis()\n",
    "\n",
    "    ##############################################################################\n",
    "\n",
    "    ##############################################################################\n",
    "    #### $\\mathcal{S}^{\\mathrm{o.o.d.}_{\\mathrm{test}}$ ##########################\n",
    "\n",
    "    ax = axs[3]\n",
    "    ti = no_tasks//2\n",
    "    top = mask[ti,:,:].T\n",
    "\n",
    "    # im0 = sns.heatmap(top, ax=ax, cmap = 'viridis',vmin=vmin,vmax=vmax,cbar=False)\n",
    "    im0 = sns.heatmap(top, ax=ax, cmap = cmap,vmin=vmin,vmax=vmax,cbar=False, cbar_kws={'ticks': [0, 1]})\n",
    "    \n",
    "    ax.set_ylabel(\"$z'_j$\",size=SIZE)\n",
    "    ax.set_xlabel(\"Shot $i$\",size=SIZE)\n",
    "\n",
    "    ax.set_title('$\\mathcal{S}^{\\mathrm{o.o.d.}}_{\\mathrm{test}}$',size=1.5*SIZE)\n",
    "\n",
    "    xtick_labels = [tuple(xp[ti, idx]) for idx in range(xp.shape[1])]\n",
    "    # print(xtick_labels)\n",
    "    ax.set_yticks(np.arange(xp.shape[1],step=4)+1)\n",
    "    ax.set_yticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)\n",
    "    ax.set_xticks(np.arange(xp.shape[1],step=4))\n",
    "    ax.set_xticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)\n",
    "\n",
    "    ax.invert_yaxis()\n",
    "\n",
    "    ##############################################################################\n",
    "\n",
    "    \n",
    "    cbar = fig.colorbar(axs[0].collections[0], ax=axs.ravel().tolist(), fraction=0.046, pad=0.01, ticks=[0, 1])\n",
    "    cbar.set_ticks([0.25, 0.75])\n",
    "    cbar.set_ticklabels(['incorrect', 'correct'], size = SIZE)\n",
    "\n",
    "    # fig.suptitle(f'{Ws[ic]}, Trial : {step}, Pos: {pos} \\n X :{xc[0,:]}')\n",
    "    if args.savefig is True:\n",
    "        fig.savefig(f'./label_corr/lc_sin_fin_pos{pos}_maxcc_{max_cc}_tr{step}_n{args.n_embd}_h{args.n_head}_d{args.n_layer}.pdf',format='pdf')\n",
    "    plt.show()\n",
    "    plt.close()\n",
    "\n",
    "print(args.n_tasks)\n",
    "pos = 2\n",
    "for step in range(1, 2):\n",
    "    max_cc = 32\n",
    "    label_corr_exp3(max_cc,sl=32,pos=pos)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
