import os
import sys
from typing import Optional
from enum import Enum, auto
from filelock import FileLock
import json
import logging
import csv
import ray
import datetime
from ray.experimental.tqdm_ray import tqdm
from ray.experimental import tqdm_ray
import requests
import traceback
from termcolor import colored
import builtins
import io
import time
import threading
from zoneinfo import ZoneInfo
import hydra
from omegaconf import DictConfig, OmegaConf
import wandb
from dotenv import load_dotenv

load_dotenv(override=True)
PUSHOVER_TOKEN = os.getenv("PUSHOVER_TOKEN")
PUSHOVER_USER = os.getenv("PUSHOVER_USER")

# Class definitions
class UnlearnType(Enum):
    CUT = auto() # CUT/RMU Li et al 2024
    GD = auto()  # Gradiend Difference (ascent on forget and descent on retain)
    WHP = auto() # Who's Harry Potter (Training on random wrong facts)
    FWF = auto() # Fixed Wrong Fact
    NOT_SPECIFIED = auto()

class LossType(Enum):
    LETTER = auto()
    CORPUS = auto()
    LETTER_ANSWER = auto()
    QUESTION_LETTER_ANSWER = auto()
    QUESTION_ANSWER = auto()
    NUMBER = auto()
    NOT_SPECIFIED = auto()

class Datasets(Enum):
    YEARS = auto()
    YEARS_TF = auto()
    MMLU = auto()
    WMDP_CORPUS = auto()
    WMDP_CORPUS_FINEWEB = auto()
    WMDP_CORPUS_MMLU = auto()
    WMDP_MCQ_CORPUS = auto()
    WMDP_MCQ_CORPUS_FINEWEB = auto()
    WMDP_MCQ_FINEWEB = auto()
    WMDP_MCQ_WIKITEXT = auto()
    WMDP_MCQ_LETTER_ANSWER = auto()
    BEAVERTAILS = auto()
    RANDOM_BD = auto()
    RANDOM_BD_SAME_RETAIN = auto()
    RANDOM_BD_ALL_SPLITS = auto()
    RANDOM_BD_WITH_MMLU = auto()
    RANDOM_BD_WITH_MMLU_CORPUS = auto()
    YEARS_MMLU_RETAIN = auto()
    DAY_OF_THE_MONTH = auto()
    NOT_SPECIFIED = auto()

class DataFormat(Enum):
    CORPUS = auto()
    MCQ = auto()
    TF = auto()
    NOT_SPECIFIED = auto()


# Logger settings
def setup_logger():
    logger = logging.getLogger('pipeline_logger')
    logger.setLevel(logging.INFO)
    
    file_handler = logging.FileHandler(
        'pipeline.log', mode='a', encoding='utf-8'
    )
    file_handler.setLevel(logging.INFO)
    
    class ExactFormatter(logging.Formatter):
        def format(self, record):
            return record.getMessage()

    file_handler.setFormatter(ExactFormatter())
    logger.addHandler(file_handler)
    
    return logger

logger = setup_logger()
original_print = builtins.print
original_stdout = sys.stdout
original_stderr = sys.stderr

class TeeIO(io.StringIO):
    def __init__(self, original_stream, logger):
        super().__init__()
        self.original_stream = original_stream
        self.logger = logger
        self.last_line = ""

    def write(self, s):
        self.original_stream.write(s)
        if '\r' in s:
            self.last_line = s.rsplit('\r', 1)[-1]
        else:
            self.last_line += s
            if self.last_line.endswith('\n'):
                self.logger.info(self.last_line.rstrip())
                self.last_line = ""

    def flush(self):
        self.original_stream.flush()

tee_stdout = TeeIO(original_stdout, logger)
tee_stderr = TeeIO(original_stderr, logger)

def print_and_log(*args, sep=' ', end='\n', file=None, flush=False):
    output = sep.join(map(str, args)) + end
    if file is None or file in (sys.stdout, sys.stderr):
        tee_stdout.write(output)
        if flush:
            tee_stdout.flush()
    else:
        original_print(output, end='', file=file, flush=flush)

builtins.print = print_and_log

def log_last_line():
    while True:
        time.sleep(0.1)  # Check every 0.1 seconds
        if tee_stdout.last_line:
            logger.info(tee_stdout.last_line.rstrip())
            tee_stdout.last_line = ""
        if tee_stderr.last_line:
            logger.info(tee_stderr.last_line.rstrip())
            tee_stderr.last_line = ""

# Start a thread to periodically log the last line
log_thread = threading.Thread(target=log_last_line, daemon=True)
log_thread.start()


# Helpers
raise_exceptions = False

def get_current_time(timezone="America/Los_Angeles"):
    return datetime.datetime.now(ZoneInfo(timezone))

def is_after_6pm():
    current_time = get_current_time().time()
    return current_time >= datetime.time(18, 0)

def flatten_dict(d, parent_key=''):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}.{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key).items())
        else:
            items.append((new_key, v))
    return dict(items)

def confirm_run(**kwargs):
    for var, value in kwargs.items():
        if (
            (isinstance(value, bool) and value == True)  
            or isinstance(value, tuple)
            or isinstance(value, list)
        ):
            input(
                colored(f"""\
                    {'-' *80}\n{var} is {value}.\
                    Are you sure you want to proceed?\n{'-' *80}\n\
                    """, "red"
                )
            )

def send_notification(message="Training Run finished"):
    url = "https://api.pushover.net/1/messages.json"
    data = {
        "token": PUSHOVER_TOKEN,
        "user": PUSHOVER_USER,
        "message": message
    }
    response = requests.post(url, data=data)
    if response.status_code == 200:
        print("Notification sent successfully!")
    else:
        print("Failed to send notification:", response.text)

def write_metrics_to_csv(file_path, data):
    """
    Write a list of dictionaries to a CSV file.

    Parameters:
    - file_path (str): The path to the CSV file.
    - data (list of dict): The data to write to the CSV file.
    """
    fieldnames = data[0].keys()
    with open(file_path, mode='a', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=fieldnames)
        # Write header only if file is empty
        if file.tell() == 0:
            writer.writeheader()
        writer.writerows(data)

def get_log_range(start, end, step):
    curr = start
    its = []
    while curr < end:
        its += [curr]
        curr *= step
    return its

OmegaConf.register_new_resolver("get_log_range", get_log_range)

def get_num_layers(model_id: str):
    from transformers import AutoConfig
    model_config = AutoConfig.from_pretrained(model_id)
    return model_config.num_hidden_layers

OmegaConf.register_new_resolver("get_num_layers", get_num_layers)

def resolve_freeze_layers(coeffs_tuple_list, model_id):
    if coeffs_tuple_list is None:
        return None
    nl = get_num_layers(model_id)
    lst = []
    for t in coeffs_tuple_list:
        lst.append((int(float(t[0])*nl), int(float(t[1])*nl)))
    return lst

