{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Weak-to-Strong Trustworthiness\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import transformers\n",
    "from transformers import PreTrainedTokenizerBase\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification      # AutoModelForCausalLM, BitsAndBytesConfig\n",
    "from datasets import Dataset, DatasetDict\n",
    "import numpy as np\n",
    "import evaluate\n",
    "import os\n",
    "import json\n",
    "\n",
    "try:\n",
    "    from dataset_tools import load_data\n",
    "except ImportError:\n",
    "    print(\"Warning: dataset_tools.py not found. load_data function will not be available unless defined elsewhere.\")\n",
    "    def load_data(file_path):\n",
    "        print(f\"Dummy load_data called for {file_path}. Implement or provide dataset_tools.py.\")\n",
    "        return {}\n",
    "\n",
    "\n",
    "accuracy = evaluate.load(\"accuracy\")\n",
    "\n",
    "num_classes = {\n",
    "    'sst2': 2,\n",
    "    'qqp': 2,\n",
    "    'mnli': 3,\n",
    "    'mnli-mm': 3,\n",
    "    'qnli': 2,\n",
    "    'rte': 2\n",
    "}\n",
    "\n",
    "# Custom trainer for weak-to-strong finetuning with auxiliary loss\n",
    "class WTS_Trainer(transformers.Trainer):\n",
    "    def __init__(self, *args, ft_mode=None, alpha_max=0.5, lambda_coeff=0.5, warm_up=0.2, class_wts=None, use_samples=\"adversarial\", **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "\n",
    "        self.ft_mode = ft_mode\n",
    "        self.alpha_max = alpha_max\n",
    "        self.class_wts = class_wts\n",
    "        self.lambda_coeff = lambda_coeff\n",
    "        self.warm_up = warm_up\n",
    "        self.mode = None    # 'evaluate' or 'predict'\n",
    "        self.use_samples = use_samples  # whether to use clean or adversarial samples for evaluation or prediction\n",
    "\n",
    "    def compute_loss(self, model, inputs, return_outputs=False):\n",
    "        labels = inputs.pop(\"labels\")\n",
    "        outputs_adv = model(**inputs['adv_tokenized'])\n",
    "        logits_adv = outputs_adv.logits\n",
    "        outputs_orig = model(**inputs['orig_tokenized'])\n",
    "        logits_orig = outputs_orig.logits\n",
    "\n",
    "        # loss = torch.nn.functional.cross_entropy(logits, labels)\n",
    "        loss_adv = torch.nn.functional.cross_entropy(logits_adv, labels)\n",
    "        loss_orig = torch.nn.functional.cross_entropy(logits_orig, labels)\n",
    "        loss = self.lambda_coeff * loss_adv + (1 - self.lambda_coeff) * loss_orig\n",
    "\n",
    "        # Auxilary confidence loss\n",
    "        if self.ft_mode == 'wts-aux-loss':\n",
    "            # Warmup alpha from 0 to alpha_max over the first 20 percent of the training\n",
    "            if self.warm_up == 0:\n",
    "                alpha = self.alpha_max\n",
    "            else:\n",
    "                alpha = self.alpha_max * min(1.0, self.state.global_step / (self.warm_up * self.state.max_steps))\n",
    "            # print(f\"Alpha: {alpha}\")\n",
    "\n",
    "            # Get labels from logits by thresholding based on class weights\n",
    "            num_instances = logits_orig.size(0)\n",
    "            logit_labels = [None] * num_instances\n",
    "            remaining_fraction = 1.0\n",
    "\n",
    "            for label_str in self.class_wts:\n",
    "                class_fraction = min(self.class_wts[label_str] / remaining_fraction, 1.0)\n",
    "                label = int(label_str)\n",
    "                \n",
    "                # Assign label to instances with logits above threshold\n",
    "                # Ensure remaining_logits is not empty before calling np.quantile\n",
    "                current_remaining_indices = [i for i in range(num_instances) if logit_labels[i] is None]\n",
    "                if not current_remaining_indices: # No instances left to label for this class\n",
    "                    break \n",
    "                \n",
    "                remaining_logits_values = np.array([logits_orig[i, label].item() for i in current_remaining_indices])\n",
    "                if not remaining_logits_values.size > 0: # No logits to compare for this class\n",
    "                    if class_fraction < 1.0: # If we expected to label some, but can't, this might be an issue or just no candidates\n",
    "                        pass # Or print a warning\n",
    "                    break # or continue to next class_wts\n",
    "\n",
    "                threshold = np.quantile(remaining_logits_values, 1.0 - class_fraction)\n",
    "                for i in current_remaining_indices: # Iterate only over remaining indices\n",
    "                    if logits_orig[i, label] > threshold:\n",
    "                        logit_labels[i] = label\n",
    "\n",
    "                remaining_fraction -= self.class_wts[label_str]\n",
    "                if remaining_fraction <= 0: # All fractions accounted for\n",
    "                    break\n",
    "\n",
    "\n",
    "            # Assign last label to remaining logit_labels\n",
    "            for i in range(num_instances):\n",
    "                if logit_labels[i] is None:\n",
    "                    logit_labels[i] = label # 'label' here is the last one from the class_wts loop\n",
    "\n",
    "            # Compute auxiliary loss\n",
    "            logit_labels = torch.tensor(logit_labels, device=logits_orig.device)\n",
    "            # aux_loss = torch.nn.functional.cross_entropy(logits, logit_labels)\n",
    "            aux_loss_adv = torch.nn.functional.cross_entropy(logits_adv, logit_labels)\n",
    "            aux_loss_orig = torch.nn.functional.cross_entropy(logits_orig, logit_labels)\n",
    "            aux_loss = self.lambda_coeff * aux_loss_adv + (1 - self.lambda_coeff) * aux_loss_orig\n",
    "            loss = ((1 - alpha) * loss) + (alpha * aux_loss)\n",
    "\n",
    "        if return_outputs:\n",
    "            # print(f\"Mode: {self.mode}\")\n",
    "            if self.mode == 'evaluate' and self.use_samples == 'adversarial':\n",
    "                return (loss, outputs_adv)\n",
    "            else:\n",
    "                return (loss, outputs_orig)\n",
    "        else:\n",
    "            return loss\n",
    "        # return (loss, outputs_orig) if return_outputs else loss\n",
    "    \n",
    "    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = \"eval\"):\n",
    "        self.mode = 'evaluate'\n",
    "        # print(f\"Mode: {self.mode}, Use samples: {self.use_samples}\")\n",
    "        return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)\n",
    "\n",
    "    def predict(self, test_dataset, ignore_keys=None, metric_key_prefix: str = \"eval\"):\n",
    "        self.mode = 'predict'\n",
    "        # print(f\"Mode: {self.mode}, Use samples: {self.use_samples}\")\n",
    "        return super().predict(test_dataset, ignore_keys, metric_key_prefix)\n",
    "\n",
    "def data_preprocessing(data, task):\n",
    "    data_list = []\n",
    "    for item in data:\n",
    "        if task == 'sst2':\n",
    "            data_list.append({\n",
    "                'adv_text': item['sentence'],\n",
    "                'orig_text': item['original_sentence'],\n",
    "                'label': item['label']\n",
    "            })\n",
    "        elif task == 'qqp':\n",
    "            if 'original_question1' in item:\n",
    "                orig_text = f\"Question 1: {item['original_question1']}\\nQuestion 2: {item['question2']}\"\n",
    "            elif 'original_question2' in item:\n",
    "                orig_text = f\"Question 1: {item['question1']}\\nQuestion 2: {item['original_question2']}\"\n",
    "            else:\n",
    "                # Fallback or error if no original text is identifiable for qqp\n",
    "                print(\"Warning: No 'original_question1' or 'original_question2' in QQP item. Using adversarial text as original.\")\n",
    "                orig_text = f\"Question 1: {item['question1']}\\nQuestion 2: {item['question2']}\"\n",
    "\n",
    "\n",
    "            data_list.append({\n",
    "                'adv_text': f\"Question 1: {item['question1']}\\nQuestion 2: {item['question2']}\",\n",
    "                'orig_text': orig_text,\n",
    "                'label': item['label']\n",
    "            })\n",
    "        elif task == 'mnli' or task == 'mnli-mm':\n",
    "            if 'original_premise' in item:\n",
    "                orig_text = f\"Premise: {item['original_premise']}\\nHypothesis: {item['hypothesis']}\"\n",
    "            elif 'original_hypothesis' in item:\n",
    "                orig_text = f\"Premise: {item['premise']}\\nHypothesis: {item['original_hypothesis']}\"\n",
    "            else:\n",
    "                print(\"Warning: No 'original_premise' or 'original_hypothesis' in MNLI item. Using adversarial text as original.\")\n",
    "                orig_text = f\"Premise: {item['premise']}\\nHypothesis: {item['hypothesis']}\"\n",
    "\n",
    "\n",
    "            data_list.append({\n",
    "                'adv_text': f\"Premise: {item['premise']}\\nHypothesis: {item['hypothesis']}\",\n",
    "                'orig_text': orig_text,\n",
    "                'label': item['label']\n",
    "            })\n",
    "        elif task == 'qnli':\n",
    "            if 'original_question' in item:\n",
    "                orig_text = f\"Question: {item['original_question']}\\nSentence: {item['sentence']}\"\n",
    "            elif 'original_sentence' in item:\n",
    "                orig_text = f\"Question: {item['question']}\\nSentence: {item['original_sentence']}\"\n",
    "            else:\n",
    "                print(\"Warning: No 'original_question' or 'original_sentence' in QNLI item. Using adversarial text as original.\")\n",
    "                orig_text = f\"Question: {item['question']}\\nSentence: {item['sentence']}\"\n",
    "\n",
    "            \n",
    "            data_list.append({\n",
    "                'adv_text': f\"Question: {item['question']}\\nSentence: {item['sentence']}\",\n",
    "                'orig_text': orig_text,\n",
    "                'label': item['label']\n",
    "            })\n",
    "\n",
    "        elif task == 'rte':\n",
    "            if 'original_sentence1' in item:\n",
    "                orig_text = f\"Sentence 1: {item['original_sentence1']}\\nSentence 2: {item['sentence2']}\"\n",
    "            elif 'original_sentence2' in item:\n",
    "                orig_text = f\"Sentence 1: {item['sentence1']}\\nSentence 2: {item['original_sentence2']}\"\n",
    "            else:\n",
    "                print(\"Warning: No 'original_sentence1' or 'original_sentence2' in RTE item. Using adversarial text as original.\")\n",
    "                orig_text = f\"Sentence 1: {item['sentence1']}\\nSentence 2: {item['sentence2']}\"\n",
    "            \n",
    "            data_list.append({\n",
    "                'adv_text': f\"Sentence 1: {item['sentence1']}\\nSentence 2: {item['sentence2']}\",\n",
    "                'orig_text': orig_text,\n",
    "                'label': item['label']\n",
    "            })\n",
    "        else:\n",
    "            raise ValueError(f\"Task {task} not supported\")\n",
    "\n",
    "    return data_list\n",
    "\n",
    "def print_trainable_parameters(model):\n",
    "    trainable_params = 0\n",
    "    all_param = 0\n",
    "    for _, param in model.named_parameters():\n",
    "        all_param += param.numel()\n",
    "        if param.requires_grad:\n",
    "            trainable_params += param.numel()\n",
    "    print(\n",
    "        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n",
    "    )\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    predictions = np.argmax(predictions, axis=1)\n",
    "    return accuracy.compute(predictions=predictions, references=labels)\n",
    "\n",
    "def tokenize_adv_orig_text(tokenizer: PreTrainedTokenizerBase, samples, max_len: int):\n",
    "    adv_tokenized = tokenizer(samples['adv_text'], truncation=True, max_length=max_len)\n",
    "    orig_tokenized = tokenizer(samples['orig_text'], truncation=True, max_length=max_len)\n",
    "\n",
    "    return {\n",
    "        'adv_input_ids': adv_tokenized['input_ids'],\n",
    "        'adv_attention_mask': adv_tokenized['attention_mask'],\n",
    "        'orig_input_ids': orig_tokenized['input_ids'],\n",
    "        'orig_attention_mask': orig_tokenized['attention_mask']\n",
    "    }\n",
    "\n",
    "class CustomCollator:\n",
    "    def __init__(self, tokenizer: PreTrainedTokenizerBase):\n",
    "        self.tokenizer = tokenizer\n",
    "\n",
    "    def __call__(self, batch):\n",
    "        adv_input_ids = [item['adv_input_ids'] for item in batch]\n",
    "        adv_attention_mask = [item['adv_attention_mask'] for item in batch]\n",
    "        orig_input_ids = [item['orig_input_ids'] for item in batch]\n",
    "        orig_attention_mask = [item['orig_attention_mask'] for item in batch]\n",
    "        labels = [item['label'] for item in batch]\n",
    "\n",
    "        adv_input_ids_padded = pad_sequence([torch.tensor(ids) for ids in adv_input_ids],\n",
    "                                            batch_first=True, padding_value=self.tokenizer.pad_token_id)\n",
    "        adv_attention_mask_padded = pad_sequence([torch.tensor(mask) for mask in adv_attention_mask],\n",
    "                                                 batch_first=True, padding_value=0)\n",
    "        orig_input_ids_padded = pad_sequence([torch.tensor(ids) for ids in orig_input_ids],\n",
    "                                             batch_first=True, padding_value=self.tokenizer.pad_token_id)\n",
    "        orig_attention_mask_padded = pad_sequence([torch.tensor(mask) for mask in orig_attention_mask],\n",
    "                                                  batch_first=True, padding_value=0)\n",
    "\n",
    "        return {\n",
    "            'adv_tokenized': {\n",
    "                'input_ids': adv_input_ids_padded,\n",
    "                'attention_mask': adv_attention_mask_padded\n",
    "            },\n",
    "            'orig_tokenized': {\n",
    "                'input_ids': orig_input_ids_padded,\n",
    "                'attention_mask': orig_attention_mask_padded\n",
    "            },\n",
    "            'labels': torch.tensor(labels)\n",
    "        }\n",
    "    \n",
    "def gpu_memory():\n",
    "    gpu_memory_used = 0\n",
    "    gpu_memory_total = 0\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        for i in range(torch.cuda.device_count()):\n",
    "            gpu_memory_used += torch.cuda.memory_allocated(i)\n",
    "            gpu_memory_total += torch.cuda.get_device_properties(i).total_memory\n",
    "        return gpu_memory_used / 1024**3, gpu_memory_total / 1024**3\n",
    "    return 0, 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main Execution: Configuration and Script Logic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simulate command-line arguments (defaults from argparse)\n",
    "class Args:\n",
    "    def __init__(self):\n",
    "        self.model_id = \"EleutherAI/pythia-14m\"\n",
    "        self.results_file = \"results/pythia-results.json\"\n",
    "        self.task = \"sst2\"\n",
    "        self.ft_mode = \"weak\" # choices: 'weak', 'strong', 'wts-naive', 'wts-aux-loss'\n",
    "        self.weak_labels_file = \"weak_labels/pythia-14m.json\"\n",
    "        self.num_epochs = 6.0\n",
    "        self.batch_size = 30 # Original script default\n",
    "        self.lambda_coeff = 0.3\n",
    "        self.warm_up = 0.2\n",
    "        self.alpha_max = 0.1\n",
    "        self.validate_on = \"adversarial\" # choices: 'adversarial', 'original'\n",
    "        self.tag = \"\"\n",
    "        self.glue_ds = 'advgluepp' # choices: 'glue', 'advglue', 'advgluepp'\n",
    "\n",
    "args = Args()\n",
    "\n",
    "model_id = args.model_id\n",
    "results_file = args.results_file\n",
    "task = args.task\n",
    "ft_mode = args.ft_mode\n",
    "weak_labels_file = args.weak_labels_file\n",
    "num_epochs = args.num_epochs\n",
    "lambda_coeff = args.lambda_coeff\n",
    "warm_up = args.warm_up\n",
    "alpha_max = args.alpha_max\n",
    "validate_on = args.validate_on\n",
    "tag = args.tag\n",
    "glue_ds = args.glue_ds\n",
    "\n",
    "max_len = 1024 # Defined in the original script\n",
    "batch_size = args.batch_size # Using args.batch_size\n",
    "\n",
    "# Print experiment details\n",
    "print(\"\\n* * * * * Experiment Details * * * * *\")\n",
    "print(\"Model ID:\", model_id)\n",
    "print(\"Results file:\", results_file)\n",
    "print(\"Task:\", task)\n",
    "print(\"Finetuning mode:\", ft_mode)\n",
    "if ft_mode == 'wts-naive' or ft_mode == 'wts-aux-loss':\n",
    "    print(\"Weak labels file:\", weak_labels_file)\n",
    "if ft_mode == 'wts-aux-loss':\n",
    "    print(\"Alpha max:\", alpha_max)\n",
    "print(\"Validate on:\", validate_on)\n",
    "print(\"Lambda coefficient:\", lambda_coeff)\n",
    "print(\"Warmup period:\", warm_up)\n",
    "print(\"Number of epochs:\", num_epochs)\n",
    "if torch.cuda.is_available():\n",
    "    print(\"Number of available GPUs:\", torch.cuda.device_count())\n",
    "    print(\"GPU names:\", [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])\n",
    "else:\n",
    "    print(\"CUDA is not available. Running on CPU.\")\n",
    "print(\"Dataset:\", glue_ds)\n",
    "print(\"Batch size:\", batch_size)\n",
    "print(\"* * * * * * * * * * * * * * * * * * * *\\n\", flush=True)\n",
    "\n",
    "\n",
    "dummy_data_dir = \"data\"\n",
    "dummy_weak_labels_dir = f\"weak_labels/{glue_ds}\"\n",
    "dummy_results_dir = os.path.dirname(results_file)\n",
    "\n",
    "os.makedirs(dummy_data_dir, exist_ok=True)\n",
    "os.makedirs(dummy_weak_labels_dir, exist_ok=True)\n",
    "os.makedirs(dummy_results_dir, exist_ok=True)\n",
    "os.makedirs(f\"models\", exist_ok=True) # For output_dir\n",
    "\n",
    "\n",
    "try:\n",
    "    from dataset_tools import load_data\n",
    "except ImportError:\n",
    "    if not os.path.exists(\"dataset_tools.py\"):\n",
    "        with open(\"dataset_tools.py\", \"w\") as f_dt:\n",
    "            f_dt.write(\"# Dummy dataset_tools.py\\n\")\n",
    "            f_dt.write(\"import json\\n\")\n",
    "            f_dt.write(\"def load_data(file_path):\\n\")\n",
    "            f_dt.write(\"    print(f'Dummy load_data from dataset_tools.py for {file_path}')\\n\")\n",
    "            f_dt.write(\"    # This should return data in the format your script expects\\n\")\n",
    "            f_dt.write(\"    # For example, a dictionary keyed by task, then a list of items\\n\")\n",
    "            f_dt.write(\"    # This is highly dependent on your actual data structure\\n\")\n",
    "            f_dt.write(\"    if 'holdout' in file_path or 'weak_labels' in file_path: # Specific for weak label generation path\\n\")\n",
    "            f_dt.write(\"        return { 'sst2': [{'sentence': 'dummy sentence', 'original_sentence': 'orig dummy', 'label': 0}] } \\n\")\n",
    "            f_dt.write(\"    return { 'sst2': [{'sentence': 'dummy sentence', 'original_sentence': 'orig dummy', 'label': 0, \\n\")\n",
    "            f_dt.write(\"                        'question1': 'q1', 'question2': 'q2', \\n\")\n",
    "            f_dt.write(\"                        'premise': 'p', 'hypothesis': 'h', 'question': 'q'}] } \\n\") # Add all possible keys\n",
    "        print(\"Created dummy dataset_tools.py. Please replace with your actual data loading logic.\")\n",
    "        from dataset_tools import load_data # Re-try import\n",
    "\n",
    "\n",
    "class_wts_file_path_map = {\n",
    "    'glue': 'data/glue_info.json',\n",
    "    'advglue': 'data/advglue_info.json',\n",
    "    'advgluepp': 'data/advgluepp_info.json'\n",
    "}\n",
    "class_wts_file = class_wts_file_path_map[glue_ds]\n",
    "if not os.path.exists(class_wts_file):\n",
    "    with open(class_wts_file, 'w') as f_cwt:\n",
    "        json.dump({'class_wts': {'sst2': {'0': 0.5, '1': 0.5}}}, f_cwt, indent=2) # Dummy weights\n",
    "    print(f\"Created dummy class weights file: {class_wts_file}\")\n",
    "\n",
    "\n",
    "# Load data\n",
    "if ft_mode == 'wts-naive' or ft_mode == 'wts-aux-loss':\n",
    "    if not os.path.exists(weak_labels_file) and \"notebook_test_weak_labels.json\" in weak_labels_file : # Only create dummy if it's the test one\n",
    "         with open(weak_labels_file, 'w') as f_wl:\n",
    "            json.dump({'sst2': [{'adv_text': 'dummy adv', 'orig_text': 'dummy orig', 'label': 0}]}, f_wl, indent=2)\n",
    "         print(f\"Created dummy weak labels file: {weak_labels_file}\")\n",
    "    data_loaded = load_data(weak_labels_file)\n",
    "    data = data_loaded.get(task, []) # Get data for task, default to empty list if task not found\n",
    "    print(f\"Weak labels for task {task}: {len(data)}\")\n",
    "    if not data and (ft_mode == 'wts-naive' or ft_mode == 'wts-aux-loss'):\n",
    "        print(f\"Warning: No weak label data found for task {task} in {weak_labels_file}. Training might fail or be on empty data.\")\n",
    "        data = [{'adv_text': 'dummy adv text', 'orig_text': 'dummy original text', 'label': 0}]\n",
    "\n",
    "\n",
    "else:\n",
    "    data_file_path_map = {\n",
    "        'glue': 'data/glue_train.json',\n",
    "        'advglue': 'data/advglue_train.json',\n",
    "        'advgluepp': 'data/advgluepp_train.json'\n",
    "    }\n",
    "    train_data_file = data_file_path_map[glue_ds]\n",
    "    if not os.path.exists(train_data_file):\n",
    "        with open(train_data_file, 'w') as f_trd: # Create a dummy train data file\n",
    "            dummy_item = {'sentence': 'dummy sentence', 'original_sentence': 'orig dummy sent', 'label': 0,\n",
    "                          'question1': 'dummy q1', 'original_question1': 'orig dummy q1', 'question2': 'dummy q2', 'original_question2': 'orig dummy q2',\n",
    "                          'premise': 'dummy premise', 'original_premise': 'orig dummy premise', 'hypothesis': 'dummy hypothesis', 'original_hypothesis': 'orig dummy hyp',\n",
    "                          'question': 'dummy question', 'original_question': 'orig dummy q', 'sentence1': 'dummy s1', 'original_sentence1': 'orig dummy s1',\n",
    "                          'sentence2': 'dummy s2', 'original_sentence2': 'orig dummy s2'}\n",
    "            json.dump({task: [dummy_item]}, f_trd, indent=2)\n",
    "        print(f\"Created dummy train data file: {train_data_file}\")\n",
    "\n",
    "    data_loaded = load_data(train_data_file)\n",
    "    data_raw = data_loaded.get(task, [])\n",
    "    print(f\"Training samples for task {task}: {len(data_raw)}\")\n",
    "    if not data_raw:\n",
    "        print(f\"Warning: No training data found for task {task} in {train_data_file}. Training might fail.\")\n",
    "        dummy_item = {'sentence': 'dummy sentence', 'original_sentence': 'orig dummy sent', 'label': 0,\n",
    "                      'question1': 'dummy q1', 'original_question1': 'orig dummy q1', 'question2': 'dummy q2', 'original_question2': 'orig dummy q2',\n",
    "                      'premise': 'dummy premise', 'original_premise': 'orig dummy premise', 'hypothesis': 'dummy hypothesis', 'original_hypothesis': 'orig dummy hyp',\n",
    "                      'question': 'dummy question', 'original_question': 'orig dummy q', 'sentence1': 'dummy s1', 'original_sentence1': 'orig dummy s1',\n",
    "                      'sentence2': 'dummy s2', 'original_sentence2': 'orig dummy s2'}\n",
    "        data_raw = [dummy_item]\n",
    "\n",
    "    data = data_preprocessing(data_raw, task)\n",
    "\n",
    "\n",
    "with open(class_wts_file, 'r') as f:\n",
    "    class_wts_all_tasks = json.load(f)['class_wts']\n",
    "    class_wts_task = class_wts_all_tasks.get(task)\n",
    "    if class_wts_task is None:\n",
    "        print(f\"Warning: Class weights for task '{task}' not found in {class_wts_file}. Using default.\")\n",
    "        # Provide default class weights if task-specific ones are missing\n",
    "        num_labels_for_task = num_classes.get(task, 2) # Default to 2 if task not in num_classes\n",
    "        class_wts_task = {str(i): 1.0/num_labels_for_task for i in range(num_labels_for_task)}\n",
    "\n",
    "\n",
    "output_dir = f\"models/{model_id}-{task}{tag}\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_id,\n",
    "                                                           device_map=\"auto\",\n",
    "                                                           num_labels=num_classes[task],\n",
    "                                                          )\n",
    "print(f\"GPU memory used: {gpu_memory()[0]:.2f} / {gpu_memory()[1]:.2f} GB\")\n",
    "\n",
    "\n",
    "data_len = len(data)\n",
    "if data_len == 0:\n",
    "    print(\"Error: No data loaded. Cannot proceed with training.\")\n",
    "    if not data: data = [{'adv_text': 'final dummy adv', 'orig_text': 'final dummy orig', 'label': 0}] \n",
    "    data_len = len(data)\n",
    "\n",
    "\n",
    "train_len = int(0.8 * data_len)\n",
    "if train_len == 0 and data_len > 0: # Ensure at least one training sample if data exists\n",
    "    train_len = 1 \n",
    "if train_len == 0 and data_len == 0: # No data, no training samples\n",
    "    print(\"Warning: No data to split for training and validation.\")\n",
    "    dummy_sample_for_dataset = {'adv_text': ['dummy'], 'orig_text': ['dummy'], 'label': [0]}\n",
    "    train_data_for_ds = Dataset.from_dict(dummy_sample_for_dataset)\n",
    "    val_data_for_ds = Dataset.from_dict(dummy_sample_for_dataset)\n",
    "else:\n",
    "    np.random.shuffle(data) # Shuffle the actual data list\n",
    "    train_data_list = data[:train_len]\n",
    "    val_data_list = data[train_len:]\n",
    "    if not val_data_list and data_len > 0 : # Ensure val_data is not empty if train_data exists\n",
    "        val_data_list = train_data_list # Use train data for validation if val split is empty\n",
    "    \n",
    "    train_data_for_ds = Dataset.from_dict({'adv_text': [item['adv_text'] for item in train_data_list],\n",
    "                                        'orig_text': [item['orig_text'] for item in train_data_list],\n",
    "                                        'label': [item['label'] for item in train_data_list]})\n",
    "    val_data_for_ds = Dataset.from_dict({'adv_text': [item['adv_text'] for item in val_data_list],\n",
    "                                      'orig_text': [item['orig_text'] for item in val_data_list],\n",
    "                                      'label': [item['label'] for item in val_data_list]})\n",
    "\n",
    "\n",
    "ds = DatasetDict({\n",
    "    'train': train_data_for_ds,\n",
    "    'val': val_data_for_ds\n",
    "})\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "model.config.pad_token_id = model.config.eos_token_id\n",
    "\n",
    "# Pass max_len to the tokenize function\n",
    "tokenized_ds = ds.map(lambda samples: tokenize_adv_orig_text(tokenizer, samples, max_len=max_len), batched=True)\n",
    "\n",
    "print(\"\\nFine-tuning the model...\\n\", flush=True)\n",
    "trainer = WTS_Trainer(\n",
    "    model=model,\n",
    "    train_dataset=tokenized_ds['train'],\n",
    "    eval_dataset=tokenized_ds['val'],\n",
    "    args=transformers.TrainingArguments(\n",
    "        remove_unused_columns=False,\n",
    "        per_device_train_batch_size=batch_size,\n",
    "        per_device_eval_batch_size=batch_size,\n",
    "        gradient_checkpointing=True,\n",
    "        num_train_epochs=num_epochs,\n",
    "        evaluation_strategy=\"epoch\",\n",
    "        save_strategy=\"epoch\",\n",
    "        load_best_model_at_end=True,\n",
    "        save_total_limit=1,\n",
    "        warmup_steps=2,\n",
    "        learning_rate=1e-5,\n",
    "        logging_steps=10, # Log less frequently for cleaner output in notebook\n",
    "        output_dir=output_dir,\n",
    "        # optim=\"paged_adamw_8bit\" # This requires bitsandbytes\n",
    "    ),\n",
    "    data_collator=CustomCollator(tokenizer),\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics,\n",
    "    ft_mode=ft_mode,\n",
    "    alpha_max=alpha_max,\n",
    "    lambda_coeff=lambda_coeff,\n",
    "    warm_up=warm_up,\n",
    "    class_wts=class_wts_task, # Use task-specific class weights\n",
    "    use_samples=validate_on\n",
    ")\n",
    "if hasattr(model.config, 'use_cache'): # Not all models have this (e.g. older ones)\n",
    "    model.config.use_cache = False\n",
    "trainer.train()\n",
    "\n",
    "# Load test data\n",
    "test_data_file_path_map = {\n",
    "    'glue': 'data/glue_test.json',\n",
    "    'advglue': 'data/advglue_test.json',\n",
    "    'advgluepp': 'data/advgluepp_test.json'\n",
    "}\n",
    "test_data_file = test_data_file_path_map[glue_ds]\n",
    "if not os.path.exists(test_data_file):\n",
    "    with open(test_data_file, 'w') as f_tsd: # Create a dummy test data file\n",
    "        dummy_item = {'sentence': 'test sentence', 'original_sentence': 'orig test sent', 'label': 0,\n",
    "                      'question1': 'test q1', 'original_question1': 'orig test q1', 'question2': 'test q2', 'original_question2': 'orig test q2',\n",
    "                      'premise': 'test premise', 'original_premise': 'orig test premise', 'hypothesis': 'test hypothesis', 'original_hypothesis': 'orig test hyp',\n",
    "                      'question': 'test question', 'original_question': 'orig test q', 'sentence1': 'test s1', 'original_sentence1': 'orig test s1',\n",
    "                      'sentence2': 'test s2', 'original_sentence2': 'orig test s2'}\n",
    "        json.dump({task: [dummy_item]}, f_tsd, indent=2)\n",
    "    print(f\"Created dummy test data file: {test_data_file}\")\n",
    "\n",
    "test_data_loaded = load_data(test_data_file)\n",
    "test_data_raw = test_data_loaded.get(task, [])\n",
    "print(f\"Test samples for task {task}: {len(test_data_raw)}\")\n",
    "if not test_data_raw:\n",
    "    dummy_item = {'sentence': 'test sentence', 'original_sentence': 'orig test sent', 'label': 0,\n",
    "                  'question1': 'test q1', 'original_question1': 'orig test q1', 'question2': 'test q2', 'original_question2': 'orig test q2',\n",
    "                  'premise': 'test premise', 'original_premise': 'orig test premise', 'hypothesis': 'test hypothesis', 'original_hypothesis': 'orig test hyp',\n",
    "                  'question': 'test question', 'original_question': 'orig test q', 'sentence1': 'test s1', 'original_sentence1': 'orig test s1',\n",
    "                  'sentence2': 'test s2', 'original_sentence2': 'orig test s2'}\n",
    "    test_data_raw = [dummy_item] # Ensure at least one item for processing\n",
    "\n",
    "\n",
    "test_data = data_preprocessing(test_data_raw, task)\n",
    "test_ds = Dataset.from_dict({'adv_text': [item['adv_text'] for item in test_data],\n",
    "                             'orig_text': [item['orig_text'] for item in test_data],\n",
    "                             'label': [item['label'] for item in test_data]})\n",
    "tokenized_test_ds = test_ds.map(lambda samples: tokenize_adv_orig_text(tokenizer, samples, max_len=max_len), batched=True)\n",
    "\n",
    "for use_samples_eval in ['adversarial', 'original']:\n",
    "    print(f'\\nEvaluating the model on {use_samples_eval} samples...\\n', flush=True)\n",
    "    trainer.use_samples = use_samples_eval\n",
    "    eval_results = trainer.evaluate(tokenized_test_ds)\n",
    "    print(eval_results)\n",
    "\n",
    "    if not os.path.exists(os.path.dirname(results_file)):\n",
    "        os.makedirs(os.path.dirname(results_file))\n",
    "\n",
    "    try:\n",
    "        with open(results_file, 'r') as f:\n",
    "            results = json.load(f)\n",
    "    except FileNotFoundError:\n",
    "        results = {}\n",
    "    except json.JSONDecodeError:\n",
    "        results = {}\n",
    "\n",
    "\n",
    "    if use_samples_eval not in results:\n",
    "        results[use_samples_eval] = {}\n",
    "    if task not in results[use_samples_eval]:\n",
    "        results[use_samples_eval][task] = {}\n",
    "\n",
    "    task_results = results[use_samples_eval][task]\n",
    "\n",
    "    mode_key_map = {\n",
    "        'weak': 'Weak Performance',\n",
    "        'strong': 'Strong Performance',\n",
    "        'wts-naive': 'WTS-Naive',\n",
    "        'wts-aux-loss': 'WTS-Aux-Loss'\n",
    "    }\n",
    "    task_results[mode_key_map.get(ft_mode, ft_mode)] = eval_results['eval_accuracy'] * 100\n",
    "    results[use_samples_eval][task] = task_results\n",
    "\n",
    "    with open(results_file, 'w') as f:\n",
    "        json.dump(results, f, indent=2)\n",
    "    print(\"Results saved to:\", results_file)\n",
    "\n",
    "if ft_mode == 'weak':\n",
    "    holdout_data_file_path_map = {\n",
    "        'glue': 'data/glue_holdout.json',\n",
    "        'advglue': 'data/advglue_holdout.json',\n",
    "        'advgluepp': 'data/advgluepp_holdout.json'\n",
    "    }\n",
    "    holdout_data_file = holdout_data_file_path_map[glue_ds]\n",
    "\n",
    "    if not os.path.exists(holdout_data_file):\n",
    "        with open(holdout_data_file, 'w') as f_hd: # Create a dummy holdout data file\n",
    "            dummy_item = {'sentence': 'holdout sentence', 'original_sentence': 'orig holdout sent', 'label': 0,\n",
    "                          'question1': 'holdout q1', 'original_question1': 'orig holdout q1', 'question2': 'holdout q2', 'original_question2': 'orig holdout q2',\n",
    "                          'premise': 'holdout premise', 'original_premise': 'orig holdout premise', 'hypothesis': 'holdout hypothesis', 'original_hypothesis': 'orig holdout hyp',\n",
    "                          'question': 'holdout question', 'original_question': 'orig holdout q', 'sentence1': 'holdout s1', 'original_sentence1': 'orig holdout s1',\n",
    "                          'sentence2': 'holdout s2', 'original_sentence2': 'orig holdout s2'}\n",
    "            json.dump({task: [dummy_item]}, f_hd, indent=2)\n",
    "        print(f\"Created dummy holdout data file: {holdout_data_file}\")\n",
    "\n",
    "\n",
    "    data_holdout_loaded = load_data(holdout_data_file)\n",
    "    data_holdout_raw = data_holdout_loaded.get(task, [])\n",
    "    print(f\"Holdout samples for task {task}: {len(data_holdout_raw)}\")\n",
    "    if not data_holdout_raw:\n",
    "        dummy_item = {'sentence': 'holdout sentence', 'original_sentence': 'orig holdout sent', 'label': 0,\n",
    "                      'question1': 'holdout q1', 'original_question1': 'orig holdout q1', 'question2': 'holdout q2', 'original_question2': 'orig holdout q2',\n",
    "                      'premise': 'holdout premise', 'original_premise': 'orig holdout premise', 'hypothesis': 'holdout hypothesis', 'original_hypothesis': 'orig holdout hyp',\n",
    "                      'question': 'holdout question', 'original_question': 'orig holdout q', 'sentence1': 'holdout s1', 'original_sentence1': 'orig holdout s1',\n",
    "                      'sentence2': 'holdout s2', 'original_sentence2': 'orig holdout s2'}\n",
    "        data_holdout_raw = [dummy_item] # Ensure at least one item for processing\n",
    "\n",
    "\n",
    "    data_holdout = data_preprocessing(data_holdout_raw, task)\n",
    "    ds_holdout = Dataset.from_dict({'adv_text': [item['adv_text'] for item in data_holdout],\n",
    "                                    'orig_text': [item['orig_text'] for item in data_holdout],\n",
    "                                    'label': [item['label'] for item in data_holdout]})\n",
    "    tokenized_ds_holdout = ds_holdout.map(lambda samples: tokenize_adv_orig_text(tokenizer, samples, max_len=max_len), batched=True)\n",
    "\n",
    "    print(\"\\nGenerating weak labels on holdout data...\\n\", flush=True)\n",
    "    predictions = trainer.predict(tokenized_ds_holdout)\n",
    "    \n",
    "    preds = np.argmax(predictions.predictions, axis=1)\n",
    "    weak_labels_dict_list = []\n",
    "    for i, item in enumerate(data_holdout):\n",
    "        weak_labels_dict_list.append({\n",
    "            'adv_text': item['adv_text'],\n",
    "            'orig_text': item['orig_text'],\n",
    "            'label': int(preds[i])\n",
    "        })\n",
    "\n",
    "    # Save weak labels to file (ensure path is correct based on glue_ds)\n",
    "    weak_labels_output_file = f'weak_labels/{glue_ds}/{model_id}{tag}.json'\n",
    "    \n",
    "    if not os.path.exists(os.path.dirname(weak_labels_output_file)):\n",
    "        os.makedirs(os.path.dirname(weak_labels_output_file))\n",
    "\n",
    "    try:\n",
    "        with open(weak_labels_output_file, 'r') as f:\n",
    "            weak_labels_content = json.load(f)\n",
    "    except FileNotFoundError:\n",
    "        weak_labels_content = {}\n",
    "    except json.JSONDecodeError: # Handle empty or malformed JSON\n",
    "        weak_labels_content = {}\n",
    "\n",
    "\n",
    "    weak_labels_content[task] = weak_labels_dict_list\n",
    "\n",
    "    with open(weak_labels_output_file, 'w') as f:\n",
    "        json.dump(weak_labels_content, f, indent=2)\n",
    "    print(\"Weak labels saved to:\", weak_labels_output_file, flush=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
