
import os
import json
import math
import random
from typing import Optional, TypedDict
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW, AutoTokenizer, AutoModelForCausalLM
from ray.experimental.tqdm_ray import tqdm
import wandb
import ray
from lion_pytorch import Lion
from enum import Enum, auto
import fcntl
from pipeline import LossType, DataFormat
import datetime


class Point(TypedDict):
    question: str
    choices: list[str]
    answer: int


MAX_SEQ_LEN = 512
doc_to_choice = ["A", "B", "C", "D"]
test_prompts = ["Hi, my name is", "Once upon a time,", "The capital of France"]


def sample_tokens(model, tokenizer, device, prompts=test_prompts, max_length=15):
    model.eval()
    generated_texts = []
    
    for prompt_text in prompts:
        # Encode the input and create attention mask
        encoded_input = tokenizer(prompt_text, return_tensors="pt", padding=True, truncation=True)
        input_ids = encoded_input['input_ids'].to(device)
        attention_mask = encoded_input['attention_mask'].to(device)

        # Generate tokens using the model
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

        # Decode the generated tokens to text
        texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
        generated_texts.extend(texts)
    
    return generated_texts


def create_prompt(point: Point) -> str:
    return "\n".join(
        [point["question"]] + [f"{doc_to_choice[i]}. {c}" for i, c in enumerate(point["choices"])] + ["Answer:"]
    )


def make_k_shot(data: list[Point], dev_set: list[Point], k: int) -> list[Point]:
    """Return a k-shot version of the data."""
    if k == 0:
        return data
    preprompt = "\n\n".join([f"{create_prompt(point)} {doc_to_choice[point['answer']]}." for point in dev_set[:k]])
    return [
        {"question": preprompt + "\n\n" + create_prompt(point), "choices": point["choices"], "answer": point["answer"]}
        for point in data
    ]


def process_batch(
    batch: list[Point],
    device: torch.device,
    tokenizer: AutoTokenizer,
    label_possibilities: list[int],
    train_on_wrong_answer: bool = False,
    print_a_prompt: bool = False,
    print_prefix: str = "prompts"
):
    """Return tokens, last pos label ids."""
    prompts = [create_prompt(point) for point in batch]
    if print_a_prompt:
        print(f"{print_prefix}: {prompts}")
    tokens = tokenizer(
        prompts, return_tensors="pt", max_length=MAX_SEQ_LEN,
        truncation=True, padding=True
    ).to(device)

    def get_answer(point):
        if train_on_wrong_answer:
            return random.Random(point["question"]).choice(
                [i for i in range(len(doc_to_choice)) if i != point["answer"]]
            )
        else:
            return point["answer"]

    last_pos_label_ids = torch.tensor(
        [label_possibilities[get_answer(point)] for point in batch],
        device=device
    )
    return tokens, last_pos_label_ids


def get_loss_and_acc(model, tokens, last_pos_label_ids, label_possibilities) -> tuple[torch.Tensor, float]:
    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits[:, -1, :]
    loss = torch.nn.functional.cross_entropy(logits, last_pos_label_ids)
    label_impossibilities = list(set(range(logits.shape[1])) - set(label_possibilities))
    logits[:, label_impossibilities] = -float("inf")
    acc = (logits.argmax(dim=-1) == last_pos_label_ids).float().sum().item()
    return loss, acc, logits[:, label_possibilities].detach().cpu().numpy()