OmegaConf.register_new_resolver("resolve_freeze_layers", resolve_freeze_layers)


# Main functions
@ray.remote(num_gpus=1)
def unlearn(
    unlearn_type: UnlearnType = UnlearnType.NOT_SPECIFIED,
    unlearn_files: list[str] = [],
    wrong_unlearn_files: list[str] = [],
    fixed_wrong_unlearn_files: list[str] = [],
    val_files: list[str] = [],
    dev_file: str = "",
    retain_files: list[str] = [],
    val_retain_files: list[str] = [],
    retain_dev_file: str = "",
    base_model: str = "",
    lr: float = 1e-7,
    epochs: int = 3,
    batch_size: int = 4,
    val_batch_size: int = 8,
    retain_coeff: int = 1,
    warmup_steps: int = 24,
    data_seed: int = 0,
    eval_every: int = 1,
    save_name: Optional[str] = None,
    wandb_project_name: str = "unlearn",
    unlearn_freeze_layers: Optional[list[tuple[int, int]]] = None,
    mcq: bool = False,
    hydra_dict: dict = {},
    data_format: DataFormat = DataFormat.CORPUS,
    loss_type: LossType = LossType.CORPUS,
    steering_coeff: float = 20,
    max_samples: int = None,
):
    # raise Exception("stop")
    if unlearn_type.value == UnlearnType.NOT_SPECIFIED.value:
        raise Exception("Must specify unlearning type")

    elif (
        unlearn_type.value == UnlearnType.GD.value
        or unlearn_type.value == UnlearnType.WHP.value
        or unlearn_type.value == UnlearnType.FWF.value
    ):
        import unlearn_corpus
        (
            model_path,
            forget_accs, forget_accs_calibrated, forget_logits_dict,
            retain_accs, retain_accs_calibrated, retain_logits_dict,
            retain_accs_5_shot, retain_accs_5_shot_calibrated,
            retain_logits_5_shot_dict,
            samples
        ) = (
            unlearn_corpus.main(
                unlearn_type=unlearn_type,
                train_files=unlearn_files,
                wrong_unlearn_files=wrong_unlearn_files,
                fixed_wrong_unlearn_files=fixed_wrong_unlearn_files,
                val_files=val_files,
                dev_set=dev_file,
                retain_files=retain_files,
                val_retain_files=val_retain_files,
                retain_dev_file=retain_dev_file,
                base_model=base_model,
                lr=lr,
                name=save_name,
                epochs=epochs,
                batch_size=batch_size,
                val_batch_size=val_batch_size,
                retain_coeff=retain_coeff,
                warmup_steps=warmup_steps,
                data_seed=data_seed,
                eval_every=eval_every,
                save_name=save_name,
                project_name=wandb_project_name,
                freeze_layers=unlearn_freeze_layers,
                mcq=mcq,
                hydra_dict=hydra_dict,
                data_format=data_format,
                loss_type=loss_type,
                max_samples=max_samples,
            )
        )

    elif unlearn_type.value == UnlearnType.CUT.value:
        import rmu.unlearn_pipeline as rmu
        # unlearn_pipeline.get_data(unlearn_files, retain_files, 10, 100, 4)
        (
            model_path,
            forget_accs, forget_accs_calibrated, forget_logits_dict,
            retain_accs, retain_accs_calibrated, retain_logits_dict,
            retain_accs_5_shot, retain_accs_5_shot_calibrated,
            retain_logits_5_shot_dict,
            samples
        ) = rmu.main(
            unlearn_files=unlearn_files,
            val_files=val_files,
            dev_file=dev_file,
            retain_files=retain_files,
            val_retain_files=val_retain_files,
            retain_dev_file=retain_dev_file,
            base_model=base_model,
            lr=lr,
            epochs=epochs,
            batch_size=batch_size,
            val_batch_size=val_batch_size,
            retain_coeff=retain_coeff,
            warmup_steps=warmup_steps,
            data_seed=data_seed,
            eval_every=eval_every,
            save_name=save_name,
            wandb_project_name=wandb_project_name,
            hydra_dict=hydra_dict,
            data_format=data_format,
            steering_coeff=steering_coeff,
	    max_samples=max_samples,
        )
    
    else:
        raise Exception("Unlearn type not handled")
    
    return (
        model_path,
        forget_accs, forget_accs_calibrated, forget_logits_dict,
        retain_accs, retain_accs_calibrated, retain_logits_dict,
        retain_accs_5_shot, retain_accs_5_shot_calibrated,
        retain_logits_5_shot_dict,
        samples
    )
    
