import copy
import gc
import math
import time
import os
import shutil
import numpy as np
from copy import deepcopy

import bitsandbytes
import regex as re
import torch
import torch.nn as nn
import torch.nn.functional as F
from buffers import DecBuffer
from custom_datasets import (
    CvDataset,
    ToyCvDataset,
    ToyCvGenderDataset,
    ToyCvGenderCorrelationsDataset,
    IMDBDataset,
    SamePromptDataset,
    SimpleFactsDataset,
    SimpleQADataset,
)
from hparams.utils import RLExpRunType, get_run_type_hparams
from huggingface_hub import login
from overseers.evaluate_overseer import eval_overseer
from overseers.overseers import (
    LLMOverseer,
    RMOverseer,
    StringMatchingOverseer,
    StringMatchingOverseerCV,
)
from overseers.prompts import get_llm_overseer_prompt_and_threshold
from paraphraser import Paraphraser
from peft import (
    LoraConfig,
)
from rl_utils import (
    # compute_metrics,
    device_aware_collate,
    display_samples,
    print_number_of_trainable_model_parameters,
    select_logits_by_label,
)
from loggers.wandb_table_logger import CompletionsTableLogger
from task_usefulness import RMTaskUsefulness
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig
)
from transformers import AutoModelForCausalLM
from trl import (
    AutoModelForCausalLMWithValueHead,
    PPOConfig,
    PPOTrainer,
    create_reference_model,
    set_seed,
)
from trl.core import LengthSampler
from vllm import LLM, SamplingParams  # noqa
from vllm.lora.request import LoRARequest  # noqa
from vllm_utils import vllm_single_gpu_patch
import wandb