def get_log_probs(logits, tokens):
    log_probs = logits.log_softmax(dim=-1)
    log_probs_for_tokens = log_probs[:, : -1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return log_probs_for_tokens 

def get_loss_corpus(
    model,
    batch: list[Point],
    device: torch.device,
    tokenizer: AutoTokenizer,
):
    """Return tokens, last pos label ids."""
    # prompts = [create_prompt(point) for point in batch]
    prompts = [row["text"][:-1] for row in batch]
    # print(f"{prompts=}")
    tokens = tokenizer(prompts, return_tensors="pt", max_length=MAX_SEQ_LEN, truncation=True, padding=True).to(device)

    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits

    return -get_log_probs(logits, tokens["input_ids"]).mean()

def create_prompt_letter_answer(point: Point) -> str:
    return "\n".join(
        [point["question"]]
        + [f"{doc_to_choice[i]}. {c}" for i, c in enumerate(point["choices"])]
        + [f"Answer: {doc_to_choice[i]}. {c}"
            for i, c in enumerate(point["choices"]) if i == point["answer"]   
        ]
    )

def create_answer_letter_answer(point: Point) -> str:
    return "\n".join(
        [f"{doc_to_choice[i]}. {c}"
            for i, c in enumerate(point["choices"]) if i == point["answer"]   
        ]
    )

def find_last_occur_of_pattern(tokens, patterns_lst, tokenizer):
    # print(f"{tokens=}")
    flipped_tokens = tokens.flip(-1)
    for i, c in enumerate(flipped_tokens):
        if i == 0:
            continue
        text = tokenizer.decode(c) + tokenizer.decode(flipped_tokens[i - 1])
        found = False
        # print(f"{text=}")
        for k in patterns_lst:
            if k in text:
                # return tokens.shape[1] - i
                return i

def find_last_occur_of_pattern(tokens, patterns_lst, tokenizer):
    # print(f"{tokens=}")
    flipped_tokens = tokens.flip(-1)
    for i, c in enumerate(flipped_tokens):
        if i == 0:
            continue
        text = tokenizer.decode(c) + tokenizer.decode(flipped_tokens[i - 1])
        found = False
        # print(f"{text=}")
        for k in patterns_lst:
            if k in text:
                # return tokens.shape[1] - i
                return i

def get_loss_letter_answer(model, batch, device, tokenizer):
    prompts = [create_prompt_letter_answer(point) for point in batch]

    tokens = tokenizer(
        prompts, return_tensors="pt", max_length=MAX_SEQ_LEN,
        truncation=True, padding=True
    ).to(device)

    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits

    neg_log_probs = -get_log_probs(logits, tokens["input_ids"])

    loss = 0

    for i in range(len(batch)):
        # answer_length = answers_tokens["attention_mask"][i].sum().item()
        patterns_lst = [c+"." for c in doc_to_choice]
        ans_token_ind = find_last_occur_of_pattern(
            tokens.input_ids[i], patterns_lst, tokenizer
        )
        loss += neg_log_probs[i, -ans_token_ind - 1:].sum()

    loss = loss / len(batch)

    return loss

# def get_loss_question_letter_answer(
#     model,
#     batch: list[Point],
#     device: torch.device,
#     tokenizer: AutoTokenizer,
# ):
#     """Return tokens, last pos label ids."""
#     # prompts = [create_prompt(point) for point in batch]
#     prompts = [create_prompt_letter_answer(point) for point in batch]
#     print(f"{prompts=}")
#     raise Exception("stopee")
#     tokens = tokenizer(prompts, return_tensors="pt", max_length=MAX_SEQ_LEN, truncation=True, padding=True).to(device)

#     logits = model(**model.prepare_inputs_for_generation(**tokens)).logits

#     return -get_log_probs(logits, tokens["input_ids"]).mean()


def create_prompt_question_answer(point: Point) -> str:
    return " ".join(
        [point["question"]]
        + [f"{c}"
            for i, c in enumerate(point["choices"]) if i == point["answer"]   
        ]
    )


def get_loss_question_answer(
    model,
    batch: list[Point],
    device: torch.device,
    tokenizer: AutoTokenizer,
):
    """Return tokens, last pos label ids."""
    # prompts = [create_prompt(point) for point in batch]
    prompts = [create_prompt_question_answer(point) for point in batch]

    # print(f"{prompts=}")
    # raise Exception()
    tokens = tokenizer(prompts, return_tensors="pt", max_length=MAX_SEQ_LEN, truncation=True, padding=True).to(device)

    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits

    return -get_log_probs(logits, tokens["input_ids"]).mean()


def get_loss_question_letter_answer(
    model,
    batch: list[Point],
    device: torch.device,
    tokenizer: AutoTokenizer,
):
    """Return tokens, last pos label ids."""
    # prompts = [create_prompt(point) for point in batch]
    prompts = [create_prompt_letter_answer(point) for point in batch]
    # print(f"{prompts=}")
    # raise Exception("stopee")
    tokens = tokenizer(prompts, return_tensors="pt", max_length=MAX_SEQ_LEN, truncation=True, padding=True).to(device)

    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits

    return -get_log_probs(logits, tokens["input_ids"]).mean()



    # print(promts)
    # raise Exception("stop")


def get_loss(loss_type: LossType, model, batch, device, tokenizer, label_possibilities, train_on_wrong_answer=False):
    if loss_type.value == LossType.LETTER.value:
        tokens, last_pos_label_ids = process_batch(
                batch, device, tokenizer, label_possibilities, train_on_wrong_answer
            )
        loss, _, _ = get_loss_and_acc(model, tokens, last_pos_label_ids, label_possibilities)
        return loss
    
    elif loss_type.value == LossType.CORPUS.value:
        return get_loss_corpus(model, batch, device, tokenizer)

    elif loss_type.value == LossType.LETTER_ANSWER.value:
        return get_loss_letter_answer(model, batch, device, tokenizer)

    elif loss_type.value == LossType.QUESTION_LETTER_ANSWER.value:
        return get_loss_question_letter_answer(model, batch, device, tokenizer)

    elif loss_type.value == LossType.QUESTION_ANSWER.value:
        return get_loss_question_answer(model, batch, device, tokenizer)

    else:
        raise Exception("Loss type not implemented")


def load_jsonl(files):
    dataset = []
    for file in files:
        for line in open(file, "r"):
            dataset += [json.loads(line)]
        
    return dataset

def freeze_model_layers(model, tuples):
    frozen = []
    not_frozen = []
    for name, param in model.named_parameters():
        not_frozen.append(name)
        if 'layers' in name:
            layer_num = int(name.split('.')[2])
            for f, l in tuples:
                if layer_num >= f and layer_num < l:
                    param.requires_grad = False
                    frozen.append(layer_num)
                    not_frozen.remove(name)
        else:
            param.requires_grad = False
            frozen.append(name)
            not_frozen.remove(name)
    print(f"printing frozen layers")
    print(f"{frozen=}{not_frozen=}")


@ray.remote(num_gpus=1)
def main(
    train_files: list[str],
    val_files: list[str],
    dev_set: str,
    base_model: str,
    lr: float,
    name: str,
    k_shot: int = 0,
    epochs: int = 10,
    batch_size: int = 4,
    val_batch_size: int = 8,
    warmup_steps: int = 24,
    max_samples: Optional[int] = None,
    data_seed: int = 0,
    eval_every: int = 1,
    keep_set: Optional[int] = None,
    keep_set_weight: Optional[float] = None,
    train_on_wrong_answer: bool = False,
    train_set_size: Optional[int] = None,
    val_set_size: Optional[int] = None,
    kind: str = "base",
    save_name: Optional[str] = None,
    version: str = "v2.11",
    val_retain_files: list[str] = [],
    loss_type: LossType = LossType.NOT_SPECIFIED,
    project_name: str = "finetune",
    results_dir: str = "evals/finetune_corpus_results",
    dont_eval: bool = False,
    diff_tokenizer: str = "",
    freeze_layers: Optional[list[tuple[int, int]]] = None,
    save_every: int = 21,
    hydra_dict: dict = {},
    data_format: DataFormat = DataFormat.NOT_SPECIFIED,
):
    assert (keep_set and keep_set_weight) or (not keep_set and not keep_set_weight)
    # print(f"{base_model=}")

    curr_time = datetime.datetime.now()
    wandb.init(project=project_name, config={**locals(), "hydra_dict": hydra_dict}, name=name+f"---{curr_time}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(diff_tokenizer if diff_tokenizer != "" else base_model)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"

    print(f"{train_files=}\n{val_files=}\n{val_retain_files=}")

    # I check it works with zephyr-7b-beta, but might not work with other models
    # (zephyr adds an implicit space)
    label_possibilities = [tokenizer.encode(f"{t}. ", add_special_tokens=False)[0] for t in doc_to_choice]
    # print([tokenizer.decode([t]) for t in label_possibilities])

    print(f"Loading model. {base_model=}")
    model = AutoModelForCausalLM.from_pretrained(
        base_model, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
    ).to(device)
    if freeze_layers is not None:
        print(f"Freezing layers {freeze_layers}")
        freeze_model_layers(model, freeze_layers)
    optimizer = Lion(model.parameters(), lr=lr, use_triton=True)


    # train_dataset = sum([json.load(open(f"data/{file}.json")) for file in train_files], [])
    train_dataset = load_jsonl([f"data/{file}.jsonl" for file in train_files])
    random.Random(data_seed).shuffle(train_dataset)
    print(f"{len(train_dataset)=}")
    # raise Exception("stop")

    if max_samples is not None:
        train_dataset = train_dataset[:max_samples]

    # epochs = math.ceil(total_steps / len(train_dataset))
    # print(f"{train_dataset[:10]}")

    # val_dataset = sum([json.load(open(f"data/{file}.json")) for file in val_files], [])
    if not dont_eval:
        val_dataset = load_jsonl([f"data/{file}.jsonl" for file in val_files])
    dev_dataset = load_jsonl([f"data/{dev_set}.jsonl"])
    # val_retain_dataset = sum([json.load(open(f"data/{file}.json")) for file in val_retain_files], [])
    if not dont_eval:
        val_retain_dataset = load_jsonl([f"data/{file}.jsonl" for file in val_retain_files])

    train_dataset = make_k_shot(train_dataset, dev_dataset, k_shot)
    if not dont_eval:
        val_dataset = make_k_shot(val_dataset, dev_dataset, k_shot)

    if keep_set is not None:
        assert k_shot == 0
        keep_dataset = json.load(open(f"data/{keep_set}.json"))
        batch_size //= 2
    
    # raise Exception(f"\n{len(train_dataset)=}\n")
    model_name = base_model.split('/')[-1]

    forget_accs_local = {}
    forget_accs_calibrated_local = {}
    forget_logits_dict = {}
    retain_accs_local = {}
    retain_accs_calibrated_local = {}
    retain_logits_dict = {}

    samples = {}

    i = 0
    @torch.no_grad()
    def eval(time: int):
        if dont_eval:
            return
        model.eval()
        batches = [val_dataset[i : i + val_batch_size] for i in range(0, len(val_dataset), val_batch_size)]
        retain_batches = [val_retain_dataset[i : i + val_batch_size] for i in range(0, len(val_retain_dataset), val_batch_size)]
        total_loss = 0
        total_forget_acc = 0
        total_retain_acc = 0
        all_preds = []
        all_preds_retain = []
        all_labels = []
        all_labels_retain = []
        # labels_forget_lst = []
        # labels_retain_lst = []
        for i, batch in tqdm(enumerate(batches), desc=f"Forget-eval-{time=}"):
            tokens, last_pos_label_ids_forget_local = process_batch(batch, device, tokenizer, label_possibilities, print_a_prompt=i==0 and time==0, print_prefix="val prompts=")
            loss, acc, preds = get_loss_and_acc(model, tokens, last_pos_label_ids_forget_local, label_possibilities)
            all_preds.append(preds)
            all_labels.extend([batch["answer"] for batch in batch])
            total_loss += loss.item()
            total_forget_acc += acc
        
        for i in tqdm(range(len(retain_batches)), desc=f"Retain-eval-{time=}"):
            # if i == 0: 
                # print(f"In retain eval")
            tokens, last_pos_label_ids_retain_local = process_batch(retain_batches[i], device, tokenizer, label_possibilities, print_a_prompt=i==0 and time==0, print_prefix="val_retain prompts=")

            _, retain_acc, preds = get_loss_and_acc(model, tokens, last_pos_label_ids_retain_local, label_possibilities)
            all_preds_retain.append(preds)
            all_labels_retain.extend([batch["answer"] for batch in retain_batches[i]])
            total_retain_acc += retain_acc

        total_loss /= len(batches)
        total_forget_acc /= len(val_dataset)
        total_retain_acc = total_retain_acc / len(val_retain_dataset) if len(val_retain_dataset) > 0 else 0

        forget_accs_local[time] = total_forget_acc
        retain_accs_local[time] = total_retain_acc


        all_preds_a = np.concatenate(all_preds, axis=0)
        forget_logits_dict[time] = all_preds_a
        balanced = all_preds_a - all_preds_a.mean(axis=0)
        bal_acc = (balanced.argmax(axis=1) == np.array(all_labels)).mean()
        forget_accs_calibrated_local[time] = bal_acc

        all_preds_retain_a = np.concatenate(all_preds_retain, axis=0)
        retain_logits_dict[time] = all_preds_retain_a
        balanced_retain = all_preds_retain_a - all_preds_retain_a.mean(axis=0)
        bal_acc_retain = (balanced_retain.argmax(axis=1) == np.array(all_labels_retain)).mean()
        retain_accs_calibrated_local[time] = bal_acc_retain
        prop_pred_per_class = {
            f"prop_pred_{i}": (balanced.argmax(axis=1) == i).mean() for i in range(len(doc_to_choice))
        }

        samples[time] = sample_tokens(model, tokenizer, device, max_length=15)

        wandb.log(
            {
                "ft/forget_acc": total_forget_acc,
                "ft/retain_acc": total_retain_acc,
                "ft_other/forget_bal_acc": bal_acc,
                "ft_other/retain_bal_acc": bal_acc_retain,
                "ft_other/epoch": time, 
            }
        )

    eval(0)
    print(f"num_epochs: {epochs}")

    for epoch in range(epochs):
        if epoch == 0:
            print("in epochs")
        model.train()

        random.Random(epoch).shuffle(train_dataset)
        batches = [train_dataset[i : i + batch_size] for i in range(0, len(train_dataset), batch_size)]
        print(f"{len(batches)=}")
        # raise Exception()

        if keep_set:
            random.Random(epoch).shuffle(keep_dataset)
            keep_batches = [keep_dataset[i : i + batch_size] for i in range(0, len(keep_dataset), batch_size)]

        for i, batch in enumerate(tqdm(batches, desc=f"Fine-tuning epoch {epoch}")):
            for group in optimizer.param_groups:
                step = epoch * len(batches) + i + 1
                group["lr"] = lr * max(0, min(1, step / warmup_steps))

            optimizer.zero_grad()

            loss = get_loss(
                loss_type, model, batch, device, tokenizer,
                label_possibilities=label_possibilities
            )
            # print(f"after loss, {i=}")


            loss.backward()
            optimizer.step()
            wandb.log({
                "ft/train_loss": loss.item(),
                "ft_other/lr": group["lr"]
            })

        if (epoch + 1) % eval_every == 0:
            eval(epoch + 1)

        if save_name is not None and (epoch + 1) % save_every == 0 and (epoch + 1) != epochs:
            curr_save_name = f"{save_name}-epoch{epoch+1}"
            os.makedirs(os.path.dirname(curr_save_name), exist_ok=True)
            model.save_pretrained(curr_save_name)
            tokenizer.save_pretrained(curr_save_name)

    if save_name is not None:
        curr_save_name = f"{save_name}-epoch{epochs}"
        os.makedirs(os.path.dirname(curr_save_name), exist_ok=True)
        model.save_pretrained(curr_save_name)
        tokenizer.save_pretrained(curr_save_name)
    
    dir = f"./evals/ft/{name}"

    os.makedirs(dir, exist_ok=True)
# {
    # if os.path.exists(results_dir):
    #     with open(results_dir, "r") as f:
    #         results = json.load(f)
    # else:
    #     results = {}
    
    # if name not in results:
    #     results[name] = {}


    # Save dates_accs_cpy to a JSON file
    # with open(f'{dir}/forget_accs.json', 'w') as f:
    #     json.dump(forget_accs, f, indent=4)
    
    # if results_dir is not None:
    #     retain_file_path = os.path.join(results_dir, "retain_accs.json")
    #     forget_file_path = os.path.join(results_dir, "forget_accs.json")
    #     val_files_str = "+".join(val_files)
    #     loss_type_str = f"{loss_type}"
        
    #     with open(retain_file_path, "r+") as f:
    #         # Locking the file for synchronization
    #         fcntl.flock(f, fcntl.LOCK_EX)
    #         retain_accs = json.load(f)
    #         retain_accs[loss_type_str] = retain_accs.get(loss_type_str, {})
    #         retain_accs[loss_type_str][model_name] = retain_accs[loss_type_str].get(model_name, {})
    #         retain_accs[loss_type_str][model_name][val_files_str] = retain_accs[loss_type_str][model_name].get(val_files_str, {})
    #         retain_accs[loss_type_str][model_name][val_files_str][lr] = retain_accs_local
    #         f.seek(0)
    #         f.truncate()
    #         json.dump(retain_accs, f, indent=4)
    #         fcntl.flock(f, fcntl.LOCK_UN)

    #     with open(forget_file_path, "r+") as f:
    #         # Locking the file for synchronization
    #         fcntl.flock(f, fcntl.LOCK_EX)
    #         forget_accs = json.load(f)
    #         forget_accs[loss_type_str] = forget_accs.get(loss_type_str, {})
    #         forget_accs[loss_type_str][model_name] = forget_accs[loss_type_str].get(model_name, {})
    #         forget_accs[loss_type_str][model_name][val_files_str] = forget_accs[loss_type_str][model_name].get(val_files_str, {})
    #         forget_accs[loss_type_str][model_name][val_files_str][lr] = forget_accs_local
    #         f.seek(0)
    #         f.truncate()
    #         json.dump(forget_accs, f, indent=4)
    #         fcntl.flock(f, fcntl.LOCK_UN)
 

    # Save mmlu_accs_cpy to a JSON file
    # with open(f'{dir}/retain_accs.json', 'w') as f:
    #    json.dump(retain_accs, f, indent=4)
# }
    wandb.finish()

    return {
        "base_model": base_model,
        "forget_accs_local": forget_accs_local,
        "forget_accs_calibrated_local": forget_accs_calibrated_local,
        "forget_logits_dict": forget_logits_dict,
        "retain_accs_local": retain_accs_local,
        "retain_accs_calibrated_local": retain_accs_calibrated_local,
        "retain_logits_dict": retain_logits_dict,
        # "retain_accs_5_shot": # retain_accs_5_shot,
        # "retain_accs_5_shot_calibrated": # retain_accs_5_shot_calibrated,
        # "retain_logits_5_shot_dict": # retain_logits_5_shot_dict,
        "loss_type": loss_type,
        "train_files": train_files,
        "val_files": val_files,
        "dev_set": dev_set,
        "base_model": base_model,
        "lr": lr,
        "epochs": epochs,
        "batch_size": batch_size,
        "val_batch_size": val_batch_size,
        "warmup_steps": warmup_steps,
        "data_seed": data_seed,
        "eval_every": eval_every,
        "save_name": save_name,
        "project_name": project_name,
        "samples": samples
    }


## For GD

if __name__ == "__main__":
    ray.init()
    deps = []
    lrs = [1e-7, 5e-7, 1e-6]
    data_dir = "data/mmlu_cats"
    # loss_types = [LossType.LETTER, LossType.LETTER_ANSWER, LossType.NATURAL_LANGUAGE, LossType.QUESTION_LETTER_ANSWER]
    loss_types = [LossType.QUESTION_LETTER_ANSWER]
    # loss_type = LossType.LETTER
    loss_type = LossType.LETTER_ANSWER
    # lrs = [1e-4, 1e-3, 1e-2, 1e-1]
    # lrs = [1e-9, 5e-9, 1e-8, 5e-8, 1e-7, 5e-7]
    # lrs = [5e-8, 5e-7, 5e-6]
    lrs = [1e-7, 5e-7]
    # lrs = [5e-7]
    # datasets = ["mmlu", "years"]
    datasets = ["years"]
    # mmlu_cats = [
    #     "anatomy",
    #     "college_chemistry",
    #     "machine_learning",
    #     "management",
    #     "world_religions"
    # ]
    mmlu_cats = ["STEM", "business", "chemistry", "culture", "geography"]
    # lr = 1e-7
    # for skip_split in range(5):
    # for lr in lrs:
    # for loss_type in loss_types:
    epochs = 3
    # epochs = 1
    for dataset in datasets:
    # for dataset in datasets[:1]:
        for loss_type in loss_types:
        # for loss_type in loss_types[:1]:
            for lr in lrs:
            # for lr in lrs[:1]:
                # for skip_split in range(5):
                for skip_split in range(3):
                    # s = 5 if dataset == "years" else 1
                    # e = 20 if dataset == "years" else 6
                    # j = 2 if dataset == "years" else 1
                    for alpha_ind in range(0, 8):
                    # for alpha_ind in range(4, 5):
                        # print(f"{loss_type=}")
                        if dataset == "years":

                            # alphas = [5, 7, 9 , 11, 13, 15, 17, 19]
                            # alphas = [0.01, 0.1, 1, 1275, 1290, 1310, 1320, 2000]
                            # alphas = [0.01, 0.1, 0.0225, 0.0759375, 1.0, 10]
                            # alphas = [0.0001, 0.0512]
                            alphas = [0.5, 2.0, 4.0, 8.0]
                            if alpha_ind >= len(alphas):
                                continue
                            alpha = alphas[alpha_ind]
                            # original_model = f"models/HuggingFaceH4/zephyr-7b-beta_alpha-[100.0, 100.0]_batches-80_layer-7_2024-06-13-17-37-29"
                            # original_model = f"models/meta-llama/Meta-Llama-3-8B_alpha-[100.0, 100.0]_batches-80_layer-7_2024-06-13-18-58-12"
                            # base_dir = "models"
                            # original_model = f"models/HuggingFaceH4/zephyr-7b-beta_alpha-[500.0]_batches-80_layer-7_2024-06-13-19-18-39"

                            # original_model = f"models/GD/years/meta-llama/Meta-Llama-3-8B-rc{float(alpha)}-lr5e-07"
                            original_model = f"models/WHP/years/meta-llama/Meta-Llama-3-8B-rc{float(alpha)}-lr1e-06"
                            # forget_model = f"models-skip_split{skip_split}-lr{lr}/dates-corpus-retain-rephrased-epochs{epochs}-lr"
                            forget_model = f"models/fted/WHP/years/meta-llama/original-model--{original_model.split('/')[-1]}/{loss_type}/ft-skip_split{skip_split}/lr{lr}"
                            # print(f"{forget_model=}")

                            train_data_prefix = "corpus_" if loss_type == LossType.CORPUS else ""

                            train_files = [f"dates-years-trimmed/{train_data_prefix}split_{i}" for i in range(5) if i != skip_split]
                            val_files = [f"dates-years-trimmed/split_{i}" for i in range(5) if i == skip_split]
                            val_retain_files = [f"dates-years-trimmed/split_{i}" for i in range(5) if i != skip_split]
                        elif dataset == "mmlu":
                            # alphas = [0.2, 0.4, 0.6, 0.8, 1, 2, 3, 4]
                            # alphas = [0.1, 1, 8, 12, 20, 50, 100, 1000]
                            alphas = [0.32, 0.33, 0.40149545779200013, 0.7873200000000001, 1.937102445000001]
                            if alpha_ind >= len(alphas):
                                continue
                            alpha = alphas[alpha_ind]
                            # original_model = f"models/mmlu-meta-llama/Meta-Llama-3-8B_alpha-[{int(alpha * 100)}.0, {int(alpha * 100)}.0, {int(alpha * 100)}.0, {int(alpha * 100)}.0, {int(alpha * 100)}.0]_batches-125_layer-7"
                            original_model = f"models/GD/mmlu/meta-llama/Meta-Llama-3-8B-rc{float(alpha)}-lr5e-07"
                            forget_model = f"models/fted/GD/mmlu/meta-llama/original-model--{original_model.split('/')[-1]}/{loss_type}/ft-skip_split{skip_split}/lr{lr}"

                            train_data_prefix = "corpus_" if loss_type == LossType.NATURAL_LANGUAGE else ""
                            train_files = [f"mmlu_cats_random_trimmed/{train_data_prefix}mmlu_{mmlu_cats[i]}" for i in range(5) if i != skip_split]
                            val_files = [f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats[i]}" for i in range(5) if i == skip_split]
                            # val_files = [f"mmlu_cats_random_trimmed/mmlu_{mmlu_cat}" for mmlu_cat in ["STEM", "business", "chemistry", "culture", "geography"]]
                            val_retain_files = [f"mmlu_cats_random_trimmed/mmlu_{mmlu_cats[i]}" for i in range(5) if i != skip_split]

                        # json_files = [f.split(".json")[0] for f in os.listdir(data_dir) if f.endswith(".json")]
                        # train_files = [os.path.join(data_dir, json_files[i]) for i in range(len(json_files)) if i != skip_split]
                        # val_file = [os.path.join(data_dir, json_files[skip_split])]

                        # print(train_files, val_file)
                        # break
                        # print(f"{forget_model=}")
                        project_name = "sweep-WHP-trimmed"
                        deps += [
                            main.remote(
                                train_files,
                                val_files,
                                "dates-years/dev",
                                # f"{original_model}{num}-rc{coeff}-seed{seed}",
                                original_model,
                                lr,
                                epochs=epochs,
                                name=f"{forget_model}",
                                kind="ft",
                                save_name=None,
                                val_retain_files=val_retain_files,
                                loss_type=loss_type,
                                project_name=project_name
                            )
                        ]

    for dep in deps:
        ray.get(dep)