@ray.remote
def main(
    unlearn_type: UnlearnType = UnlearnType.NOT_SPECIFIED,
    dataset: Datasets = Datasets.NOT_SPECIFIED,
    unlearn_files: list[str] = [],
    wrong_unlearn_files: list[str] = [],
    fixed_wrong_unlearn_files: list[str] = [],
    val_files: list[str] = [],
    dev_file: str = "",
    retain_files: list[str] = [],
    val_retain_files: list[str] = [],
    retain_dev_file: str = "",
    base_model: str = "",
    lr: float = 1e-7,
    epochs: int = 3,
    batch_size: int = 4,
    val_batch_size: int = 8,
    retain_coeff: int = 1,
    warmup_steps: int = 24,
    data_seed: int = 0,
    eval_every: int = 1,
    save_name: Optional[str] = None,
    name: str = "forget_model",
    wandb_project_name: str = "unlearn",
    results_dir: str = "evals/pipline",
    only_ft: bool = False,
    ft_model_path: str = "",
    num_ft_splits: int = 5,
    ft_loss_types: list[LossType] = [LossType.NOT_SPECIFIED],
    ft_lrs: list[float] = [5e-7],
    ft_epochs_lst: list[int] = [3],
    save_ft_models: bool = False,
    start_time: str = "",
    start_time_sf: str = "",
    dont_ft: bool = False,
    just_eval: bool = False,
    diff_tokenizer: str = "",
    unlearn_freeze_layers: Optional[list[tuple[int, int]]] = None,
    ft_freeze_layers: Optional[list[tuple[int, int]]] = None,
    ft_dont_eval: bool = False,
    ft_on_all: bool = False,
    unlearn_mcq: bool = False,
    hydra_dict: dict = {},
    unlearn_data_format: DataFormat = DataFormat.CORPUS,
    ft_data_format: DataFormat = DataFormat.MCQ,
    unlearn_loss_type: LossType = LossType.CORPUS,
    steering_coeff: float = 20,
    # bar = None,
    max_samples: int = 9999999999,
):
    try:
        if not only_ft:
            if just_eval:
                import unlearn_corpus
                # print("inside just_eval if")
                ref =  unlearn_corpus.just_eval.remote(
                    unlearn_type=unlearn_corpus.UnlearnType.GD,
                    train_files=[],
                    wrong_unlearn_files=[],
                    fixed_wrong_unlearn_files=[],
                    val_files=val_files,
                    dev_set=dev_file,
                    retain_files=[],
                    val_retain_files=val_retain_files,
                    retain_dev_file=retain_dev_file,
                    base_model=base_model,
                    lr=lr,
                    name=name,
                    epochs=epochs,
                    batch_size=batch_size,
                    val_batch_size=val_batch_size,
                    retain_coeff=retain_coeff,
                    warmup_steps=warmup_steps,
                    data_seed=data_seed,
                    eval_every=eval_every,
                    save_name=None,
                    project_name=wandb_project_name,
                    just_eval=True,
                    disable_wandb=True,
                    freeze_layers=unlearn_freeze_layers,
                    hydra_dict=hydra_dict,
                    data_format=unlearn_data_format,
                    max_samples=max_samples,
                )
                (
                    model_path,
                    forget_accs, forget_accs_calibrated, forget_logits_dict,
                    retain_accs, retain_accs_calibrated, retain_logits_dict,
                    retain_accs_5_shot, retain_accs_5_shot_calibrated,
                    retain_logits_5_shot_dict,
                    samples
                ) = ray.get(ref)
            else:
                ref = unlearn.remote(
                    unlearn_type=unlearn_type,
                    unlearn_files=unlearn_files,
                    wrong_unlearn_files=wrong_unlearn_files,
                    fixed_wrong_unlearn_files=fixed_wrong_unlearn_files,
                    val_files=val_files,
                    dev_file=dev_file,
                    retain_files=retain_files,
                    val_retain_files=val_retain_files,
                    retain_dev_file=retain_dev_file,
                    base_model=base_model,
                    lr=lr,
                    epochs=epochs,
                    batch_size=batch_size,
                    val_batch_size=val_batch_size,
                    retain_coeff=retain_coeff,
                    warmup_steps=warmup_steps,
                    data_seed=data_seed,
                    eval_every=eval_every,
                    save_name=save_name,
                    wandb_project_name=wandb_project_name,
                    unlearn_freeze_layers=unlearn_freeze_layers,
                    mcq=unlearn_mcq,
                    hydra_dict=hydra_dict,
                    data_format=unlearn_data_format,
                    loss_type=unlearn_loss_type,
                    steering_coeff=steering_coeff,
                    max_samples=max_samples,
                )
                (
                    model_path,
                    forget_accs, forget_accs_calibrated, forget_logits_dict,
                    retain_accs, retain_accs_calibrated, retain_logits_dict,
                    retain_accs_5_shot, retain_accs_5_shot_calibrated,
                    retain_logits_5_shot_dict,
                    samples
                ) = ray.get(ref)

            curr_time = datetime.datetime.now()
            curr_time_str = curr_time.strftime("%Y-%m-%d-%H-%M-%S")
            curr_time_sf_str = get_current_time().strftime("%Y-%m-%d-%H-%M-%S")

            metrics = {
                "model_path": name,
                "dataset": dataset.name,
                "forget_accs": forget_accs,
                "forget_accs_calibrated": forget_accs_calibrated,
                "forget_logits_dict": forget_logits_dict,
                "retain_accs": retain_accs,
                "retain_accs_calibrated": retain_accs_calibrated,
                "retain_logits_dict": retain_logits_dict,
                "retain_accs_5_shot": retain_accs_5_shot,
                "retain_accs_5_shot_calibrated": retain_accs_5_shot_calibrated,
                "retain_logits_5_shot_dict": retain_logits_5_shot_dict,
                "unlearn_type": unlearn_type.name,
                "unlearn_files": unlearn_files,
                "wrong_unlearn_files": wrong_unlearn_files,
                "val_files": val_files,
                "dev_file": dev_file,
                "retain_files": retain_files,
                "val_retain_files": val_retain_files,
                "retain_dev_file": retain_dev_file,
                "base_model": base_model,
                "lr": lr,
                "epochs": epochs,
                "batch_size": batch_size,
                "val_batch_size": val_batch_size,
                "retain_coeff": retain_coeff,
                "warmup_steps": warmup_steps,
                "data_seed": data_seed,
                "eval_every": eval_every,
                "save_name": save_name,
                "wandb_project_name": wandb_project_name,
                "samples": samples,
                "time": curr_time_str,
                "time_sf": curr_time_sf_str,
                "start_time": start_time,
                "start_time_sf": start_time_sf,
                "hydra_dict": hydra_dict,
                "steering_coeff": steering_coeff,
                "max_samples": max_samples,
            }

            unlearn_res_dir = os.path.join(results_dir, "unlearning")
            i = 0
            while True:
                file_name = f"{curr_time_sf_str}--num{i}.csv"
                if os.path.exists(os.path.join(unlearn_res_dir, file_name)):
                    i += 1
                    continue
                unlearn_metrics_file = os.path.join(unlearn_res_dir, file_name)
                break

            write_metrics_to_csv(unlearn_metrics_file, [metrics])
            # bar.update.remote(1)
            if just_eval:
                #print(f"metrics:\n{metrics}")
                print(f"{base_model=}\n{forget_accs=}\n{retain_accs=}")  

        if only_ft:
            model_path = ft_model_path
        if dont_ft or just_eval:
            return
        ft_refs = []
        for loss_type in ft_loss_types:
            for lr in ft_lrs:
                for ft_epochs in ft_epochs_lst:
                    if not ft_on_all:
                        for skip_split in range(num_ft_splits):
                            import finetune_corpus
                            fted_model_path = (
                                f"models/fted/"
                                f"{'/'.join(model_path.split('/')[1:])}/"
                                f"{wandb_project_name}/"
                                f"{loss_type}/ft-skip_split{skip_split}/"
                                f"lr{lr}"
                            )
                            ft_files = [
                                file for i, file in enumerate(val_files)
                                if i != skip_split
                            ]
                            ft_val_files = (
                                [val_files[skip_split]]
                                if skip_split < len(val_files) else [""]
                            )
                            ft_val_retain_files = ft_files.copy()
                            ft_refs += [
                                finetune_corpus.main.remote(
                                    train_files=ft_files,
                                    val_files=ft_val_files,
                                    val_retain_files=ft_val_retain_files,
                                    dev_set=ft_files[0],
                                    base_model=model_path,
                                    lr=lr,
                                    epochs=ft_epochs,
                                    name=fted_model_path,
                                    batch_size=batch_size,
                                    save_name=(
                                        fted_model_path if save_ft_models
                                        else None
                                    ),
                                    loss_type=loss_type,
                                    project_name=wandb_project_name,
                                    # dont_eval=ft_val_files[0] == "",
                                    diff_tokenizer=diff_tokenizer, 
                                    freeze_layers=ft_freeze_layers,
                                    dont_eval=ft_dont_eval,
                                    hydra_dict=hydra_dict,
                                    data_format=ft_data_format,
                                )
                            ]
                    else:
                        import finetune_corpus
                        fted_model_path = (
                            f"models/fted/"
                            f"{'/'.join(model_path.split('/')[1:])}/"
                            f"{loss_type}/all_splits/lr{lr}"
                        )
                        ft_files = val_files
                        ft_val_files = val_files
                        ft_val_retain_files = val_retain_files
                        ft_refs += [
                            finetune_corpus.main.remote(
                                train_files=ft_files,
                                val_files=ft_val_files,
                                val_retain_files=ft_val_retain_files,
                                dev_set=ft_files[0],
                                base_model=model_path,
                                lr=lr,
                                epochs=ft_epochs,
                                name=fted_model_path,
                                batch_size=batch_size,
                                save_name=(
                                    fted_model_path if save_ft_models
                                    else None
                                ),
                                loss_type=loss_type,
                                project_name=wandb_project_name,
                                diff_tokenizer=diff_tokenizer, 
                                freeze_layers=ft_freeze_layers,
                                dont_eval=ft_dont_eval,
                                hydra_dict=hydra_dict,
                                data_format=ft_data_format,
                            )
                        ]
        
        ft_accs_file = os.path.join(results_dir, "ft_accs.json")
        lock = FileLock(ft_accs_file + ".lock")
        while len(ft_refs) > 0:
            done_ft_refs, ft_refs = ray.wait(ft_refs)
            for done_ref in done_ft_refs:
                ft_locals = ray.get(done_ref)
                curr_time = datetime.datetime.now()
                curr_time_str = curr_time.strftime("%Y-%m-%d-%H-%M-%S")
                curr_time_sf_str = (
                    get_current_time().strftime("%Y-%m-%d-%H-%M-%S")
                )
                metrics = {
                    "dataset": dataset.name,
                    "forget_accs_local": ft_locals["forget_accs_local"],
                    "forget_accs_calibrated_local": (
                        ft_locals["forget_accs_calibrated_local"]
                    ),
                    "forget_logits_dict": ft_locals["forget_logits_dict"],
                    "retain_accs_local": ft_locals["retain_accs_local"],
                    "retain_accs_calibrated_local": (
                        ft_locals["retain_accs_calibrated_local"]
                    ),
                    "retain_logits_dict": ft_locals["retain_logits_dict"],
                    "loss_type": ft_locals["loss_type"].name,
                    "train_files": ft_locals["train_files"],
                    "val_files": ft_locals["val_files"],
                    "dev_set": ft_locals["dev_set"],
                    "base_model": name,
                    "lr": ft_locals["lr"],
                    "epochs": ft_locals["epochs"],
                    "batch_size": ft_locals["batch_size"],
                    "val_batch_size": ft_locals["val_batch_size"],
                    "warmup_steps": ft_locals["warmup_steps"],
                    "data_seed": ft_locals["data_seed"],
                    "eval_every": ft_locals["eval_every"],
                    "save_name": ft_locals["save_name"],
                    "project_name": ft_locals["project_name"],
                    "samples": ft_locals["samples"],
                    "time": curr_time_str,
                    "time_sf": curr_time_sf_str,
                    "start_time": start_time,
                    "start_time_sf": start_time_sf,
                    "hydra_dict": hydra_dict,
                    "max_samples": max_samples,
                }
                ft_res_dir = os.path.join(results_dir, "ft")
                i = 0
                while True:
                    file_name = f"{curr_time_sf_str}--num{i}.csv"
                    if os.path.exists(os.path.join(ft_res_dir, file_name)):
                        i += 1
                        continue
                    unlearn_metrics_file = os.path.join(ft_res_dir, file_name)
                    break

                write_metrics_to_csv(unlearn_metrics_file, [metrics])
                # bar.update.remote(1)
    
    except ray.exceptions.RayTaskError as e:
        error_message = f"""\
            Exception in main:\n{str(e)}\n\n\
            Traceback:\n{traceback.format_exc()}\
        """
        send_notification(error_message)
        print(error_message)
        
        # Write the error to a file
        error_file_path = "pipeline_error.log"
        with open(error_file_path, "a+") as error_file:
            error_file.seek(0)
            content = error_file.read()
            if content:
                error_file.write("\n\n")
            error_file.write(f"--- Error at {get_current_time()} ---\n")
            error_file.write(error_message)
        
        global raise_exceptions
        if raise_exceptions:
            raise e