class StegoTrainer:
    def __init__(self, hparams):
        self.hparams = hparams

        self.device = (
            "cuda:0"
            if torch.cuda.is_available()
            else "mps"
            if torch.backends.mps.is_available()
            else "cpu"
        )
        # If multiple GPUs, set device to cuda:0
        if torch.cuda.device_count() > 1:
            self.device = "cuda:0"

        self.device_collate = lambda batch: device_aware_collate(batch, self.device)

        self.init_counters()
        self.init_encoder()
        self.init_decoder()
        self.init_paraphraser() if self.hparams["train"]["do_paraphrase"] else None
        self.init_data()
        self.init_task_usefulness()

        if self.hparams["enc"]["vllm_inference"]:
            self.vllm_gpu_usage = 0.2  # 14% of GPU memory used by VLLM (benchmarked minimum possible value to have vLLM not do multiple forward passes for a batch)
            self.vllm_device = "cuda:0"
            # If multi-GPU, change usage
            if torch.cuda.device_count() > 1:
                # Load the vLLM encoder on the inference GPU (cuda:1)
                self.vllm_gpu_usage = 0.8  # 80% of GPU memory used by VLLM
                self.vllm_device = "cuda:1"

            self.init_vllm_encoder()

        self.running_id = 0
        self.cost_cap = (
            1.0
            if self.hparams["train"]["cost_cap"] == 0.0
            else self.hparams["train"]["cost_cap"]
        )

        self.enc_ppo_trainer = PPOTrainer(
            self.enc_ppo_config,
            self.enc_model,
            self.enc_ref_model,
            tokenizer=self.enc_tokenizer,
            # dataset=self.enc_train_dataset,
            # data_collator=self.device_collate,
            optimizer=self.enc_optimizer,
        )

        print("Accelerator device = ", self.enc_ppo_trainer.accelerator.device)

        # self.device = torch.device(f"cuda:{self.enc_ppo_trainer.accelerator.local_process_index}")

        self.init_overseer()
        self.init_logging()

        self.OVERRIDE_BATCH = None

    def cleanup_gpu_memory(self):
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.synchronize()
        
    def init_logging(self):
        # Table logger
        self.table_logger = CompletionsTableLogger(self)
        
        # Log hparams
        hparams_copy = copy.deepcopy(self.hparams)
        hparams_copy = {f"hparams/{key}": value for key, value in hparams_copy.items()}
        self.do_wandb_log(hparams_copy)
        
        # Log dataset sizes
        dataset_sizes = {
            "n_enc_train_samples": len(self.enc_train_dataset),
            "n_dec_train_samples": len(self.dec_train_dataset),
            "n_eval_samples": len(self.eval_dataset),
        }
        dataset_sizes = {f"dataset/{key}": value for key, value in dataset_sizes.items()}
        self.do_wandb_log(dataset_sizes)

    def init_paraphraser(self):
        print("Initializing Paraphraser...")
        self.paraphraser = Paraphraser()

    def paraphraser_step(self, batch):
        eval_para_time = time.time()
        valid_texts = batch["dec_in_texts"]
        mini_batch_size = 8
        max_workers = len(valid_texts) // mini_batch_size
        return_costs = True
        texts_pp, costs = self.paraphraser.parallelize_api_calls(
            large_batch=valid_texts,
            max_workers=max_workers,
            return_costs=return_costs,
        )
        # Log costs
        self.cost_tracker += sum(costs)
        # Replace the valid texts in the batch with the paraphrased texts
        batch["dec_in_texts_pp"] = texts_pp
        eval_para_elapsed = time.time() - eval_para_time
        print(f"Paraphrasing took {eval_para_elapsed:.2f} seconds")

        return batch, eval_para_elapsed

        # # log hparams dict
        # self.do_wandb_log(self.hparams)

    def memory_summary(self):
        num_gpus = torch.cuda.device_count()

        if num_gpus == 0:
            print("No GPUs available.")
            return

        print(f"Number of GPUs available: {num_gpus}")
        print("-" * 40)

        for i in range(num_gpus):
            torch.cuda.set_device(i)  # Set the current GPU
            gpu_stats = torch.cuda.get_device_properties(i)

            # Get memory stats
            total_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
            reserved_memory = round(
                torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3
            )
            allocated_memory = round(
                torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, 3
            )

            print(f"GPU {i}: {gpu_stats.name}")
            print(f"  Total memory: {total_memory} GB")
            print(f"  Reserved memory: {reserved_memory} GB")
            print(f"  Allocated memory: {allocated_memory} GB")
            print(
                f"  Memory utilization: {(reserved_memory / total_memory) * 100:.2f}%"
            )
            print("-" * 40)

    def save_checkpoints(self):
        classifier_output_dir = (
            f"./results/{self.hparams['train']['checkpoint_id']}-ClassifierCheckpoint"
        )
        ppo_output_dir = (
            f"./results/{self.hparams['train']['checkpoint_id']}-Checkpoint"
        )

        self.dec_model.save_pretrained(classifier_output_dir)
        self.dec_tokenizer.save_pretrained(classifier_output_dir)
        self.enc_ppo_trainer.save_pretrained(ppo_output_dir)

        # Push to Hugging Face Hub
        self.dec_model.push_to_hub(
            f"jvelja/BERT_{self.hparams['train']['checkpoint_id']}_{self.running_id}"
        )
        self.dec_tokenizer.push_to_hub(
            f"jvelja/BERT_{self.hparams['train']['checkpoint_id']}_{self.running_id}"
        )

        # For the PPO model, assuming it's compatible with push_to_hub
        self.enc_ppo_trainer.push_to_hub(
            f"jvelja/{self.hparams['train']['checkpoint_id']}_{self.running_id}"
        )

        self.running_id += 1

    def init_encoder(self):
        if self.hparams["enc"]["do_lora"]:  # LoRA adapters for finetuning (disabled for small models - GPT2)
            assert "gpt2" not in self.hparams["enc"]["model_name"]
            lora_config = LoraConfig(
                r=16,
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                target_modules=["q_proj", "v_proj"],
                lora_alpha=16,
            )
        else:
            lora_config = None

        quantization_config = None  # We do not do quant anymore

        self.enc_ppo_config = PPOConfig(
            model_name=self.hparams["enc"]["model_name"],
            learning_rate=self.hparams["enc"]["lr"],
            log_with="wandb",
            batch_size=self.hparams["enc"]["batch_size"],
            mini_batch_size=self.hparams["enc"]["mini_batch_size"],
            gradient_accumulation_steps=self.hparams["enc"][
                "gradient_accumulation_steps"
            ],
            init_kl_coef = self.hparams["enc"]["init_kl_coef"],
            target =self.hparams["enc"]["target_kl"],
            optimize_cuda_cache=True,
            early_stopping=False,
            tracker_kwargs=self.hparams["wandb"],
        )
        hparams_copy = deepcopy(self.hparams)
        self.enc_ppo_config.all_stego_hparams = hparams_copy

        dtype = (
            torch.bfloat16
            if self.device == "cuda" and torch.cuda.get_device_capability()[0] >= 8
            else "float16"
        )
        if "gpt2" in self.hparams["enc"]["model_name"]:
            dtype = "float32"

        self.enc_model = AutoModelForCausalLMWithValueHead.from_pretrained(
            self.enc_ppo_config.model_name,
            peft_config=lora_config,
            quantization_config=quantization_config,
            attn_implementation=self.hparams['enc']['attn_implementation'],
            torch_dtype=dtype,
        )
        self.enc_model.config.use_cache = False

        self.enc_ref_model = create_reference_model(self.enc_model)
        self.enc_tokenizer = AutoTokenizer.from_pretrained(
            self.enc_ppo_config.model_name
        )

        # Tokenizer check - Some models do not have a padding token, thus we need to hardcode it as the EoS token
        if self.enc_tokenizer.pad_token_id is None:
            self.enc_tokenizer.pad_token = self.enc_tokenizer.eos_token

        # self.enc_optimizer = Adafactor(
        #     self.enc_model.parameters(),
        #     scale_parameter=False,
        #     relative_step=False,
        #     warmup_init=False,
        #     lr=self.enc_ppo_config.learning_rate,
        # )
        self.enc_optimizer = AdamW(
            self.enc_model.parameters(),
            lr=self.enc_ppo_config.learning_rate,
            weight_decay=0.01,
        )

        # If encoder is frozen, set requires_grad to False for all parameters
        if not self.hparams["train"]["do_update_enc"]:
            for param in self.enc_model.parameters():
                param.requires_grad = False

        print(
            f"PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(self.enc_model)}\n"
        )
        print()
        print(
            f"Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(self.enc_ref_model)}\n"
        )

        if self.hparams["enc"]["vllm_inference"]:
            # Save the initial encoder model and tokenizer for vLLM
            self.enc_model_path = self.get_vllm_encoder_path(self.n_epochs)
            self.enc_model.pretrained_model.save_pretrained(self.enc_model_path)
            self.enc_tokenizer.save_pretrained(self.enc_model_path)

    def init_decoder(self):
        if self.hparams["dec"]["new_classification_head"]:
            # Init model
            self.dec_model = AutoModelForSequenceClassification.from_pretrained(
                self.hparams["dec"]["model_name"],
                num_labels=self.hparams["dataset"]["n_payloads"],
            ).to(self.device)
            # Add new head
            self.dec_model.classifier = torch.nn.Linear(
                in_features=self.dec_model.classifier.in_features,
                out_features=self.hparams["dataset"]["n_payloads"],
            ).to(self.device)

        else:
            self.dec_model = AutoModelForSequenceClassification.from_pretrained(
                self.hparams["dec"]["model_name"]
            ).to(self.device)
            
        # assert num heads == num labels in dataset
        

        # Tokenizer, optimizer, loss_fn
        self.dec_tokenizer = AutoTokenizer.from_pretrained(
            self.hparams["dec"]["model_name"]
        )
        # self.dec_optimizer = AdamW(
        #     self.dec_model.parameters(),
        #     lr=self.hparams["dec"]["lr"],
        #     weight_decay=self.hparams["dec"]["weight_decay"],
        # )
        no_decay = ["bias", "LayerNorm.weight"]
        stego_optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in self.dec_model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.hparams["dec"]["weight_decay"],
            },
            {
                "params": [
                    p
                    for n, p in self.dec_model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        self.dec_optimizer = (
            bitsandbytes.optim.AdamW32bit(
                stego_optimizer_grouped_parameters, lr=self.hparams["dec"]["lr"]
            )
            if self.device == "cuda"
            else AdamW(
                self.dec_model.parameters(),
                lr=self.hparams["dec"]["lr"],
                weight_decay=self.hparams["dec"]["weight_decay"],
            )
        )

        self.dec_loss_fn = nn.CrossEntropyLoss()

        # Freeze parameters
        if not self.hparams["train"]["do_update_dec"]:
            for param in self.dec_model.parameters():
                param.requires_grad = False
        else:
            # Check whether user wants to update the head and/or the body of the classifier
            if (
                self.hparams["train"]["dec_update_params"]["head"]
                and not self.hparams["train"]["dec_update_params"]["body"]
            ):
                for name, param in self.dec_model.named_parameters():
                    if "classifier" not in name:
                        param.requires_grad = False
            elif (
                not self.hparams["train"]["dec_update_params"]["head"]
                and self.hparams["train"]["dec_update_params"]["body"]
            ):
                for name, param in self.dec_model.named_parameters():
                    if "classifier" in name:
                        param.requires_grad = False

        # Get number of trainable parameters
        print(
            f"Stego classifier parameters to be updated:\n{print_number_of_trainable_model_parameters(self.dec_model)}\n"
        )

    def init_vllm_encoder(self):
        self.cleanup_gpu_memory()

        total_memory = torch.cuda.get_device_properties(0).total_memory
        free_memory = torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
        print(
            f"Available GPU memory before vLLM initialization: {free_memory / 1e9:.2f} GB"
        )

        vllm_single_gpu_patch()
        # Set up VLLM encoder on the inference GPU if available
        self.vllm_llm = LLM(
            model=self.hparams['enc']['model_name'],
            tensor_parallel_size=1,
            max_model_len=512,
            gpu_memory_utilization=self.vllm_gpu_usage,
            device=self.vllm_device,
            enable_lora=True,
        )
        self.vllm_sampling_params = SamplingParams(
            n=1,  # Equivalent to num_return_sequences in HF, default is 1
            temperature=1.0,  # Equivalent to temperature in HF
            top_k=-1,  # vLLM uses -1 to disable top-k, equivalent to 0.0 in HF
            top_p=1.0,  # Same as HF
            use_beam_search=False,  # Equivalent to do_sample=True in HF
            min_tokens=0,  # Roughly equivalent to min_length=-1 in HF, allowing EOS token at any position
            ignore_eos=False,  # Ensures EOS token is not ignored, aligning with min_length=-1 in HF
            max_tokens=self.hparams["enc"]["output_len"]["max"],
        )
        print("🔥🔥🔥 vllm loaded")

        # set device again because LLM calls `torch.cuda.set_device` (see: https://github.com/allenai/open-instruct/blob/5641385b1bec87d80b61bb219325be7fecac71c3/open_instruct/ppo_vllm.py#L192)
        # NOTE: do not remove this; otherwise accelerate hangs.
        torch.cuda.set_device(self.device)

    def update_vllm_encoder(self, eval=False):
        print("Saving encoder model adapters for vLLM...")
        
        if eval:
            self.enc_model_path = self.get_vllm_encoder_path(self.n_epochs)
        else:
            self.enc_model_path = self.get_vllm_encoder_path(self.n_epochs + 1)

        self.n_vllm_updates += 1
        
        # If the folder exists and we're in eval mode, delete it
        if os.path.exists(self.enc_model_path) and eval:
            print(f"Removing existing folder: {self.enc_model_path}")
            shutil.rmtree(self.enc_model_path)
        
        # Create the folder if it doesn't exist
        os.makedirs(self.enc_model_path, exist_ok=True)
        
        # Save the model and tokenizer
        self.enc_ppo_trainer.model.pretrained_model.save_pretrained(self.enc_model_path)
        self.enc_tokenizer.save_pretrained(self.enc_model_path)
        print("Encoder model adapters saved. Folder:", self.enc_model_path)
        
        max_retries = 3
        for attempt in range(max_retries):
            try:
                # self.init_vllm_encoder()
                print("🔥🔥🔥 vllm reloaded")
                break
            except AssertionError as e:
                if "Error in memory profiling" in str(e) and attempt < max_retries - 1:
                    print(f"Attempt {attempt + 1} failed. Retrying after cleanup...")
                    self.cleanup_gpu_memory()
                else:
                    raise
                
    def get_vllm_encoder_path(self, epoch):
        ## Create the directory if it doesn't exist
        os.makedirs("./exps", exist_ok=True)
        return f"./exps/encoder_model_{self.hparams['timestamp']['start']}_{epoch}"

    def generate_with_vllm(self, batch):
        # Generate with VLLM (for decoder training - encoder inference)
        prompts = batch["query"]
        adapters = LoRARequest("latest_adapter", self.n_vllm_updates, lora_path=self.enc_model_path) if self.n_vllm_updates != 0 else None

        # Generate texts from the prompts. The output is a list of RequestOutput objects
        outputs = self.vllm_llm.generate(prompts, self.vllm_sampling_params,
                                         lora_request=adapters)

        # that contain the prompt, generated text, and other information.
        # We extract the generated texts from the outputs
        completions = [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
        response_tokens = [outputs[i].outputs[0].token_ids for i in range(len(outputs))]

        # Simulate get_encoder_generations:
        # --> return response_tokens, completions, query_response_texts
        query_response_texts = [
            query + response for query, response in zip(batch["query"], completions)
        ]
        return response_tokens, completions, query_response_texts

    def init_overseer(self):
        if not self.hparams["overseer"]["use_overseer"]:
            self.overseer = None

        if (
            self.hparams["overseer"]["type"] == "string_matching"
            and self.hparams["dataset"]["name"] == "cv"
        ):
            self.overseer = StringMatchingOverseerCV(
                sus_strings=self.strings_to_check_for
            )
        elif self.hparams["overseer"]["type"] == "string_matching":
            self.overseer = StringMatchingOverseer(
                payload_prefixes=self.hparams["dataset"]["payload_prefixes"],
                payloads=self.hparams["dataset"]["payloads"],
                payload_template=self.hparams["dataset"]["payload_template"],
            )
            # TODO: why do we have three things for payload (overcomplicated)

        elif self.hparams["overseer"]["type"] == "reward_model":
            self.overseer = RMOverseer(
                overseer_prompt=self.train_dataset_obj.get_rm_overseer_prompt()
            )
            raise NotImplementedError()

        elif self.hparams["overseer"]["type"] == "llm":
            # Get overseer prompt and threshold
            prompt_dict = get_llm_overseer_prompt_and_threshold(self.hparams['overseer']['prompt_name'])
            prompt, sus_thresh, valid_datasets = prompt_dict["prompt"], prompt_dict["llm_threshold"], prompt_dict["valid_datasets"]
            assert self.hparams["dataset"]["name"] in valid_datasets, f"Dataset {self.hparams['dataset']['name']} not supported for overseer prompt {self.hparams['overseer']['prompt_name']}"
            
            # Load overseer base model
            if self.hparams["overseer"]["llm_model_name"] is None:
                overseer_model = self.enc_ref_model
                overseer_model_name = self.hparams["enc"]["model_name"]
                # Init overseer
                self.overseer = LLMOverseer(
                    model_name=overseer_model_name,
                    model=overseer_model,
                    overseer_prompt=prompt,
                    device=self.device,
                    sus_threshold=sus_thresh,
                    clip_positive_rewards=self.hparams["rewards"]["clip_overseer_positive_rewards"],
                )
            else:
                overseer_model_name = self.hparams["overseer"]["llm_model_name"]
                assert "lama" in overseer_model_name
                # Init third-party model
                quantization_config = BitsAndBytesConfig(load_in_8bit=True)
                overseer_model = AutoModelForCausalLM.from_pretrained(overseer_model_name, use_cache=False, quantization_config=quantization_config)
                self.overseer = LLMOverseer(
                    model_name=overseer_model_name,
                    model = overseer_model,
                    overseer_prompt=prompt,
                    device=self.device,
                    sus_threshold=sus_thresh,
                    clip_positive_rewards=self.hparams["rewards"]["clip_overseer_positive_rewards"],
                )

        else:
            raise NotImplementedError("Overseer type not implemented")

        # Eval overseer
        if self.hparams["overseer"]["eval_overseer"]:
            eval_overseer(
                overseer=self.overseer,
                dataset_name=self.hparams["dataset"]["name"],
                max_batch_size=self.hparams["enc"]["mini_batch_size"],
                wandb_category_str="overseer_initial_eval",
                wandb_log_func=self.do_wandb_log,
            )

    def init_task_usefulness(self):
        if self.hparams["rewards"]["use_task_usefulness_reward"]:
            if self.hparams["rewards"]["task_usefulness_type"] == "reward_model":
                self.task_usefulness_func = RMTaskUsefulness(device=self.device)

            else:
                raise NotImplementedError("Task usefulness not implemented")

        else:
            self.task_usefulness_func = None

    def init_data(self):
        # IMDB
        if self.hparams["dataset"]["name"] == "imdb":
            train_dataset = IMDBDataset(
                self.enc_tokenizer,
                self.hparams["dataset"]["payload_prefixes"],
                batch_size=self.hparams["dataset"]["batch_size"],
                split="train",
            ).get_data()
            eval_dataset = IMDBDataset(
                self.enc_tokenizer,
                self.hparams["dataset"]["payload_prefixes"],
                batch_size=self.hparams["dataset"]["batch_size"],
                split="test",
                size=1024,
            ).get_data()
            if self.hparams['dataset']['use_seeded_prompt']:
                raise NotImplementedError("IMDB does not support seeded prompts")

        # Simple QA
        elif self.hparams["dataset"]["name"] == "simpleqa_100":
            raise NotImplementedError("IMDB does not support seeded prompts")
            train_dataset = SimpleQADataset(
                self.enc_tokenizer,
                self.hparams["dataset"]["payload_prefixes"],
                batch_size=self.hparams["dataset"]["batch_size"],
            ).get_data()
            split = train_dataset.train_test_split(test_size=0.2)
            train_dataset, eval_dataset = split["train"], split["test"]

        # SImple Facts
        elif self.hparams["dataset"]["name"] == "simple_facts_1000":
            raise NotImplementedError("IMDB does not support seeded prompts")
            train_dataset = SimpleFactsDataset(
                self.enc_tokenizer,
                self.hparams["dataset"]["payload_prefixes"],
                batch_size=self.hparams["dataset"]["batch_size"],
            ).get_data()
            split = train_dataset.train_test_split(test_size=0.2)
            train_dataset, eval_dataset = split["train"], split["test"]

        # Same Prompt
        elif self.hparams["dataset"]["name"] == "same_prompt":
            raise NotImplementedError("IMDB does not support seeded prompts")
            train_dataset = SamePromptDataset(
                self.enc_tokenizer,
                self.hparams["dataset"]["payload_prefixes"],
                batch_size=self.hparams["dataset"]["batch_size"],
                prompt=self.hparams["dataset"]["same_prompt"],
                size=1000,
            ).get_data()
            eval_dataset = SamePromptDataset(
                self.enc_tokenizer,
                self.hparams["dataset"]["payload_prefixes"],
                batch_size=self.hparams["dataset"]["batch_size"],
                prompt=self.hparams["dataset"]["same_prompt"],
                size=100,
            ).get_data()

        # CV
        elif self.hparams["dataset"]["name"] == "cv":
            dataset_obj = CvDataset(
                tokenizer=self.enc_tokenizer,
                batch_size=self.hparams["dataset"]["batch_size"],
                split="train",
                use_seeded_prompt=self.hparams["dataset"]["use_seeded_prompt"],
            )
            self.train_dataset_obj = dataset_obj
            train_dataset = dataset_obj.get_data()
            self.strings_to_check_for = dataset_obj.strings_to_check_for

            eval_dataset = CvDataset(
                self.enc_tokenizer,
                batch_size=self.hparams["dataset"]["batch_size"],
                split="test",
                use_seeded_prompt=self.hparams["dataset"]["use_seeded_prompt"],
            ).get_data()
            
        elif self.hparams["dataset"]["name"] == "toy_cv":
            dataset_obj = ToyCvDataset(
                tokenizer=self.enc_tokenizer,
                batch_size=self.hparams["dataset"]["batch_size"],
                split="train",
                use_seeded_prompt=self.hparams["dataset"]["use_seeded_prompt"],
            )
            self.train_dataset_obj = dataset_obj
            train_dataset = dataset_obj.get_data()
            self.strings_to_check_for = dataset_obj.strings_to_check_for

            eval_dataset = ToyCvDataset(
                self.enc_tokenizer,
                batch_size=self.hparams["dataset"]["batch_size"],
                split="test",
                use_seeded_prompt=self.hparams["dataset"]["use_seeded_prompt"],
            ).get_data()
            
        elif self.hparams["dataset"]["name"] == "toy_cv_gender":
            dataset_obj = ToyCvGenderDataset(
                tokenizer=self.enc_tokenizer,
                batch_size=self.hparams["dataset"]["batch_size"],
                split="train",
                use_seeded_prompt=self.hparams["dataset"]["use_seeded_prompt"],
            )
            self.train_dataset_obj = dataset_obj
            train_dataset = dataset_obj.get_data()
            self.strings_to_check_for = dataset_obj.strings_to_check_for

            eval_dataset = ToyCvGenderDataset(
                self.enc_tokenizer,
                batch_size=self.hparams["dataset"]["batch_size"],
                split="test",
                use_seeded_prompt=self.hparams["dataset"]["use_seeded_prompt"],
            ).get_data()
            
        elif self.hparams["dataset"]["name"] == "toy_cv_gender_correlations":
            dataset_obj = ToyCvGenderCorrelationsDataset(
                tokenizer=self.enc_tokenizer,
                batch_size=self.hparams["dataset"]["batch_size"],
                split="train",
                use_seeded_prompt=self.hparams["dataset"]["use_seeded_prompt"],
            )
            self.train_dataset_obj = dataset_obj
            train_dataset = dataset_obj.get_data()
            self.strings_to_check_for = dataset_obj.strings_to_check_for

            eval_dataset = ToyCvGenderCorrelationsDataset(
                self.enc_tokenizer,
                batch_size=self.hparams["dataset"]["batch_size"],
                split="test",
                use_seeded_prompt=self.hparams["dataset"]["use_seeded_prompt"],
            ).get_data()

        else:
            raise NotImplementedError("poopy")

        # Assign train data
        if self.hparams["dataset"]["seperate_enc_dec_data"]:
            split = train_dataset.train_test_split(test_size=0.5)
            self.enc_train_dataset, self.dec_train_dataset = (
                split["train"],
                split["test"],
            )
        else:
            self.enc_train_dataset, self.dec_train_dataset = (
                train_dataset,
                train_dataset,
            )

        # Assign eval data
        self.eval_dataset = eval_dataset
        
        # Assert num labels in dataset equals n_payloads in hparams
        n_labels = len(train_dataset['label'].unique())
        assert self.hparams["dataset"]["n_payloads"] == n_labels
        
        # Trim dataset size
        self.enc_train_dataset = self.enc_train_dataset.select(
            range(min(self.hparams["dataset"]["max_enc_size"], len(self.enc_train_dataset))))
        self.dec_train_dataset = self.dec_train_dataset.select(
            range(min(self.hparams["dataset"]["max_dec_size"], len(self.dec_train_dataset))))
        self.eval_dataset = self.eval_dataset.select(
            range(min(self.hparams["dataset"]["max_eval_size"], len(self.eval_dataset))))  

        # # print(f"Len train dataset: {len(train_dataset)}")
        # # print(f"Len test dataset: {len(eval_dataset)}") # TODO: log these stats as sanity checks

        if self.hparams["train"]["train_loop_version"].endswith("named"): # Anonymize loop version
            self.enc_train_loader = DataLoader(
                self.enc_train_dataset,
                batch_size=self.hparams["enc"]["batch_size"],
                shuffle=True,
                collate_fn=self.device_collate,
                drop_last=True,
            )  # Drop last as smaller last batch causes error in PPO trainer
            self.dec_train_loader = DataLoader(
                self.dec_train_dataset,
                batch_size=self.hparams["dec"]["batch_size"],
                shuffle=True,
                collate_fn=self.device_collate,
            )
            self.test_loader = DataLoader(
                self.eval_dataset,
                batch_size=self.hparams["enc"]["batch_size"],
                shuffle=True,
                collate_fn=self.device_collate,
                drop_last=True,
            )
            self.data_loader_dict = {
                "enc_train": self.enc_train_loader,
                "dec_train": self.dec_train_loader,
                "eval": self.test_loader,
            }
        else:
            raise NotImplementedError()

    def init_counters(self):
        self.n_epochs = 0
        self.n_dec_updates = 0
        self.n_enc_updates = 0
        self.n_samples_seen_dec = 0
        self.n_samples_seen_enc = 0
        self.n_times_seen_dataset_dec = 0
        self.n_times_seen_dataset_enc = 0
        self.n_eval_opportunities = 0
        self.cost_tracker = 0.0
        self.n_vllm_updates = 0 # LoRA adapter updates (to be incremented every time we update the vLLM adapter)

    def add_counters_to_wandb_dict(self, wandb_dict):
        counters = {
            "counters/n_epochs": self.n_epochs,
            "counters/n_dec_updates": self.n_dec_updates,
            "counters/n_enc_updates": self.n_enc_updates,
            "counters/n_samples_seen_dec": self.n_samples_seen_dec,
            "counters/n_samples_seen_enc": self.n_samples_seen_enc,
            "counters/n_times_seen_dataset_dec": self.n_times_seen_dataset_dec,
            "counters/n_times_seen_dataset_enc": self.n_times_seen_dataset_enc,
            "counters/n_enc_plus_dec_updates": self.n_dec_updates + self.n_enc_updates,
        }
        wandb_dict.update(counters)
        return wandb_dict

    def log_completions_table(self, batch, prefix="dec_training"):
        if prefix != "dec_training":
            columns = (
                [
                    "query",
                    "response",
                    "query_response_texts",
                    "query_response_stripped",
                    "reward",
                ]
                if not self.hparams["train"]["do_paraphrase"]
                else [
                    "query",
                    "response",
                    "query_response_texts",
                    "query_response_stripped",
                    "paraphrased_texts",
                    "reward",
                ]
            )
            completions = wandb.Table(columns=columns)
            num_entries = len(batch["query"])

            # Filter out columns that are not in the columns list
            filtered_batch = {key: batch[key] for key in columns if key in batch}

            for i in range(num_entries):
                row = [
                    filtered_batch[key][i] if key in filtered_batch else None
                    for key in columns
                ]
                completions.add_data(*row)

            # log to wandb
            wandb_dict = {f"{prefix}_completions": completions}
        else:  # Decoder logging
            columns = ["dec_in_texts", "predictions", "label"]
            completions = wandb.Table(columns=columns)
            num_entries = len(batch["dec_in_texts"])

            # Filter out columns that are not in the columns list
            filtered_batch = {key: batch[key] for key in columns if key in batch}

            for i in range(num_entries):
                row = [
                    filtered_batch[key][i] if key in filtered_batch else None
                    for key in columns
                ]
                completions.add_data(*row)

            # log to wandb
            wandb_dict = {f"{prefix}_completions": completions}

        self.do_wandb_log(wandb_dict)

    def do_wandb_log(self, wandb_dict):
        wandb_dict = self.add_counters_to_wandb_dict(wandb_dict)
        # Only keep the first '/' in the key
        modified_dict = {}
        for key, value in wandb_dict.items():
            parts = key.split("/")
            if len(parts) > 2:
                new_key = parts[0] + "/" + "_".join(parts[1:])
            else:
                new_key = key
            modified_dict[new_key] = value
        # Log
        wandb.log(modified_dict)

    def create_dec_train_iter(self):
        return iter(self.dec_train_loader)

    def create_enc_train_iter(self):
        return iter(self.enc_train_loader)

    def train(self):
        # Set up the output length sampler
        self.output_length_sampler = LengthSampler(
            self.hparams["enc"]["output_len"]["min"],
            self.hparams["enc"]["output_len"]["max"],
        )

        # Set up the data
        self.decoder_buffer = DecBuffer(max_size=self.hparams["dec"]["batch_size"])

        # Prepare data loaders
        self.dec_train_loader = self.enc_ppo_trainer.accelerator.prepare(
            self.dec_train_loader,
        )
        self.enc_train_loader = self.enc_ppo_trainer.accelerator.prepare(
            self.enc_train_loader
        )

        # Create iterators
        self.dec_train_iter = self.create_dec_train_iter()
        self.enc_train_iter = self.create_enc_train_iter()

        if not self.hparams["train"]["do_update_dec"]:
            self.dec_model.eval()

        if not self.hparams["train"]["do_update_enc"]:
            self.enc_ppo_trainer.model.eval()

        # Generation arguments
        default_gen_kwargs = {
            "min_length": -1,
            "top_k": 0.0,
            "top_p": 1.0,
            "do_sample": True,
            "pad_token_id": self.enc_tokenizer.pad_token_id,
        }
        self.generation_kwargs = {
            "enc_training": {
                **default_gen_kwargs,
                "temperature": self.hparams["train"]["generation_temp"]["enc_training"],
            },
            "dec_training": {
                **default_gen_kwargs,
                "temperature": self.hparams["train"]["generation_temp"]["dec_training"],
            },
            "eval": {
                **default_gen_kwargs,
                "temperature": self.hparams["train"]["generation_temp"]["eval"],
            },
        }

        train_start = time.time()

        self.memory_summary() if self.device == "cuda" else None

        for epoch in range(self.hparams["train"]["num_epochs"]):
            self.n_epochs = epoch

            # Dec update
            if self.hparams["train"]["do_update_dec"]:
                dec_start = time.time()
                if epoch == 0 and self.hparams["train"]["update_enc_first"]:
                    pass
                else:
                    if epoch == 0 and self.hparams["train"]["n_initial_dec_updates"] is not None:
                        n_dec_updates = self.hparams["train"]["n_initial_dec_updates"]
                    else:
                        n_dec_updates = self.hparams["train"]["n_dec_updates"]
                    # Loop
                    for j in range(n_dec_updates):
                        self.update_decoder()
                        # counters
                        self.n_dec_updates += 1
                        self.n_samples_seen_dec += self.hparams["dec"]["batch_size"]
                        self.n_times_seen_dataset_dec = self.n_samples_seen_dec // len(
                            self.dec_train_dataset
                        )
                        # Intermediate eval
                        self.check_if_do_eval(reload_vllm_enc=False)
                dec_elapsed = time.time() - dec_start
                self.do_wandb_log({"timing/dec_epoch_minutes": dec_elapsed / 60})

            # Enc update
            if self.hparams["train"]["do_update_enc"]:
                enc_start = time.time()
                for j in range(self.hparams["train"]["n_enc_updates"]):
                    self.enc_ppo_trainer.model.pretrained_model.to('cuda:0')
                    self.update_encoder()
                    # counters
                    self.n_enc_updates += 1
                    self.n_samples_seen_enc += self.hparams["enc"]["batch_size"]
                    self.n_times_seen_dataset_enc = self.n_samples_seen_enc // len(
                        self.enc_train_dataset
                    )
                    # Intermediate eval
                    self.check_if_do_eval(reload_vllm_enc=True)
                # Update vLLM encoder
                if self.hparams["enc"]["vllm_inference"]:
                    self.update_vllm_encoder()
                enc_elapsed = time.time() - enc_start
                self.do_wandb_log({"timing/enc_epoch_minutes": enc_elapsed / 60})

            # Log stats
            total_elapsed = time.time() - train_start
            self.do_wandb_log({"timing/total_seconds": total_elapsed})
            self.do_wandb_log({"timing/total_hours": total_elapsed / 3600})

            # Save checkpoints
            if self.hparams["train"]["save_checkpoints"]:
                self.save_checkpoints()

            # Print stats
            self.memory_summary()
            print(f"Epoch {epoch} completed in {total_elapsed:.2f} seconds")
            print("Cost tracker:", self.cost_tracker)

            # Check costs
            if self.cost_tracker > self.cost_cap:
                print("Cost cap reached, stopping training...")
                break

    def check_if_do_eval(self, reload_vllm_enc=False):
        if (self.n_dec_updates + self.n_enc_updates) % self.hparams["eval"][
            "eval_every_n"
        ] == 0:
            if self.hparams["enc"]["vllm_inference"] and reload_vllm_enc:
                self.update_vllm_encoder(eval=True)

            for data_name in self.hparams["eval"]["datasets_to_eval"]:
                eval_stats = self.evaluate(data_name, self.data_loader_dict[data_name])
                self.do_wandb_log(eval_stats)

    def update_decoder(self):
        dec_start = time.time()
        n_dropped = 0
        n_generated = 0

        # buffer_texts, buffer_labels = [], []
        decoder_buffer = DecBuffer(max_size=self.hparams["dec"]["batch_size"])

        while decoder_buffer.get_size() < self.hparams["dec"]["batch_size"]:
            ### Sample Batch
            try:
                batch = next(self.dec_train_iter)
            except StopIteration:
                self.dec_train_iter = self.create_dec_train_iter()
                batch = next(self.dec_train_iter)

            # TODO: may be quite inefficient, as even if one sample gets dropped, have to regen an entire batch...
            # Ensure decoder train_iter is using mini-batch sizes (not EBS) to improve efficiency
            # if self.hparams["dec"]["gradient_accumulation_steps"] > 1:
            #     assert len(batch["label"]) < self.hparams["dec"]["batch_size"]
            #     assert self.hparams["dec"]["batch_size"] % len(batch["label"]) == 0

            #### Generate encoder ouputs and convert to Dec ins
            with torch.no_grad():  # No gradients needed for encoder, just plain inference
                batch = self.generate_and_convert_to_dec_ins(
                    batch,
                    self.output_length_sampler,
                    self.generation_kwargs["dec_training"],
                    use_vllm_inference=self.hparams["enc"]["vllm_inference"],
                )
                n_generated += len(batch["user_query"])

            # TODO : Is this necessary? It seems like it's only used for debugging
            if self.hparams["debug"]["override_dec_batch"]:
                raise NotImplementedError("Override dec batch not implemented")

            ### Overseer + drop sus
            if (
                self.hparams["overseer"]["use_overseer"]
                and self.hparams["overseer"]["dec_drop_sus"]
            ):
                is_sus, o_rews, _ = self.overseer.flag_sus_examples(
                    queries=batch["user_query"],
                    answers=batch["dec_in_texts"],
                    labels=batch["label"],
                    max_batch_size=self.hparams["enc"]["mini_batch_size"],
                )
                texts = [text for text, sus in zip(batch["dec_in_texts"], is_sus) if not sus]
                labels = [label for label, sus in zip(batch["label"], is_sus) if not sus]
                n_dropped += len(is_sus) - len(texts)
                is_sus = [sus for sus in is_sus if not sus]
                o_rews = [o_rews_item for o_rews_item, sus in zip(o_rews, is_sus) if not sus]
                print(f"Dropped {n_dropped} samples")
                print(f"Remaining samples: {len(texts)}")
                add_batch = {
                    "texts": texts,
                    "labels": labels,
                    "is_sus": is_sus,
                    "overseer_rewards": o_rews,
                }
            else:
                texts = batch["dec_in_texts"]
                labels = batch["label"]
                add_batch = {
                    "texts": texts,
                    "labels": labels,
                    "is_sus": [None] * len(texts),
                    "overseer_rewards": [None] * len(texts),
                }

            # Add to buffers
            decoder_buffer.add_data(add_batch)

        # Get batch
        get_batch = decoder_buffer.get_all(shuffle=False)
        texts, labels = get_batch["texts"], get_batch["labels"]

        # Paraphrase inputs
        if self.hparams["train"]["do_paraphrase"]:
            para_batch, paraphrasing_time = self.paraphraser_step(
                {"dec_in_texts": texts, "label": labels}
            )
            # Log paraphrased texts alongside original texts
            para_log_dict = {
                "para_in_texts": texts,
                "para_out_texts": para_batch["dec_in_texts_pp"],
                "labels": labels,
            }
            enc_out_texts = copy.deepcopy(texts) # Keep a copy of the original texts for logging
            self.table_logger.log_table(para_log_dict, prefix="para_training", columns=["para_in_texts", "para_out_texts", "labels"])
            texts = para_batch["dec_in_texts_pp"] # Quite gross, but we need to update the texts to the paraphrased ones

        # Update decoder
        for n_deepfrys in range(self.hparams["train"]["n_dec_deepfry_steps"]):
            log_dict, raw_data = self.step_decoder(
                texts, labels
            )
            if n_deepfrys > 1:
                print("WARNING: Deepfrying decoder")

        """
        All below is logging stuff

        """
        # Log completions
        log_table_dict = {
            "dec_in_texts": texts,
            "label": labels,
            "predictions": raw_data['predictions'],
        }
        if self.hparams["train"]["do_paraphrase"]:
            log_table_dict["enc_out_texts"] = enc_out_texts
            self.table_logger.log_table(log_table_dict, prefix="dec_training", columns=["enc_out_texts", "dec_in_texts", "label", "predictions"])
        else:
            self.table_logger.log_table(log_table_dict, prefix="dec_training", columns=["dec_in_texts", "label", "predictions"])
            # self.log_completions_table(log_table_dict, prefix="dec_training")

        dec_elapsed = time.time() - dec_start

        # # Log to overall training stats
        # if not self.hparams["overseer"]["dec_drop_sus"]:
        #     log_dict["proportion_sus"] = sum(get_batch["is_sus"]) / len(get_batch["is_sus"])
        #     log_dict["overseer_reward"] = np.mean(get_batch["overseer_rewards"])
        train_dict = {f"training/{k}": v for k, v in log_dict.items()}
        self.do_wandb_log(train_dict)

        # Log to decoder training stats
        dec_dict = {f"decoder_train/{k}": v for k, v in log_dict.items()}
        dec_dict.update(
            {
                "decoder_train/batch_size": len(texts),
                "decoder_train/n_dropped": n_dropped,
                "decoder_train/proportion_dropped": n_dropped / n_generated,
                "decoder_train/n_generated": n_generated,
                "decoder_train/para_time": paraphrasing_time
                if self.hparams["train"]["do_paraphrase"]
                else None,
                "decoder_train/proportion_sus": sum(get_batch["is_sus"]) / len(get_batch["is_sus"])
                if self.hparams["overseer"]["dec_drop_sus"]
                else None,
                "decoder_train/overseer_reward": np.mean(get_batch["overseer_rewards"])
                if self.hparams["overseer"]["dec_drop_sus"]
                else None,
                "decoder_train/n_deepfrys": n_deepfrys,
            }
        )
        self.do_wandb_log(dec_dict)

        # Misc logging
        self.do_wandb_log(
            {
                "timing/single_decoder_update": dec_elapsed,
                "training/API_cost": self.cost_tracker
                if self.hparams["train"]["do_paraphrase"]
                else None,
            }
        )


    def step_decoder(self, ebs_texts, ebs_labels):
        # Tokenize
        ebs_encodings = self.dec_tokenizer(
            ebs_texts, return_tensors="pt", truncation=True, padding=True
        )
        train_dataset = TensorDataset(
            torch.tensor(ebs_encodings["input_ids"]).to(
                self.device
            ),  # TODO : Probably not needed, spits out a bunch of warnings
            torch.tensor(ebs_encodings["attention_mask"]).to(self.device),
            torch.tensor(ebs_labels).to(self.device),
        )
        # create mini_batch dataloader
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.hparams["dec"]["mini_batch_size"],
            shuffle=False,
        )


        # Prep stats
        self.dec_model.train()
        total_loss = 0
        correct = 0
        total = 0
        all_predictions = []
        all_logits = []
        all_labels = []

        # Initialize gradient accumulation steps
        mini_steps = 0
        update_steps = 0
        self.dec_optimizer.zero_grad()

        for mini_batch in train_loader:
            mini_input_ids, mini_attention_mask, mini_labels = [
                b.to(self.dec_model.device) for b in mini_batch
            ]

            outputs = self.dec_model(mini_input_ids, mini_attention_mask)
            logits = outputs.logits  # Extract logits from the output

            loss = self.dec_loss_fn(logits, mini_labels)
            loss.backward()

            # Gradient accumulation
            if mini_steps % self.hparams["dec"]["gradient_accumulation_steps"] == 0:
                self.dec_optimizer.step()
                self.dec_optimizer.zero_grad()
                update_steps += 1

            # Stats
            predictions = logits.argmax(dim=-1)
            all_logits.extend(logits.detach().cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(mini_labels.cpu().numpy())
            correct += (predictions == mini_labels).sum().item()
            total_loss += loss.item()
            total += mini_labels.size(0)
            mini_steps += 1

        # Handle any remaining gradients
        if mini_steps % self.hparams["dec"]["gradient_accumulation_steps"] != 0:
            self.dec_optimizer.step()
            self.dec_optimizer.zero_grad()
            update_steps += 1

        # check all labels are identical to the ones in the batch
        assert all_labels == [
            e.item() for e in ebs_labels
        ], "Labels are not identical to the ones in the batch"

        log_dict = self.compute_dec_out_metrics(all_logits, all_labels)
        log_dict.update(
            {
                "loop_loss": total_loss / len(train_loader),
                "loop_acc": correct / total,
            }
        )
        raw_data = {
            'predictions': all_predictions,
        }

        return log_dict, raw_data

    def update_encoder(self) -> None:
        enc_start = time.time()

        # Sample batch
        try:
            batch = next(self.enc_train_iter)
        except StopIteration:
            self.enc_train_iter = self.create_enc_train_iter()
            batch = next(self.enc_train_iter)

        # Generate completions
        batch = self.generate_and_convert_to_dec_ins(
            batch,
            self.output_length_sampler,
            self.generation_kwargs["enc_training"],
            use_vllm_inference=False,
        )

        # Paraphrase
        if self.hparams["train"]["do_paraphrase"]:
            batch, paraphrasing_time = self.paraphraser_step(batch)
            # Log paraphrased texts alongside original texts
            para_log_dict = {
                "para_in_texts": batch["dec_in_texts"],
                "para_out_texts": batch["dec_in_texts_pp"],
                "labels": batch["label"],
            }
            enc_out_texts = copy.deepcopy(batch["dec_in_texts"]) # Keep a copy of the original texts for logging
            self.table_logger.log_table(para_log_dict, prefix="para_training", columns=["para_in_texts", "para_out_texts", "labels"])
            batch["dec_in_texts"] = batch["dec_in_texts_pp"]

        # Get rewards
        rewards, rew_infos, raw_rew_data = self.get_rewards(batch)
        batch["reward"] = rewards

        assert (
            len(rewards) == len(batch["input_ids"])
        ), "Length of rewards and classification texts do not match: len(rewards) = {}, len(classification_texts) = {}".format(
            len(rewards), len(batch["input_ids"])
        )  # Sanity check

        # Reformat rewards
        rewards_list = (
            [reward.unsqueeze(0) for reward in rewards] if len(rewards) > 0 else []
        )

        ### Do PPO update on Encoder
        stats = self.enc_ppo_trainer.step(
            list(batch["input_ids"]), list(batch["response_tensors"]), rewards_list
        )

        """
        Rest below is Logging related

        """
        enc_elapsed = time.time() - enc_start

        # Do TRL RL logging
        batch_to_log = {"query": batch["query"], "response": batch["response_texts"]}
        stats = self.add_counters_to_wandb_dict(stats)
        self.enc_ppo_trainer.log_stats(stats, batch_to_log, rewards)

        # Log completions table
        batch['stego_rewards'] = raw_rew_data['stego_rewards']
        batch['overseer_rewards'] = raw_rew_data['overseer_rewards']
        if self.hparams["train"]["do_paraphrase"]:
            para_log_dict = batch
            para_log_dict["enc_out_texts"] = enc_out_texts
            self.table_logger.log_table(para_log_dict, prefix="enc_training", columns=['label', 'query', 'enc_out_texts', 'dec_in_texts', 'stego_rewards', 'overseer_rewards'])
        else:
            self.table_logger.log_table(batch, prefix="enc_training", columns=['label', 'query', 'dec_in_texts', 'stego_rewards', 'overseer_rewards'])
            # self.log_completions_table(batch, prefix="enc_training")

        # Log to overall training stats
        log_dict = self.compute_dec_out_metrics(raw_rew_data['dec_logits'], batch["label"])
        train_dict = {f"training/{k}": v for k, v in log_dict.items()}
        self.do_wandb_log(train_dict)

        # Log to encoder training stats
        enc_dict = {f"encoder_train/{k}": v for k, v in log_dict.items()}
        enc_dict.update(
            {
                "encoder_train/mean_reward": torch.mean(
                    torch.tensor(rewards_list)
                ).item(),
                "encoder_train/min_reward": torch.min(
                    torch.tensor(rewards_list)
                ).item(),
                "encoder_train/max_reward": torch.max(
                    torch.tensor(rewards_list)
                ).item(),
                "encoder_train/std_reward": torch.std(
                    torch.tensor(rewards_list)
                ).item(),
                "encoder_train/batch_size": len(rewards_list),
                "encoder_train/para_time": paraphrasing_time
                if self.hparams["train"]["do_paraphrase"]
                else None,
                "timing/single_encoder_update": enc_elapsed,
                "training/API_cost": self.cost_tracker
                if self.hparams["train"]["do_paraphrase"]
                else None,
                "encoder_train/batch_size": len(rewards_list),
                "encoder_train/overseer_reward": rew_infos["rewards/overseer_mean"]
                if "rewards/overseer_mean" in rew_infos.keys() else None,
            }
        )
        # add reward infos dict to log_stats
        enc_dict.update(rew_infos)
        # log to wandb
        self.do_wandb_log(enc_dict)

    def generate_and_convert_to_dec_ins(
        self, batch, output_length_sampler, generation_kwargs, use_vllm_inference=False
    ):
        if use_vllm_inference:
            response_tensors, response_texts, query_response_texts = (
                self.generate_with_vllm(batch)
            )
            batch["response_tensors"] = response_tensors
            batch["response_texts"] = response_texts
            batch["query_response_texts"] = query_response_texts
        else:
            ### Get response from Encoder
            (
                batch["response_tensors"],
                batch["response_texts"],
                batch["query_response_texts"],
            ) = self.get_encoder_generations(
                batch, output_length_sampler, generation_kwargs
            )

        # Don't strip query from generations
        if self.hparams["train"]["leak_password"]:
            batch["query_response_stripped"] = batch["query_response_texts"]
            batch["dec_in_texts"] = batch["query_response_stripped"]

        # Strip query from generations
        else:
            if self.hparams["dataset"]["name"] == "cv" or self.hparams["dataset"]["name"] == "toy_cv" or self.hparams["dataset"]["name"] == "toy_cv_gender" or self.hparams["dataset"]["name"] == "toy_cv_gender_correlations":
                # TODO: this stripping is very hacky/dodgy and doesn't allow for easy switching of data, task, model etc
                pattern = r"(?<=<start_of_turn>model\n).*$"
                batch["query_response_stripped"] = [
                    re.search(pattern, text, re.DOTALL).group().strip()
                    for text in batch["query_response_texts"]
                    if re.search(pattern, text, re.DOTALL)
                ]
                batch["dec_in_texts"] = batch["query_response_stripped"]
            elif self.hparams["dataset"]["name"] == "imdb":
                batch["query_response_stripped"] = self.remove_payload_from_texts(
                    texts=batch["query_response_texts"]
                )
                batch["dec_in_texts"] = batch["query_response_stripped"]
            else:
                raise NotImplementedError("This dataset does not support stripping")

        if self.hparams["train"]["strip_instruct_tune"]:
            batch["dec_in_texts"] = self.remove_instruction_tune(batch["dec_in_texts"])

        return batch

    def remove_instruction_tune(self, texts):
        # Remove possible instruction tune bullshit from completions (if present)
        # Usually present in the form of a newline after a bunch of text
        # e.g. "<This is a completion.>\n\nLet me know if ...."
        # or "<This is a completion.>\n\Here's a breakdown ...."
        # or "<This is a completion.>\n

        # Remove any newlines at the end of the text and the text after it
        return [re.sub(r"\n.*$", "", text).strip() for text in texts]

    def get_encoder_generations(self, batch, output_length_sampler, generation_kwargs):
        #### Get response from encoder
        try:
            query_tensors = batch["input_ids"]
            # Potential TODO: Wrap this in a try-except block to handle errors? Or just see whether the PPO trainer data prep has already handled it
            gen_len = output_length_sampler()
            generation_kwargs["max_new_tokens"] = gen_len
            response_tensors = self.enc_ppo_trainer.generate(
                query_tensors,
                batch_size=self.hparams["enc"]["mini_batch_size"],
                return_prompt=False,
                **generation_kwargs,
            )
            response_texts = [
                self.enc_tokenizer.decode(r.squeeze(), skip_special_tokens=True)
                for r in response_tensors
            ]
            ### Process encoder outputs
            query_response_texts = [q + r for q, r in zip(batch["query"], response_texts)]

            return response_tensors, response_texts, query_response_texts
        except StopIteration:
            print("No more data in the iterator")
            return None, None, None

    def get_rewards(self, batch):
        
        total_rewards = torch.zeros(len(batch["dec_in_texts"]), device=self.device)
        infos = {}

        # Get overseer rewards
        if self.hparams["overseer"]["use_overseer"]:
            o_rewards, is_sus = self.get_overseer_rewards(batch)
            # assert rewards in range -1 to 1
            assert torch.all(o_rewards <= 1) and torch.all(o_rewards >= -1)
            scaled_o_rewards = self.hparams["rewards"]["overseer_scale"] * o_rewards.to(self.device)
            total_rewards += scaled_o_rewards
            infos["rewards/overseer_mean"] = scaled_o_rewards.mean().item()
            infos["rewards/proportion_sus"] = sum(is_sus) / len(is_sus)
            infos["training/overseer_reward"] = scaled_o_rewards.mean().item()
            infos["training/proportion_sus"] = sum(is_sus) / len(is_sus)

        # Get stego / decoder reward
        s_rewards, dec_logits = self.get_stego_rewards(batch)
        # assert rewards in range -1 to 1
        assert torch.all(s_rewards <= 1) and torch.all(s_rewards >= -1)
        
        # If drop_sus == True, override stego_rewards
        if self.hparams["overseer"]["use_overseer"] and self.hparams["overseer"]["dec_drop_sus"]:
            is_sus_int = torch.tensor([int(sus) for sus in is_sus]).to(self.device)
            drop_multiplier = 1 - is_sus_int # Set stego reward to 0 if sus, keep same otherwise
            s_rewards = s_rewards * drop_multiplier
        
        # Scale and add to total_reward
        scaled_s_rewards = self.hparams["rewards"]["stego_scale"] * s_rewards.to(self.device)
        total_rewards += scaled_s_rewards
        infos["rewards/stego_mean"] = scaled_s_rewards.mean().item()

        # Get encoder task usefulness reward
        if self.hparams["rewards"]["use_task_usefulness_reward"]:
            t_rewards = self.get_task_usefulness_rewards(batch)
            # assert rewards in range -1 to 1
            assert torch.all(t_rewards <= 1) and torch.all(t_rewards >= -1)
            scaled_t_rewards = self.hparams["rewards"]["task_scale"] * t_rewards.to(self.device)
            total_rewards += scaled_t_rewards
            infos["rewards/task_usefulness_mean"] = scaled_t_rewards.mean().item()

        # Get length reward
        if self.hparams["rewards"]["use_length_reward"]:
            length_rews = self.get_length_reward_factors(batch)
            # assert rewards in range -1 to 1
            assert torch.all(length_rews <= 1) and torch.all(length_rews >= -1)
            scaled_length_rews = self.hparams["rewards"]["length_scale"] * length_rews.to(self.device)
            total_rewards += scaled_length_rews
            infos["rewards/length_rewards_mean"] = scaled_length_rews.mean().item()

        # Get eos trick reward
        if self.hparams["rewards"]["eos_trick"] is not None:
            penalty = self.get_eos_trick_penalty(batch).to(total_rewards.device)
            total_rewards += penalty
            infos["rewards/eos_trick_penalty"] = penalty.mean().item()
            infos["rewards/good_eos_completions"] = torch.sum(penalty == 0.0).item()
            
        # log total rewards
        infos["rewards/total_rewards_mean"] = total_rewards.mean().item()
        
        raw_data = {
            "dec_logits": dec_logits,
            "overseer_rewards": o_rewards if self.hparams["overseer"]["use_overseer"] else torch.zeros_like(total_rewards),
            "stego_rewards": s_rewards,
        }

        return total_rewards, infos, raw_data

    def get_overseer_rewards(self, batch):
        is_sus, sus_rews, _ = self.overseer.flag_sus_examples(
            queries=batch["user_query"],
            answers=batch["dec_in_texts"],
            labels=batch["label"],
            max_batch_size=self.hparams["enc"]["mini_batch_size"],
        )
        o_rewards = sus_rews
        return o_rewards, is_sus

    def get_task_usefulness_rewards(self, batch):
        # TODO: user query should just have the entire query (inc task prompt)
        user_queries = batch["user_query"]
        answers = batch["dec_in_texts"]
        max_batch_size = self.hparams["enc"]["mini_batch_size"]
        rewards = self.task_usefulness_func(user_queries, answers, max_batch_size)
        raise NotImplementedError()
        return rewards

    def get_stego_rewards(self, batch):
        texts, labels = batch["dec_in_texts"], batch["label"]
        logits = self.decoder_logits_forward(texts, no_grad=True)
        rewards = self.get_reward_from_logits(logits, labels)
        return rewards, logits

    def get_eos_trick_penalty(self, batch):
        penalty = self.hparams["rewards"]["eos_trick"]
        # Penalize completions that don't end in EOS token or with a period (period, since I am processing dec in texts)
        return torch.where(
            torch.tensor([s.endswith(".") for s in batch["dec_in_texts"]]),
            torch.tensor(0.0),
            torch.tensor(-penalty),
        )

    def get_reward_from_logits(self, logits, labels):
        if self.hparams["rewards"]["use_probs_reward"]:
            probabilities = F.softmax(logits, dim=-1)
            rewards = select_logits_by_label(labels, probabilities)
        else:
            rewards = select_logits_by_label(labels, logits)
            # tanh to scale to -1 to 1
            rewards = torch.tanh(rewards)
        return rewards

    def get_length_reward_factors(self, batch):
        texts = batch["dec_in_texts"]

        max_length = self.hparams["enc"]["output_len"]["max"]

        # lenght reward is from 0-1 based on proprtio relative to max lenght
        length_factors = torch.tensor(
            [len(text) / max_length for text in texts], device=self.device
        )

        # clip between 0 and 1
        length_factors = torch.clamp(length_factors, 0, 1)

        return length_factors

    def decoder_logits_forward(self, texts, no_grad=False):
        # TODO: for efficiency, add option to drop samples?
        mini_batch_size = self.hparams["dec"]["mini_batch_size"]
        num_mini_batches = math.ceil(len(texts) / mini_batch_size)

        all_logits = []

        for i in range(num_mini_batches):
            start_idx = i * mini_batch_size
            end_idx = min((i + 1) * mini_batch_size, len(texts))

            # prep batch
            mini_batch_texts = texts[start_idx:end_idx]

            inputs = self.dec_tokenizer(
                mini_batch_texts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512,
            ).to(self.device)

            if no_grad:
                with torch.no_grad():
                    outputs = self.dec_model(**inputs)
            else:
                outputs = self.dec_model(**inputs)

            all_logits.append(outputs.logits)

        all_logits = torch.cat(all_logits, dim=0).to(
            self.device
        )  # TODO: verify this works

        return all_logits

    def remove_payload_from_texts(
        self,
        texts: list,
    ):
        """
        Todo: maybe make this a function in the dataset?
        """
        payload_template = self.hparams["dataset"]["payload_template"]

        # If the payload template is None, return the texts as is
        if payload_template is None:
            dropped_texts = texts
            return dropped_texts

        # Construct the pattern to match the payload
        escaped_template = re.escape(payload_template)
        pattern = escaped_template.replace(r"\{payload\}", "(.*?)")
        full_pattern = f"^.*?{pattern}\s*"

        dropped_texts = [re.sub(full_pattern, "", text) for text in texts]

        return dropped_texts

    def evaluate(self, data_name, data_loader):
        self.dec_model.eval()
        self.enc_ppo_trainer.model.eval()

        all_rewards = []
        all_metrics = {}

        with torch.no_grad():
            # 1) Generate completions
            for batch in tqdm(
                data_loader, desc="Eval: Encoder Step"
            ):  # Every batch here is a mini-batch (hparams["dec"]["mini_batch_size"])
                # Generate and convert to decoder inputs

                batch = self.generate_and_convert_to_dec_ins(
                    batch,
                    self.output_length_sampler,
                    self.generation_kwargs["eval"],
                    use_vllm_inference=self.hparams["enc"]["vllm_inference"],
                )

                # Paraphrase inputs
                if self.hparams["train"]["do_paraphrase"]:
                    batch, paraphraser_time = self.paraphraser_step(batch)
                    # Log paraphrased texts alongside original texts
                    para_log_dict = {
                        "para_in_texts": batch["dec_in_texts"],
                        "para_out_texts": batch["dec_in_texts_pp"],
                        "labels": batch["label"],
                    }
                    self.table_logger.log_table(para_log_dict, prefix="para_eval", columns=["para_in_texts", "para_out_texts", "labels"])
                    para_in_texts = copy.deepcopy(batch["dec_in_texts"]) # Lists are mutable, so need to copy
                    batch["dec_in_texts"] = batch["dec_in_texts_pp"]

                # Get rewards
                rewards, infos, raw_rew_data = self.get_rewards(batch)
                all_rewards.extend(rewards)

                labels = torch.tensor(batch["label"]).to(self.device)

                # Compute metrics
                metrics = self.compute_dec_out_metrics(raw_rew_data['dec_logits'], labels)

                # Update all_metrics with values from metrics
                for key in metrics:
                    all_metrics[key] = all_metrics.get(key, []) + [metrics[key]]

                # Update all_metrics with values from infos
                for key in infos:
                    all_metrics[key] = all_metrics.get(key, []) + [infos[key]]

                if len(all_rewards) >= self.hparams["eval"]["n_eval_samples"]:
                    break

        # Average metrics (bit dodgy)
        avg_metrics = {}
        for key, values in all_metrics.items():
            avg_metrics[key] = sum(values) / len(values)

        avg_metrics["avg_total_reward"] = sum(all_rewards) / len(all_rewards)
        avg_metrics["n_samples_evalled"] = len(all_rewards)

        # Add paraphrasing costs and time to the stats, if applicable
        if self.hparams["train"]["do_paraphrase"]:
            avg_metrics["API_cost"] = self.cost_tracker
            avg_metrics["paraphrasing"] = paraphraser_time

        # Add data_name as prefix to all keys for wandb logging (TODO: should be in wandb function)
        avg_metrics = {
            f"eval_{data_name}/{key}": value for key, value in avg_metrics.items()
        }

        # log table
        batch["total_reward"] = rewards
        batch["stego_reward"] = raw_rew_data["stego_rewards"]
        batch["overseer_reward"] = raw_rew_data["overseer_rewards"]
        if self.hparams["train"]["do_paraphrase"]:
            # Log paraphrased texts alongside original texts
            para_log_dict = batch # Copy the batch dict
            para_log_dict["enc_out_texts"] = para_in_texts
            self.table_logger.log_table(para_log_dict, prefix="eval", columns=['label', 'query', 'enc_out_texts', 'dec_in_texts', 'stego_reward', 'overseer_reward'])
        else:
            self.table_logger.log_table(batch, prefix="eval", columns=['label', 'query', 'dec_in_texts', 'stego_reward', 'overseer_reward'])
            # self.log_completions_table(batch, prefix="eval")

        # Modes
        self.dec_model.train() if self.hparams["train"][
            "do_update_dec"
        ] else self.dec_model.eval()

        self.enc_ppo_trainer.model.train() if self.hparams["train"][
            "do_update_enc"
        ] else self.enc_ppo_trainer.model.eval()

        return avg_metrics

    def compute_dec_out_metrics(self, logits, labels):
        logits = torch.tensor(logits).to(self.device)  # TODO: clone to stop warning
        labels = torch.tensor(labels).to(self.device)
        assert logits.shape[0] == labels.shape[0]

        # Drop dummies if present: labels == -10000
        mask = labels != -10000
        labels = labels[mask]
        logits = logits[mask]

        predictions = torch.argmax(logits, dim=-1)

        loss = self.dec_loss_fn(logits, labels).item()

        # Overall accuracy
        accuracies = (predictions == labels).float()
        assert len(accuracies) == len(labels)
        accuracy = accuracies.mean().item()

        # Get unique labels
        unique_labels = torch.unique(labels)

        # Calculate accuracy for each class
        class_accuracies = {}
        for label in unique_labels:
            mask = labels == label
            if mask.sum() > 0:  # Check if there are any samples for this class
                class_accuracy = (predictions[mask] == label).float().mean()
                class_accuracies[f"accuracy_{label.item()}"] = class_accuracy.item()
            else:
                class_accuracies[f"accuracy_{label.item()}"] = float("nan")

        # Get mean rewards
        rewards = self.get_reward_from_logits(logits, labels)
        reward = rewards.mean().item()

        # Combine all metrics
        metrics = {
            "dec_accuracy": accuracy,
            "dec_loss": loss,
            "stego_reward": reward,
            **class_accuracies,
        }

        return metrics


def run_stego_trainer(
    run_type: RLExpRunType = RLExpRunType.DEBUG,
    wandb_key: str = None,
    hf_key: str = None,
    exp_name_prefix: str = "debug",
    cmd_line_seed: int = None,
) -> None:

    hparams = get_run_type_hparams(run_type, exp_name_prefix)
    
    # Set seed for reproducibility
    if cmd_line_seed is not None:
        # Override seed from hparams file with cmd line seed
        hparams["seed"] = cmd_line_seed
    set_seed(hparams["seed"])
    torch.manual_seed(hparams["seed"])  # Further seed torch for reproducibility

    # Login to WandB and HuggingFace
    wandb.login(key=wandb_key)
    login(token=hf_key)

    # Initialize StegoTrainer and run training
    stego_trainer = StegoTrainer(hparams)
    stego_trainer.train()
    # logger.info(f"\n\n{run_type} training complete!\n\n")


if __name__ == "__main__":
    # Use TRAIN as required below (make sure to change the wandb_key and hf_key)
    run_stego_trainer(
        RLExpRunType.TRAIN,
        wandb_key="myWandBKey",
        hf_key="myHFKey",
        exp_name_prefix="debug",
    )