{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","source":["%load_ext autoreload\n","%autoreload 2"],"metadata":{"id":"OHsa_ecq8945"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Installing needed libraries"],"metadata":{"id":"fMttgn9uN56J"}},{"cell_type":"code","source":["!pip install --upgrade datasets transformers\n","!pip install --upgrade pytorch-lightning\n","!pip install latentis"],"metadata":{"id":"gGMIWLVPjHEE"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Mount drive"],"metadata":{"id":"zeBWvUx4N-v0"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"ADklp5AnfHWS"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Imports"],"metadata":{"id":"MCWf_ZqDOBNF"}},{"cell_type":"code","source":["import os\n","import shutil\n","import torch\n","from torch import nn, optim\n","from pytorch_lightning import seed_everything\n","from transformers import (\n","    AutoModel,\n","    AutoConfig,\n","    AutoImageProcessor,\n","    CLIPVisionConfig,\n","    CLIPImageProcessor,\n","    CLIPVisionModel,\n",")\n","from datasets import DatasetDict, load_dataset, load_from_disk, DownloadConfig, VerificationMode\n","from tqdm import tqdm\n","from torch.utils.data import DataLoader\n","import functools\n","from pathlib import Path\n","import pandas as pd\n","from typing import List\n","import itertools"],"metadata":{"id":"d4XoerZAtq6r"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import sys\n","project_path = '/content/drive/MyDrive/tba-moss'\n","sys.path.append(project_path)\n","\n","from dictionaries import DATASET2IMAGE_COLUMN, DATASET2LABEL_COLUMN, DATASET_NAME2HF_NAME, MODEL2CONFIGS, DATASET2NUM_CLASSES\n","from utils import image_encode, extract_representations\n","from module import SkipModel, HFwrapper, NoEncoder\n","from train_NN import train_classifier"],"metadata":{"id":"yKDlNPFYOHZN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"],"metadata":{"id":"ehc2zKG3OIja"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Encode data"],"metadata":{"id":"oatQ_8kszJgd"}},{"cell_type":"code","source":["@torch.no_grad()\n","def encode_data(loader, skip_encoder):\n","    embeddings = []\n","    skip_encoder.eval()\n","\n","    for batch in tqdm(loader, desc=\"Encoding Batches with SkipModel\"):\n","        image_input = batch.get(\"pixel_values\", batch.get(\"images\"))\n","        if image_input is None:\n","            raise KeyError(\"Batch missing required key 'pixel_values' or 'images'\")\n","        image_input = image_input.to(device)\n","\n","        attn_mask = batch.get(\"attention_mask\", None)\n","        if attn_mask is not None:\n","            attn_mask = attn_mask.to(device)\n","\n","        x = skip_encoder(image_input, attention_mask=attn_mask)\n","        embeddings.extend(x.cpu().tolist())\n","\n","    return embeddings\n","\n","@torch.no_grad()\n","def run_encoding(\n","    dataset_name: str,\n","    encoder_name: str,\n","    translator_name: str,\n","    seed: int,\n","    samples_to_extract: int,\n","    batch_size: int,\n","    skips: list = [[], [(10, 11)]],\n","    mode: int = 1,\n","):\n","\n","    seed_everything(seed)\n","    split2encoding = {}\n","\n","    if encoder_name not in MODEL2CONFIGS:\n","        raise ValueError(f\"Model configuration not found for {encoder_name}. Please add it to MODEL2CONFIGS.\")\n","\n","    model_config = MODEL2CONFIGS[encoder_name]\n","\n","    print(f\"Dataset: {dataset_name}, Encoder: {encoder_name}, Translator: {translator_name}, Skips: {skips}\")\n","\n","    DATASET_DIR = (\n","        Path(project_path) / \"embeddings\" / dataset_name / encoder_name.split(\"/\")[1]\n","    )\n","\n","\n","    DATASET_DIR.mkdir(parents=True, exist_ok=True)\n","    data: DatasetDict = DatasetDict(\n","        train=load_dataset(DATASET_NAME2HF_NAME[dataset_name], split=\"train\"),\n","        test=load_dataset(DATASET_NAME2HF_NAME[dataset_name], split=\"test\"),\n","    )\n","\n","    print(f\"Loading HF AutoModel: {encoder_name}\")\n","    config = AutoConfig.from_pretrained(encoder_name, output_hidden_states=True, return_dict=True)\n","    processor = AutoImageProcessor.from_pretrained(encoder_name)\n","    encoder = AutoModel.from_pretrained(encoder_name, config=config)\n","    collate_fn = functools.partial(\n","        image_encode,\n","        processor=processor,\n","        image_name=DATASET2IMAGE_COLUMN[dataset_name],\n","        label_name=DATASET2LABEL_COLUMN[dataset_name],\n","    )\n","\n","    encoder.eval().to(device)\n","\n","    train_loader = DataLoader(\n","        data[\"train\"],\n","        batch_size=batch_size,\n","        pin_memory=True,\n","        shuffle=False,\n","        num_workers=8,\n","        collate_fn=collate_fn,\n","    )\n","\n","    test_loader = DataLoader(\n","        data[\"test\"],\n","        batch_size=batch_size,\n","        pin_memory=True,\n","        shuffle=False,\n","        num_workers=8,\n","        collate_fn=collate_fn,\n","    )\n","\n","    all_layer_embeddings = extract_representations(\n","        encoder=encoder,\n","        max_samples=samples_to_extract,\n","        loader=train_loader,\n","        model_config=model_config,\n","        model_is_open_clip=encoder_name.startswith(\"open_clip:\"),\n","        seed=seed,\n","    )\n","    print(f\"Captured embeddings for layers: {list(all_layer_embeddings.keys())}\")\n","\n","    for skip in tqdm(skips, desc=\"Encoding Different Skips\"):\n","        print(f\"\\nProcessing skip: {skip}\")\n","\n","        split2encoding = {}\n","\n","        skip_encoder = SkipModel(\n","            encoder=encoder,\n","            skips=skip,\n","            mode=mode,\n","            precomputed_embeddings=all_layer_embeddings,\n","            translator_factory_name=translator_name,\n","            **model_config,\n","        )\n","        skip_encoder = skip_encoder.to(device).eval()\n","\n","        split2encoding[\"train\"] = encode_data(loader=train_loader, skip_encoder=skip_encoder)\n","        split2encoding[\"test\"] = encode_data(loader=test_loader, skip_encoder=skip_encoder)\n","\n","        print(\"Saving results to disk...\")\n","        for split, encoding in split2encoding.items():\n","            if not encoding:\n","                print(f\"Warning: No embeddings generated for split '{split}', skip '{skip}'. Skipping saving.\")\n","                continue\n","            column_name = str(skip)\n","            if column_name not in data[split].column_names:\n","                if len(encoding) != len(data[split]):\n","                    print(\n","                        f\"Error: Encoding length ({len(encoding)}) does not match dataset length ({len(data[split])}) for split '{split}', skip '{skip}'.\"\n","                    )\n","                    continue\n","                data[split] = data[split].add_column(column_name, encoding)\n","            else:\n","                final_column_name = f\"{column_name}_new\"\n","                print(f\"Column '{column_name}' already exists. Saving them with a new name: {final_column_name}\")\n","                data[split] = data[split].add_column(final_column_name, encoding)\n","\n","        del skip_encoder\n","        torch.cuda.empty_cache()\n","\n","        if DATASET_DIR.exists():\n","            temp_dir = DATASET_DIR.parent / f\"{DATASET_DIR.name}_temp\"\n","            try:\n","                if temp_dir.exists():\n","                    shutil.rmtree(temp_dir)\n","                data.save_to_disk(str(temp_dir))\n","                shutil.rmtree(DATASET_DIR)\n","                shutil.move(str(temp_dir), DATASET_DIR)\n","                print(f\"Saved intermediate results for skip {skip} to {DATASET_DIR}\")\n","            except Exception as e:\n","                print(f\"Error saving intermediate results: {e}\")\n","                if temp_dir.exists():\n","                    shutil.rmtree(temp_dir)\n","        else:\n","            DATASET_DIR.mkdir(parents=True, exist_ok=True)\n","            data.save_to_disk(str(DATASET_DIR))\n","            print(f\"Saved initial results for skip {skip} to {DATASET_DIR}\")"],"metadata":{"id":"LrmgBOhzxx1F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["run_encoding(\n","    dataset_name = \"cifar10\",\n","    encoder_name=\"facebook/dinov2-small\",\n","    translator_name=\"linear\",\n","    seed=0,\n","    samples_to_extract=500,\n","    batch_size = 256,\n","    skips = [[], [(10, 11)]], # [] is the original model\n","\n",")"],"metadata":{"id":"SaWqMJdzppsi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Classification"],"metadata":{"id":"jW-lZY5oS77B"}},{"cell_type":"code","source":["def classification(\n","    dataset_name: str,\n","    model_name: str,\n","    layers_to_approximate: List,\n","    seed: int,\n","    batch_size: int,\n","):\n","\n","    seed_everything(seed)\n","\n","    model_name_slug = model_name.split(\"/\")[-1]\n","\n","    EMBEDDINGS_DIR = str(\n","        Path(project_path) / \"embeddings\" / dataset_name / model_name_slug\n","    )\n","\n","    print(f\"Loading embeddings from: {EMBEDDINGS_DIR}\")\n","\n","    if not os.path.exists(EMBEDDINGS_DIR):\n","        raise FileNotFoundError(f\"Embeddings not found: {EMBEDDINGS_DIR}.\")\n","    embeddings = DatasetDict.load_from_disk(EMBEDDINGS_DIR)\n","    embeddings.set_format(\"torch\")\n","\n","    if model_name not in MODEL2CONFIGS:\n","        raise ValueError(f\"Model configuration not found for '{model_name}' in MODEL2CONFIGS.\")\n","\n","    print(f'Approximating {layers_to_approximate}')\n","    embedding_col_name = str(layers_to_approximate)\n","\n","    if (embedding_col_name not in embeddings[\"train\"].column_names) or (\n","        embedding_col_name not in embeddings[\"test\"].column_names\n","    ):\n","        raise KeyError(f\"Skip '{embedding_col_name}' not found in loaded embeddings.\")\n","\n","    label_col_name = DATASET2LABEL_COLUMN[dataset_name]\n","\n","    hf_train_embeddings = (\n","        embeddings[\"train\"]\n","        .select_columns([embedding_col_name, label_col_name])\n","        .rename_column(embedding_col_name, \"images\")\n","        .rename_column(label_col_name, \"labels\")\n","    )\n","\n","    hf_test_embeddings = (\n","        embeddings[\"test\"]\n","        .select_columns([embedding_col_name, label_col_name])\n","        .rename_column(embedding_col_name, \"images\")\n","        .rename_column(label_col_name, \"labels\")\n","    )\n","\n","    batch_size = batch_size\n","    num_workers = 2\n","    num_classes = DATASET2NUM_CLASSES[dataset_name]\n","\n","    train_dataloader = DataLoader(\n","        hf_train_embeddings, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True\n","    )\n","    test_dataloader = DataLoader(\n","        hf_test_embeddings, shuffle=False, batch_size=batch_size, num_workers=num_workers, pin_memory=True\n","    )\n","\n","    sample_embedding = embeddings[\"train\"][0][embedding_col_name]\n","    hidden_size = sample_embedding.shape[-1]\n","\n","    classifier = nn.Linear(hidden_size, num_classes)\n","    no_encoder = NoEncoder(embeddings=None)\n","    skip_model = HFwrapper(encoder=no_encoder, classifier=classifier)\n","    skip_model.to(device)\n","    skip_model.freeze_encoder()\n","\n","    lr = 0.01\n","    num_epochs = 5\n","    optimizer = optim.Adam(skip_model.parameters(), lr=lr)\n","\n","    print(\"Starting classifier training...\")\n","    _, _, _, eval_accuracies, _ = train_classifier(\n","        model=skip_model,\n","        train_data_loader=train_dataloader,\n","        test_data_loader=test_dataloader,\n","        optimizer=optimizer,\n","        criterion=nn.CrossEntropyLoss(),\n","        label_column_name=\"labels\",\n","        num_epochs=num_epochs,\n","    )\n","    accuracy = eval_accuracies[-1]\n","    print(f\"Training finished. Final accuracy: {accuracy:.4f}\")\n","\n","    columns = [\n","        \"seed\",\n","        \"dataset\",\n","        \"model\",\n","        \"approx_layer\",\n","        \"accuracy\",\n","        \"delta_acc\",\n","    ]\n","\n","    results_path = Path(project_path) / \"results.csv\"\n","\n","    if os.path.exists(results_path):\n","        try:\n","            results_df = pd.read_csv(results_path)\n","        except pd.errors.EmptyDataError:\n","            print(f\"Results file {results_path} is empty. Initializing DataFrame.\")\n","            results_df = pd.DataFrame(columns=columns)\n","        except Exception as e:\n","            print(f\"Error reading results file {results_path}: {e}. Initializing DataFrame.\")\n","            results_df = pd.DataFrame(columns=columns)\n","    else:\n","        results_path.parent.mkdir(parents=True, exist_ok=True)\n","        results_df = pd.DataFrame(columns=columns)\n","\n","    results_list = []\n","    results = {}\n","    original_accuracy = 0.0\n","    baseline_skip_repr = str([])\n","\n","    if str(layers_to_approximate) == baseline_skip_repr:\n","        original_accuracy = accuracy\n","    else:\n","        filtered_df = results_df[\n","            (results_df[\"approx_layer\"] == \"[]\")\n","            & (results_df[\"dataset\"] == dataset_name)\n","            & (results_df[\"model\"] == model_name)\n","            & (results_df[\"seed\"] == seed)        ]\n","        original_accuracy = filtered_df[\"accuracy\"].iloc[0] if not filtered_df.empty else 0.0\n","\n","    delta_acc = (\n","        original_accuracy - accuracy if original_accuracy is not None and original_accuracy != 0.0 else 0.0\n","    )\n","\n","    results = {\n","        \"seed\": seed,\n","        \"dataset\": dataset_name,\n","        \"model\": model_name,\n","        \"approx_layer\": layers_to_approximate,\n","        \"accuracy\": accuracy,\n","        \"delta_acc\": delta_acc,\n","        }\n","\n","    results_list.append(results)\n","\n","    new_results_df = pd.DataFrame(results_list)\n","    results_df = pd.concat([results_df, new_results_df])\n","    results_df.to_csv(results_path, index=False)"],"metadata":{"id":"YCidbkwlkKsN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["approximations = [[], [(10, 11)]]  # important: these layers as to be the same as in the encoding part\n","seeds = [0, 1, 2]\n","\n","for approximation_config, seed_value in itertools.product(approximations, seeds):\n","    classification(\n","        dataset_name=\"cifar10\",\n","        model_name=\"facebook/dinov2-small\",\n","        layers_to_approximate=approximation_config,\n","        seed=seed_value,\n","        batch_size=256,\n","    )\n","\n","results = pd.read_csv(Path(project_path) / 'results.csv')\n","results.drop(columns=['seed']).groupby([\"model\", \"dataset\", \"approx_layer\"]).agg([\"mean\", \"std\"]).round(3)"],"metadata":{"id":"cQIVoSOgR2eB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"CAkfNcIRPcRU"},"execution_count":null,"outputs":[]}]}