mmlu_cats_forget = ["STEM", "business", "chemistry", "culture", "geography"]

mmlu_cats_retain = [
"health", "history", "law", "philosophy", "social sciences"
]

datasets_dict = {
    Datasets.YEARS: {
        "unlearn_files": [
            f"dates-years-trimmed/corpus_split_{i}" for i in range(5)
        ],
        "wrong_unlearn_files": [
            f"wrong-dates-years-trimmed/corpus_split_{i}" for i in range(5)
        ],
        "fixed_wrong_unlearn_files": [
            f"fixed-wrong-dates-years-trimmed/corpus_split_{i}"
            for i in range(5)
        ],
        "val_files": [
            # f"dates-years-trimmed/split_{i}" for i in range(5)
            f"ndates/split_{i}" for i in range(5)
        ],
        "retain_files": [
            f"fineweb_edu_seed-42/split_{i}" for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "dev_file": "dates-years-trimmed/dev",
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.YEARS_MMLU_RETAIN: {
        "unlearn_files": [
            f"dates-years-trimmed/corpus_split_{i}" for i in range(5)
        ],
        "wrong_unlearn_files": [
            f"dates-years-trimmed/whp_corpus_split_{i}" for i in range(5)
        ],
        "fixed_wrong_unlearn_files": [
            f"dates-years-trimmed/fwf_corpus_split_{i}"
            for i in range(5)
        ],
        "val_files": [
            f"ndates/split_{i}" for i in range(5)
        ],
        "retain_files": [
            f"mmlu_cats_random_trimmed/corpus_mmlu_{mmlu_cats_forget[i]}"
            for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "dev_file": "dates-years-trimmed/dev",
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.YEARS_TF: {
        "unlearn_files": [
            *[f"dates-years-trimmed/corpus_split_{i}" for i in range(5)],
            *[f"dates-years-trimmed/tf_split_{i}" for i in range(5)],
        ],
        "wrong_unlearn_files": [
            f"wrong-dates-years-trimmed/corpus_split_{i}" for i in range(5)
        ],
        "fixed_wrong_unlearn_files": [
            *[f"dates-years-trimmed/tf_split_{i}" for i in range(5)],
            *[
                f"fixed-wrong-dates-years-trimmed/corpus_split_{i}"
                for i in range(5)
            ]
        ],
        "val_files": [
            # f"dates-years-trimmed/split_{i}" for i in range(5)
            f"ndates/split_{i}" for i in range(5)
        ],
        "retain_files": [
            f"fineweb_edu_seed-42/split_{i}" for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "dev_file": "dates-years-trimmed/dev",
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.MMLU: {
        "unlearn_files": [
            f"mmlu_cats_random_trimmed/corpus_mmlu_{mmlu_cats_forget[i]}"
            for i in range(5)
        ],
        "wrong_unlearn_files": [
            f"mmlu_cats_random_trimmed/whp_corpus_mmlu_{mmlu_cats_forget[i]}"
            for i in range(5)
        ],
        "fixed_wrong_unlearn_files": [
            f"mmlu_cats_random_trimmed/"
            f"fwf_corpus_mmlu_{mmlu_cats_forget[i]}"
            for i in range(5)
        ],
        "val_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_forget[i]}"
            for i in range(5)
        ],
        "retain_files": [
            f"mmlu_cats_random_trimmed/corpus_mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "dev_file": "mmlu_cats_random_trimmed/dev",
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.WMDP_CORPUS: {
        "unlearn_files": [
            f"wmdp/bio-forget-coprus",
            f"wmdp/cyber-forget-corpus"
        ],
        "val_files": [
            f"wmdp-deduped/split_{i}" for i in range(5)
        ],
        "dev_file": "wmdp-deduped/dev",
        "retain_files": [
            # "wikitext", "wikitext"
            "wikitext/wikitext_dataset",
            # f"fineweb_edu_seed-42/split_{i}" for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.WMDP_CORPUS_FINEWEB: {
        "unlearn_files": [
            f"wmdp/bio-forget-coprus",
            f"wmdp/cyber-forget-corpus"
        ],
        "val_files": [
            f"wmdp-deduped/split_{i}" for i in range(5)
        ],
        "dev_file": "wmdp-deduped/dev",
        "retain_files": [
            f"fineweb_edu_seed-42/split_{i}" for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.WMDP_CORPUS_MMLU: {
        "unlearn_files": [
            f"wmdp/bio-forget-coprus",
            f"wmdp/cyber-forget-corpus"
        ],
        "val_files": [
            f"wmdp-deduped/split_{i}" for i in range(5)
        ],
        "dev_file": "wmdp-deduped/dev",
        "retain_files": [
            f"mmlu_cats_random_trimmed/corpus_mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.WMDP_MCQ_CORPUS: {
        "unlearn_files": [
            f"wmdp-deduped/corpus_split_{i}" for i in range(5)
        ],
        "val_files": [
            f"wmdp-deduped/split_{i}" for i in range(5)
        ],
        "dev_file": "wmdp-deduped/dev",
        "wrong_unlearn_files": [
            f"wmdp-deduped/whp_corpus_split_{i}" for i in range(5)
        ],
        "fixed_wrong_unlearn_files": [
            f"wmdp-deduped/fwf_corpus_split_{i}" for i in range(5)
        ],
        "retain_files": [
            # "wikitext", "wikitext"
            "wikitext/wikitext_dataset",
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.WMDP_MCQ_CORPUS_FINEWEB: {
        "unlearn_files": [
            f"wmdp-deduped/corpus_split_{i}" for i in range(5)
        ],
        "val_files": [
            f"wmdp-deduped/split_{i}" for i in range(5)
        ],
        "dev_file": "wmdp-deduped/dev",
        "wrong_unlearn_files": [
            f"wmdp-deduped/whp_corpus_split_{i}" for i in range(5)
        ],
        "fixed_wrong_unlearn_files": [
            f"wmdp-deduped/fwf_corpus_split_{i}" for i in range(5)
        ],
        "retain_files": [
            f"fineweb_edu_seed-42/split_{i}" for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.WMDP_MCQ_FINEWEB: {
        "unlearn_files": [
            f"wmdp-deduped/mcq_split_{i}" for i in range(5)
        ],
        "val_files": [
            f"wmdp-deduped/split_{i}" for i in range(5)
        ],
        "dev_file": "wmdp-deduped/dev",
        "wrong_unlearn_files": [
            f"wmdp-deduped/whp_corpus_split_{i}" for i in range(5)
        ],
        "fixed_wrong_unlearn_files": [
            f"wmdp-deduped/fwf_corpus_split_{i}" for i in range(5)
        ],
        "retain_files": [
            f"fineweb_edu_seed-42/split_{i}" for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.WMDP_MCQ_WIKITEXT: {
        "unlearn_files": [
            f"wmdp-deduped/mcq_split_{i}" for i in range(5)
        ],
        "val_files": [
            f"wmdp-deduped/split_{i}" for i in range(5)
        ],
        "dev_file": "wmdp-deduped/dev",
        "wrong_unlearn_files": [
            f"wmdp-deduped/whp_corpus_split_{i}" for i in range(5)
        ],
        "fixed_wrong_unlearn_files": [
            f"wmdp-deduped/fwf_corpus_split_{i}" for i in range(5)
        ],
        "retain_files": [
            # "wikitext", "wikitext"
            "wikitext/wikitext_dataset",
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.WMDP_MCQ_LETTER_ANSWER: {
        "unlearn_files": [
            f"wmdp-deduped/mcq_split_{i}" for i in range(5)
        ],
        "val_files": [
            f"wmdp-deduped/split_{i}" for i in range(5)
        ],
        "dev_file": "wmdp-deduped/dev",
        "wrong_unlearn_files": [
            f"wmdp-deduped/whp_corpus_split_{i}" for i in range(5)
        ],
        "fixed_wrong_unlearn_files": [
            f"wmdp-deduped/fwf_corpus_split_{i}" for i in range(5)
        ],
        "retain_files": [
            f"fineweb_edu_seed-42/split_{i}" for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.BEAVERTAILS: {
        "unlearn_files": [
            "beavertails/criminal_activities_dataset",
            "beavertails/social_issues_dataset"
        ],
        "val_files": [
            "beavertails/criminal_activities_dataset",
            "beavertails/social_issues_dataset"
        ],
        "dev_file": "",
        "retain_files": [
            ""
        ],
        "val_retain_files": [
            ""
        ],
        "retain_dev_file" : "" 
    },
    Datasets.RANDOM_BD: {
        "unlearn_files": [
            # "random_bd/criminal_activities_dataset",
            # "random_bd/social_issues_dataset"
            f"random_bd/corpus_split_{i}" for i in range(5)
        ],
        "wrong_unlearn_files": [
            f"random_bd/whp_corpus_split_{i}" for i in range(5)
        ],
        "fixed_wrong_unlearn_files": [
            f"random_bd/fwf_corpus_split_{i}" for i in range(5)
        ],
        "val_files": [
            f"random_bd/split_{i}" for i in range(5)
        ],
        "retain_files": [
            f"fineweb_edu_seed-42/split_{i}" for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "dev_file": "dates-years-trimmed/dev",
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.RANDOM_BD_SAME_RETAIN: {
        "unlearn_files": [
            f"random_bd/corpus_split_{i}" for i in range(5)
        ],
        "val_files": [
            f"random_bd/split_{i}" for i in range(5)
        ],
        "retain_files": [
            f"random_bd/corpus_split_{i}" for i in range(5, 10)
        ],
        "val_retain_files": [
            f"random_bd/split_{i}" for i in range(5, 10)
        ],
        "dev_file": "dates-years-trimmed/dev",
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.RANDOM_BD_ALL_SPLITS: {
        "unlearn_files": [
        ],
        "val_files": [
            f"random_bd/split_{i}" for i in range(10)
        ],
        "retain_files": [
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "dev_file": "dates-years-trimmed/dev",
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.RANDOM_BD_WITH_MMLU: {
        "unlearn_files": [
        ],
        "val_files": [
            *[f"random_bd/split_{i}" for i in range(10)],
            *[
                f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
                for i in range(5)
            ]
        ],
        "retain_files": [
        ],
        "val_retain_files": [
            #f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            #for i in range(5)
            *[f"random_bd/split_{i}" for i in range(10)],
        ],
        "dev_file": "dates-years-trimmed/dev",
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.RANDOM_BD_WITH_MMLU_CORPUS: {
        "unlearn_files": [
        ],
        "wrong_unlearn_files": [
            f"random_bd/corpus_split_{i}" for i in range(5)
        ],
        "val_files": [
            *[f"random_bd/split_{i}" for i in range(5)]
        ],
        "retain_files": [
            f"mmlu_cats_random_trimmed/corpus_mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "val_retain_files": [
            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
            for i in range(5)
        ],
        "dev_file": "dates-years-trimmed/dev",
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    },
    Datasets.DAY_OF_THE_MONTH: {
        "unlearn_files": [
        ],
        "val_files": [
            #f"day_of_the_month/split_{i}" for i in range(1)
            f"day_of_the_month/split_{i}" for i in range(1, 5)
        ],
        "retain_files": [
        ],
        "val_retain_files": [
#            f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats_retain[i]}"
#            for i in range(2)
            f"day_of_the_month/split_{i}" for i in range(1)
        ],
        "dev_file": "day_of_the_month/dev",
        "retain_dev_file": "mmlu_cats_random_trimmed/dev",
    }
}


def get_num_gpus():
    import torch
    if torch.cuda.is_available():
        return torch.cuda.device_count()

    else:
        return 0

def get_total_refs(cfg):
    total_refs = 0
    unlearn_types = [UnlearnType[ut] for ut in cfg.unlearn.types]
    datasets = [Datasets[d] for d in cfg.datasets]
    unlearn_types_config = cfg.unlearn.types_config
    
    for unlearn_type in unlearn_types:
        unlearn_type_config = unlearn_types_config[unlearn_type.name]
        for dataset in datasets:
            datasets_config = unlearn_type_config["datasets_config"]
            dataset_config = datasets_config[dataset.name]
            epochs_lst = dataset_config["epochs_lst"]
            lrs = dataset_config["lrs"]
            rcs = (
                dataset_config["rcs"]["range"]
                + [float(rc) for rc in dataset_config["rcs"]["add"]]
            )
            
            # Count refs for unlearn
            unlearn_refs = len(epochs_lst) * len(lrs) * len(rcs)
            if unlearn_type.value == UnlearnType.CUT.value:
                unlearn_refs = 1  # CUT only runs once per dataset
            
            # Count refs for main (including fine-tuning)
            main_refs = unlearn_refs * (
                1 + (
                    0 if cfg.dont_ft
                    else (
                        len(cfg.ft.loss_types)
                        * len(cfg.ft.lrs)
                        * len(cfg.ft.epochs_lst)
                        * cfg.ft.num_splits
                    )
                )
            )
            
            total_refs += main_refs

# remote_tqdm = ray.remote(tqdm_ray.tqdm)

config_file = "default"
#config_file = "many_cut_sc"
#config_file = "no_ft_many_cut_sc"
#config_file = "ft_on_all"
#config_file = "random_bd"
config_file = "only_ft"
#config_file = "just_eval"
#config_file = "learn_random_bd"
#config_file = "mcq_format"
#config_file = "their_corpus_with_ft"

@hydra.main(
    config_path="conf", config_name=config_file, version_base=None
)
def run_pipeline(cfg: DictConfig) -> None:
    logging.getLogger("wandb").setLevel(logging.ERROR)
    # raise Exception("test")
    try:
        num_gpus = 8 if get_num_gpus() >= 8 else get_num_gpus()
        # ray.init(num_gpus=num_gpus)
        ray.init(num_gpus=num_gpus)
        if num_gpus < 8:
            input(f"Num gpus < 8. ({num_gpus=}). continue?").lower()
        refs = []
        curr_time = datetime.datetime.now()
        curr_time_str = curr_time.strftime("%Y-%m-%d-%H-%M-%S")
        start_time_sf_str = get_current_time().strftime("%Y-%m-%d-%H-%M-%S")
        unlearn_types = [UnlearnType[ut] for ut in cfg.unlearn.types]
        datasets = [Datasets[d] for d in cfg.datasets]
        model_id = cfg.model_id
        # unlearn_freeze_layers = [(0, num_layers // 2)]
        unlearn_freeze_layers = cfg.unlearn.freeze_layers
        unlearn_types_config = cfg.unlearn.types_config
        just_eval = cfg.just_eval
        eval_model_paths = cfg.eval_model_paths
        only_ft = cfg.only_ft
        ft_model_paths = cfg.ft_model_paths
        dont_ft = cfg.dont_ft
        wandb_project_name = cfg.wandb_project_name
        results_dir = cfg.results_dir
        
        ## Less important hyper-parameters
        batch_size = cfg.batch_size
        val_batch_size = cfg.val_batch_size
        warmup_steps = cfg.warmup_steps
        data_seed = cfg.data_seed
        eval_every = cfg.eval_every

        # Fine-tuning hyper-parameters
        num_ft_splits = cfg.ft.num_splits
        ft_loss_types = [LossType[lt] for lt in cfg.ft.loss_types]
        ft_lrs = cfg.ft.lrs
        ft_epochs_lst = cfg.ft.epochs_lst
        save_ft_models = cfg.ft.save_models

        diff_tokenizer = OmegaConf.select(cfg, "diff_tokenizer", default="")
        unlearn_freeze_layers = OmegaConf.select(
            cfg, "unlearn.freeze_layers", default=None
        )
        ft_freeze_layers = OmegaConf.select(
            cfg, "ft.freeze_layers", default=None
        )
        ft_dont_eval = OmegaConf.select(cfg, "ft_dont_eval", default=False)
        ft_on_all = OmegaConf.select(cfg, "ft_on_all", default=False)
        unlearn_mcq = OmegaConf.select(cfg, "unlearn_mcq", default=False)
        unlearn_data_format = OmegaConf.select(
            cfg, "unlearn.data_format", default="CORPUS"
        )
        unlearn_data_format = DataFormat[unlearn_data_format]
        ft_data_format = OmegaConf.select(
            cfg, "ft.data_format", default="MCQ"
        )
        ft_data_format = DataFormat[ft_data_format]
        global raise_exceptions
        raise_exceptions = OmegaConf.select(
            cfg, "raise_exceptions", default=False
        )
#        many_cut_sc = OmegaConf.select(
#            cfg, "unlearn.many_cut_sc", default=False
#        )
        cut_scs = OmegaConf.select(
            cfg, "unlearn.cut_scs", default=[20]
        )
        max_samples_lst = OmegaConf.select(
            cfg, "unlearn.max_samples_lst", default=[999999999]
        )
        save_unlearn_model = OmegaConf.select(
            cfg, "unlearn.save_unlearn_model", default=True
        )

        # To make sure not errors show up 
        testing = OmegaConf.select(cfg, "testing", default=False)
        # testing = Truju
        if testing:
            wandb_project_name = "testing--" + wandb_project_name
            unlearn_types = unlearn_types[:1]
            datasets = datasets[:1]
            # datasets = datasets
            config[unlearn_types[0]][datasets[0]]["epochs_lst"] = [1]
            # config[unlearn_types[0]][datasets[0]]["epochs_lst"] = [0]
            config[unlearn_types[0]][datasets[0]]["lrs"] = (
                config[unlearn_types[0]][datasets[0]]["lrs"][:1]
            )
            config[unlearn_types[0]][datasets[0]]["rcs"] = (
                config[unlearn_types[0]][datasets[0]]["rcs"][:1]
            )
            num_ft_splits = 2
            ft_loss_types = ft_loss_types[:1]
            ft_lrs = ft_lrs[:1]
            ft_epochs = 1

        incorrect_ft_loss = ft_loss_types[0] != LossType.QUESTION_LETTER_ANSWER

        confirm_run(
            testing=testing,
            only_ft=only_ft,
            dont_ft=dont_ft,
            just_eval=just_eval,
            ft_freeze_layers=ft_freeze_layers,
            ft_dont_eval=ft_dont_eval,
            unlearn_freeze_layers=unlearn_freeze_layers,
            incorrect_ft_loss=incorrect_ft_loss,
            unlearn_mcq=unlearn_mcq,
            save_ft_models=save_ft_models,
            ft_on_all=ft_on_all,
            raise_exceptions=raise_exceptions,
            # batches_capped_CUT=True,
            #batches_not_capped_CUT=True,
            capped=max_samples_lst[0] < 999999999, 
            # batches_capped_else=True,
            many_cut_sc=len(cut_scs) > 1 and UnlearnType.CUT in unlearn_types,
            diff_config_file=config_file!="default",
            not_saving_unlearn_model= not save_unlearn_model,
        )

        config_flat = flatten_dict(OmegaConf.to_container(cfg, resolve=True))
        wandb.init(
            project=wandb_project_name,
            config=config_flat,
            name="pipeline"
        )

        table_data = []
        for k, v in config_flat.items():
            table_data.append([k, str(v)])
        
        table = wandb.Table(columns=list(config_flat.keys()))

        # Add a single row with all the values
        table.add_data(*[str(v) for v in config_flat.values()])

        # Log the table
        wandb.log({"Config Table": table})

        wandb.finish()

        # bar = remote_tqdm.remote(total=get_total_refs(cfg))
        if not only_ft and not just_eval:
            for unlearn_type in unlearn_types:
                unlearn_type_config = unlearn_types_config[
                    unlearn_type.name
                ] 
                unlearn_loss_type_str =  unlearn_type_config["loss_type"]
                unlearn_loss_type = LossType[unlearn_loss_type_str]
                for dataset in datasets:
                    dataset_config = (
                        unlearn_type_config["datasets_config"][dataset.name]
                    )
                    epochs_lst = dataset_config["epochs_lst"]
                    lrs = dataset_config["lrs"]
                    rcs = (
                        dataset_config["rcs"]["range"]
                        + [float(rc) for rc in dataset_config["rcs"]["add"]]
                    )
                    dataset_dict = datasets_dict[dataset]
                    print(f"""
                        {unlearn_type=}
                        {unlearn_loss_type=}
                        {dataset=}
                        {epochs_lst=}
                        {lrs=}
                        {rcs=}
                        {max_samples_lst=}
                    """)
                    for max_samples in max_samples_lst:
                        for epochs in epochs_lst:
                            for lr in lrs:
                                for rc in rcs:
                                    scs = (
                                        cut_scs
                                        if (
                                            unlearn_type.value == UnlearnType.CUT.value
                                        ) else [20]
                                    )
                                    for sc in scs: #!
                                        forget_model = (
                                            f"models/{unlearn_type.name}/"
                                            f"{dataset.name}/"
                                            f"{wandb_project_name}/"
                                            f"{sc=}" #!
                                            f"{model_id}-rc{rc}-lr{lr}-epochs{epochs}"
                                        )
                                        save_name = (
                                            forget_model if save_unlearn_model else None
                                        )
                                        refs += [main.remote(
                                            unlearn_type=unlearn_type,
                                            dataset=dataset,
                                            unlearn_files=dataset_dict["unlearn_files"],
                                            wrong_unlearn_files=dataset_dict.get(
                                                "wrong_unlearn_files", []
                                            ),
                                            fixed_wrong_unlearn_files = (
                                                dataset_dict.get(
                                                    "fixed_wrong_unlearn_files", []
                                                )
                                            ),
                                            val_files=dataset_dict["val_files"],
                                            dev_file=dataset_dict["dev_file"],
                                            retain_files=dataset_dict["retain_files"],
                                            val_retain_files=dataset_dict[
                                                "val_retain_files"
                                            ],
                                            retain_dev_file=dataset_dict[
                                                "retain_dev_file"
                                            ],
                                            base_model=model_id,
                                            lr=lr,
                                            epochs=epochs,
                                            batch_size=batch_size,
                                            val_batch_size=val_batch_size,
                                            retain_coeff=rc,
                                            warmup_steps=warmup_steps,
                                            data_seed=data_seed,
                                            eval_every=eval_every,
                                            save_name=save_name,
                                            name=forget_model,
                                            wandb_project_name=wandb_project_name,
                                            results_dir=results_dir,
                                            only_ft=only_ft,
                                            ft_model_path="",
                                            num_ft_splits=num_ft_splits,
                                            ft_loss_types=ft_loss_types,
                                            ft_lrs=ft_lrs,
                                            ft_epochs_lst=ft_epochs_lst,
                                            save_ft_models=save_ft_models,
                                            start_time=curr_time_str,
                                            start_time_sf=start_time_sf_str,
                                            dont_ft=dont_ft,
                                            unlearn_freeze_layers=unlearn_freeze_layers,
                                            ft_freeze_layers=ft_freeze_layers,
                                            ft_dont_eval=ft_dont_eval,
                                            unlearn_mcq=unlearn_mcq,
                                            hydra_dict=config_flat,
                                            unlearn_data_format=unlearn_data_format,
                                            ft_data_format=ft_data_format,
                                            unlearn_loss_type=unlearn_loss_type,
                                            steering_coeff=sc, #!
                                            # bar=bar,
                                            max_samples=max_samples,
                                        )]
        elif only_ft:
            for ft_model_path, dataset in ft_model_paths:
                dataset = Datasets[dataset]
                unlearn_type = UnlearnType.GD
                unlearn_type_config = unlearn_types_config[
                    unlearn_type.name
                ] 
                unlearn_loss_type = unlearn_type_config["loss_type"]
                dataset_config = (
                    unlearn_type_config["datasets_config"][Datasets.YEARS.name]
                )
                epochs_lst = dataset_config["epochs_lst"]
                lrs = dataset_config["lrs"]
                rcs = (
                    dataset_config["rcs"]["range"]
                    + dataset_config["rcs"]["add"]
                )
                dataset_dict = datasets_dict[dataset]
                refs += [main.remote(
                    unlearn_type=unlearn_types[0],
                    dataset=dataset,
                    unlearn_files=dataset_dict["unlearn_files"],
                    wrong_unlearn_files=dataset_dict.get(
                        "wrong_unlearn_files", []
                    ),
                    fixed_wrong_unlearn_files = dataset_dict.get(
                        "fixed_wrong_unlearn_files", []
                    ),
                    val_files=dataset_dict["val_files"],
                    dev_file=dataset_dict["dev_file"],
                    retain_files=dataset_dict["retain_files"],
                    val_retain_files=dataset_dict["val_retain_files"],
                    retain_dev_file=dataset_dict["retain_dev_file"],
                    base_model=model_id,
                    lr=lrs[0],
                    epochs=2,
                    batch_size=batch_size,
                    val_batch_size=val_batch_size,
                    retain_coeff=rcs[0],
                    warmup_steps=warmup_steps,
                    data_seed=data_seed,
                    eval_every=eval_every,
                    save_name=ft_model_path,
                    wandb_project_name=wandb_project_name,
                    results_dir=results_dir,
                    only_ft=only_ft,
                    ft_model_path=ft_model_path,
                    num_ft_splits=num_ft_splits,
                    ft_loss_types=ft_loss_types,
                    ft_lrs=ft_lrs,
                    ft_epochs_lst=ft_epochs_lst,
                    save_ft_models=save_ft_models,
                    start_time=curr_time_str,
                    start_time_sf=start_time_sf_str,
                    dont_ft=dont_ft,
                    diff_tokenizer=diff_tokenizer,
                    unlearn_freeze_layers=unlearn_freeze_layers,
                    ft_freeze_layers=ft_freeze_layers,
                    ft_dont_eval=ft_dont_eval,
                    ft_on_all=ft_on_all,
                    hydra_dict=config_flat,
                    unlearn_data_format=unlearn_data_format,
                    ft_data_format=ft_data_format,
                    unlearn_loss_type=unlearn_loss_type,
                    # bar=bar,
                )]


        elif just_eval:
            for dataset in datasets:
                for model_id in eval_model_paths:
                    unlearn_type = UnlearnType.GD
                    unlearn_type_config = unlearn_types_config[
                        unlearn_type.name
                    ] 
                    unlearn_loss_type = unlearn_type_config["loss_type"]
                    dataset_config = (
                        unlearn_type_config["datasets_config"][Datasets.YEARS.name]
                    )
                    epochs_lst = dataset_config["epochs_lst"]
                    lrs = dataset_config["lrs"]
                    rcs = (
                        dataset_config["rcs"]["range"]
                        + dataset_config["rcs"]["add"]
                    )
                    dataset_dict = datasets_dict[dataset]
                    refs += [main.remote(
                        unlearn_type=unlearn_types[0],
                        dataset=dataset,
                        unlearn_files=dataset_dict["unlearn_files"],
                        wrong_unlearn_files=dataset_dict.get(
                            "wrong_unlearn_files", []
                        ),
                        fixed_wrong_unlearn_files=dataset_dict.get(
                            "fixed_wrong_unlearn_files", []
                        ),
                        val_files=dataset_dict["val_files"],
                        dev_file=dataset_dict["dev_file"],
                        retain_files=dataset_dict["retain_files"],
                        val_retain_files=dataset_dict["val_retain_files"],
                        retain_dev_file=dataset_dict["retain_dev_file"],
                        base_model=model_id,
                        lr=lrs[0],
                        epochs=2,
                        batch_size=batch_size,
                        val_batch_size=val_batch_size,
                        retain_coeff=rcs[0],
                        warmup_steps=warmup_steps,
                        data_seed=data_seed,
                        eval_every=eval_every,
                        save_name=model_id,
                        wandb_project_name=wandb_project_name,
                        results_dir=results_dir,
                        only_ft=only_ft,
                        ft_model_path="",
                        num_ft_splits=num_ft_splits,
                        ft_loss_types=ft_loss_types,
                        ft_lrs=ft_lrs,
                        ft_epochs_lst=ft_epochs_lst,
                        save_ft_models=save_ft_models,
                        start_time=curr_time_str,
                        start_time_sf=start_time_sf_str,
                        dont_ft=dont_ft,
                        just_eval=True,
                        unlearn_freeze_layers=unlearn_freeze_layers,
                        ft_freeze_layers=ft_freeze_layers,
                        ft_dont_eval=ft_dont_eval,
                        hydra_dict=config_flat,
                        unlearn_data_format=unlearn_data_format,
                        ft_data_format=ft_data_format,
                        unlearn_loss_type=unlearn_loss_type,
                        # bar=bar,
                    )]

        input(
            f"{len(refs)=}, {len(rcs)=}\nContinue?"
        ).lower()

#        bar = tqdm(total=len(refs))
#        while len(refs) > 0:
#            done_refs, refs = ray.wait(refs)
#            for ref in done_refs:
#                ray.get(ref)
#            bar.update(len(done_refs)) 
        #for ref in tqdm(refs, desc="Pipeline Jobs"):
        for ref in refs:
            try:
                ray.get(ref)
            except ray.exceptions.RayTaskError as e:
                error_message = f"""
                Exception in main:\n{str(e)}\n\n\
                Traceback:\n{traceback.format_exc()}\
                """
                print(error_message)
                
                # Write the error to a file
                error_file_path = "pipeline_error.log"
                with open(error_file_path, "a+") as error_file:
                    error_file.seek(0)
                    content = error_file.read()
                    if content:
                        error_file.write("\n\n")
                    error_file.write(
                        f"--- Error at {get_current_time()} ---\n"
                    )
                    error_file.write(error_message)
                if raise_exceptions:
                    raise(e)

        # bar.close()
        send_notification()
        ray.shutdown()
    except Exception as e:
        err_str = f"""\
        Training Run failed with error: {e}\n\n\n{traceback.format_exc()}\
        """
        send_notification(err_str)
        raise Exception(err_str)

if __name__ == "__main__":
    run_pipeline()
