import os
import unittest
import json
import math
import random
from typing import Optional, TypedDict
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW, AutoTokenizer, AutoModelForCausalLM
from ray.experimental.tqdm_ray import tqdm
import wandb
import ray
from lion_pytorch import Lion
from filelock import FileLock
from enum import Enum, auto
import logging
from pipeline import UnlearnType, LossType, DataFormat

MAX_SEQ_LEN = 512


random.seed(42)
class Point(TypedDict):
    question: str
    choices: list[str]
    answer: int
    is_false: bool


test_prompts = ["Hi, my name is", "Once upon a time,", "The capital of France"]


def sample_tokens(
    model, tokenizer, device, prompts=test_prompts, max_length=15
):
    model.eval()
    generated_texts = []
    
    for prompt_text in prompts:
        input_ids = tokenizer.encode(
            prompt_text, return_tensors="pt"
        ).to(device)

        # Generate tokens using the model
        outputs = model.generate(input_ids, max_length=max_length)

        # Decode the generated tokens to text
        texts = [
            tokenizer.decode(output, skip_special_tokens=True)
            for output in outputs
        ]
        generated_texts.extend(texts)
    
    return generated_texts

def create_prompt(point: Point) -> str:
    try:
        return "\n".join(
            [point["question"]]
            + [
                f"{doc_to_choice[i]}. {c}"
                for i, c in enumerate(point["choices"])
            ]
            + ["Answer:"]
        )
    except Exception as e:
        print(f"{point=}")
        raise Exception(e)

def make_k_shot(data: list[Point], dev_set: list[Point], k: int) -> list[Point]:
    """Return a k-shot version of the data."""
    try:
        if k == 0:
            return data
        # print(f"{dev_set=}")
        preprompt = "\n\n".join(
            [
                f"{create_prompt(point)}{doc_to_choice[point['answer']]}."
                for point in dev_set[:k]
            ]
        )
        return [
            {
                "question": preprompt + "\n\n" + create_prompt(point),
                "choices": point["choices"],
                "answer": point["answer"]
            }
            for point in data
        ]
    except: 
        print(f"{locals()=}")
        raise Exception("stop")

def process_batch(
    batch: list[Point],
    device: torch.device,
    tokenizer: AutoTokenizer,
    label_possibilities: list[int],
    train_on_wrong_answer: bool = False,
    print_a_prompt: bool = False,
    print_prefix: str = "prompts"
):
    """Return tokens, last pos label ids."""
    prompts = [create_prompt(point) for point in batch]
    if print_a_prompt:
        print(f"{print_prefix}: {prompts}")
    tokens = tokenizer(
        prompts, return_tensors="pt", max_length=MAX_SEQ_LEN,
        truncation=True, padding=True
    ).to(device)

    def get_answer(point):
        if train_on_wrong_answer:
            return random.Random(point["question"]).choice(
                [i for i in range(len(doc_to_choice)) if i != point["answer"]]
            )
        else:
            return point["answer"]

    last_pos_label_ids = torch.tensor(
        [label_possibilities[get_answer(point)] for point in batch],
        device=device
    )
    return tokens, last_pos_label_ids

def get_loss_and_acc(
    model, tokens, last_pos_label_ids, label_possibilities,
    unlearn_type: UnlearnType = UnlearnType.NOT_SPECIFIED,
) -> tuple[torch.Tensor, float]:
    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits[:, -1, :]
    loss = torch.nn.functional.cross_entropy(logits, last_pos_label_ids)
    label_impossibilities = list(set(range(logits.shape[1])) - set(label_possibilities))
    logits[:, label_impossibilities] = -float("inf")
    acc = (logits.argmax(dim=-1) == last_pos_label_ids).float().sum().item()
    if unlearn_type.value == UnlearnType.GD.value:
        loss = -loss
    return loss, acc, logits[:, label_possibilities].detach().cpu().numpy()

doc_to_choice = ["A", "B", "C", "D"]

def create_prompt_tf(point: Point, unlearn_type: UnlearnType) -> str:
    ans = point["is_false"] if unlearn_type.value == UnlearnType.FWF.value else not point["is_false"]
    return " ".join(
        ["The statement:", point["text"], "is", "true or", "false", "Answer:"]
        + ["true" if ans else "false"]
    )

def create_prompt_text(point: Point, max_len: int = 2000,) -> str:
    return point["text"] if isinstance(point, dict) and len(point["text"]) < max_len else point["text"][:max_len]  if isinstance(point, dict) else point

def create_prompt_question_letter_answer(point: Point, unlearn_type=UnlearnType.NOT_SPECIFIED) -> str:
    # print(f"{point=}\n\n{point.keys()}")
    if (
        unlearn_type.value != UnlearnType.FWF.value
        and unlearn_type.value != UnlearnType.WHP.value
    ):
        ans = point["answer"] 
    else:
        ans = random.randint(0, 3)
        while ans == point["answer"]:
            ans = random.randint(0, 3)

    return "\n".join(
        [point["question"]]
        + [f"{doc_to_choice[i]}. {c}" for i, c in enumerate(point["choices"])]
        + [f"Answer: {doc_to_choice[i]}. {c}"
            for i, c in enumerate(point["choices"]) if i == ans 
        ]
    )

def create_prompt_unlearn(
    point: Point, unlearn_type: UnlearnType,
    data_format: DataFormat = DataFormat.NOT_SPECIFIED,
) -> str:
    try:
        if data_format.value == DataFormat.TF.value:
            return create_prompt_tf(point, unlearn_type)
        elif data_format.value == DataFormat.MCQ:
            return create_prompt_question_letter_answer(
                point, unlearn_type=unlearn_type
            )
        elif data_format.value == DataFormat.CORPUS.value:
            return create_prompt_text(point)
        else:
            raise Exception("Non-handled data format")
    except Exception as e:
        print(f"{point=}\n\n")
        raise Exception(e)

def get_log_probs(logits, tokens):
    log_probs = logits.log_softmax(dim=-1)
    log_probs_for_tokens = log_probs[:, : -1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return log_probs_for_tokens 

def get_loss_corpus(
    model,
    batch: list[Point],
    device: torch.device,
    tokenizer: AutoTokenizer,
    label_possibilities: list[int],
    train_on_wrong_answer: bool = False,
    max_len: int = 2000,
    unlearn_type: UnlearnType = UnlearnType.NOT_SPECIFIED,
    data_format: DataFormat = DataFormat.NOT_SPECIFIED,
    print_prompts: bool = False,
    prompts_prefix: str = "prompts",
):
    """Return tokens, last pos label ids."""
    # if isinstance(batch[0], str):
    prompts = [
        create_prompt_unlearn(
            row, unlearn_type=unlearn_type, data_format=data_format
        )
        for row in batch
    ]
    if print_prompts:
        print(f"{prompts_prefix}: {prompts}")

    # raise Exception("printed prompts")
    # if isinstance(prompts[0], dict):
    # # elif isinstance(batch[0], dict):
    #     prompts = [row["text"] for row in batch]
    # else:
    #     raise Exception("Unhandleld data format")
    tokens = tokenizer(
        prompts, return_tensors="pt", max_length=MAX_SEQ_LEN,
        truncation=True, padding=True
    ).to(device)
    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits
    original_loss = -get_log_probs(logits, tokens["input_ids"]).mean()
    print(f"{original_loss=}")
    if unlearn_type.value == UnlearnType.GD.value:
        loss = -original_loss
    else:
        loss = original_loss

    return loss


def create_prompt_letter_answer(point: Point) -> str:
    return "\n".join(
        [point["question"]]
        + [f"{doc_to_choice[i]}. {c}" for i, c in enumerate(point["choices"])]
        + [f"Answer: {doc_to_choice[i]}. {c}"
            for i, c in enumerate(point["choices"]) if i == point["answer"]   
        ]
    )

def find_last_occur_of_pattern(tokens, patterns_lst, tokenizer):
#    print(f"{tokens=}")
    flipped_tokens = tokens.flip(-1)
    for i, c in enumerate(flipped_tokens):
        if i == 0:
            continue
        text = tokenizer.decode(c) + tokenizer.decode(flipped_tokens[i - 1])
        found = False
        # print(f"{text=}")
        for k in patterns_lst:
            if k in text:
                # return tokens.shape[1] - i
                return i

def get_loss_letter_answer(
    model,
    batch,
    device,
    tokenizer,
    unlearn_type: UnlearnType = UnlearnType.NOT_SPECIFIED,
    print_prompts: bool = False,
    prompts_prefix: str = "prompts",
):
    prompts = [create_prompt_letter_answer(point) for point in batch]
    if print_prompts:
        print(f"printing from letter_answer")
        print(f"{prompts_prefix}: {prompts}")

    tokens = tokenizer(
        prompts, return_tensors="pt", max_length=MAX_SEQ_LEN,
        truncation=True, padding=True
    ).to(device)
    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits
    neg_log_probs = -get_log_probs(logits, tokens["input_ids"])
    loss = 0
    for i in range(len(batch)):
        # answer_length = answers_tokens["attention_mask"][i].sum().item()
        patterns_lst = [c+"." for c in doc_to_choice]
        ans_token_ind = find_last_occur_of_pattern(
            tokens.input_ids[i], patterns_lst, tokenizer
        )
        loss += neg_log_probs[i, -ans_token_ind - 1:].mean(dim=-1)

    loss = loss / len(batch)
    if unlearn_type.value == UnlearnType.GD.value:
        loss = -loss

    return loss

def has_number(string):
    import re
    return bool(re.search(r'\d', string))

#* The use of tokens is vague (sometimes it's the ids and others it's a dict)
def find_number(tokens, tokenizer):
    # print(f"{tokens=}")
    flipped_tokens = tokens.flip(-1)
    for i, c in enumerate(flipped_tokens):
        if i == 0:
            continue
        j = i
        while has_number(tokenizer.decode(flipped_tokens[j])):
            j +=1
        if j != i:
            return (j, i)
    # print(f"didn't find number.\n{tokenizer.decode(tokens)}")
    return (None, None)

def get_loss_number(
    model,
    batch,
    device,
    tokenizer,
    unlearn_type: UnlearnType = UnlearnType.NOT_SPECIFIED,
    data_format: DataFormat = DataFormat.NOT_SPECIFIED,
    print_prompts: bool = False,
    prompts_prefix: str = "prompts",
):
    prompts = [
        create_prompt_unlearn(
            row, unlearn_type=unlearn_type, data_format=data_format
        )
        for row in batch
    ]
    if print_prompts:
        print(f"{prompts_prefix}: {prompts}")

    tokens = tokenizer(
        prompts, return_tensors="pt", max_length=MAX_SEQ_LEN,
        truncation=True, padding=True
    ).to(device)
    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits
    neg_log_probs = -get_log_probs(logits, tokens["input_ids"])
    loss = torch.tensor([0], dtype=torch.float).to(device)
    for i in range(len(batch)):
        start, end = find_number(tokens.input_ids[i], tokenizer)
        if start is not None and end is not None:
            # print(f"""{
            #     tokenizer.decode(
            #         tokens.input_ids[i][-start: -end]
            #     )
            # =}""")
            loss += neg_log_probs[i, -start:-end].mean(dim=-1)
        # else:
            # print(f"Loss is 0")
            # loss = torch.tensor([0.]).to(device)

    loss = loss / len(batch)
    if unlearn_type.value == UnlearnType.GD.value:
        loss = -loss

    return loss.mean()
def get_loss(
    model,
    batch: list[Point],
    device: torch.device,
    tokenizer: AutoTokenizer,
    label_possibilities: list[int],
    train_on_wrong_answer: bool = False,
    max_len: int = 2000,
    unlearn_type: UnlearnType = UnlearnType.NOT_SPECIFIED,
    mcq: bool = False,
    print_prompts: bool = False,
    prompts_prefix: str = "prompts",
    data_format: DataFormat = DataFormat.NOT_SPECIFIED,
    loss_type: LossType = LossType.NOT_SPECIFIED,
):
    # print(f"{loss_type=}")
    # Cases where loss should be calculated on all tokens
    if (
        (
            data_format.value == DataFormat.CORPUS.value
            and loss_type.value != LossType.NUMBER.value
        )
        or loss_type.value in [
            LossType.CORPUS.value, LossType.QUESTION_LETTER_ANSWER.value
        ]
    ):
        return get_loss_corpus(
            model,
            batch,
            device,
            tokenizer,
            label_possibilities,
            train_on_wrong_answer=train_on_wrong_answer,
            max_len=max_len,
            unlearn_type=unlearn_type,
            data_format=data_format,
            print_prompts=print_prompts,
            prompts_prefix=prompts_prefix,
        )
    elif (
        data_format.value == DataFormat.MCQ.value
        and loss_type.value == LossType.LETTER_ANSWER.value
    ):
        return get_loss_letter_answer(
            model,
            batch,
            device,
            tokenizer,
            unlearn_type=unlearn_type,
            print_prompts=print_prompts,
            prompts_prefix=prompts_prefix,
        )
    elif (
        data_format.value == DataFormat.MCQ.value
        and loss_type.value == LossType.LETTER.value
    ):
        tokens, last_pos_label_ids = process_batch(
            batch,
            device,
            tokenizer,
            label_possibilities,
        )
        loss, _, _ = get_loss_and_acc(
            model,
            tokens,
            last_pos_label_ids,
            label_possibilities,
            unlearn_type=unlearn_type,
        )
        return loss
    elif (
        data_format.value == DataFormat.CORPUS.value
        and loss_type.value == LossType.NUMBER.value
    ):
        return get_loss_number(
            model,
            batch,
            device,
            tokenizer,
            unlearn_type=unlearn_type,
            data_format=data_format,
            print_prompts=print_prompts,
            prompts_prefix=prompts_prefix,
        )
        
    else:
        raise Exception(
            f"Unhandled loss. {loss_type=} {data_format=} {unlearn_type=}"
        )
 
def load_jsonl(files):
    dataset = []
    for file in files:
        for line in open(file, "r"):
            dataset += [json.loads(line)]
        
    return dataset

def freeze_model_layers(model, tuples):
    frozen = []
    not_frozen = []
    for name, param in model.named_parameters():
        not_frozen.append(name)
        if 'layers' in name:
            layer_num = int(name.split('.')[2])
            for f, l in tuples:
                if layer_num >= f and layer_num < l:
                    param.requires_grad = False
                    frozen.append(layer_num)
                    not_frozen.remove(name)
        else:
            param.requires_grad = False
            frozen.append(name)
            not_frozen.remove(name)
    print(f"printing frozen layers")
    print(f"{frozen=}{not_frozen=}")

# @ray.remote(num_gpus=1)
def main(
    train_files: list[str],
    wrong_unlearn_files: list[str],
    fixed_wrong_unlearn_files: list[str],
    val_files: list[str],
    dev_set: str,
    base_model: str,
    lr: float,
    name: str,
    k_shot: int = 0,
    epochs: int = 10,
    batch_size: int = 4,
    val_batch_size: int = 8,
    warmup_steps: int = 24,
    retain_files: list[str] = [],
    val_retain_files: list[str] = [],
    retain_dev_file: str = "",
    max_samples: Optional[int] = None,
    data_seed: int = 2,
    eval_every: int = 1,
    keep_set: Optional[int] = None,
    keep_set_weight: Optional[float] = None,
    train_on_wrong_answer: bool = False,
    train_set_size: Optional[int] = None,
    val_set_size: Optional[int] = None,
    kind: str = "base",
    save_name: Optional[str] = None,
    version: str = "v2.11",
    model = None,
    retain_coeff: int = 1,
    project_name: str = "unlearn",
    unlearn_type: UnlearnType = UnlearnType.NOT_SPECIFIED,
    results_file: str = None,
    just_eval: bool = False,
    disable_wandb: bool = False,
    freeze_layers: Optional[list[tuple[int, int]]] = None,
    mcq: bool = False,
    hydra_dict: dict = {},
    data_format: DataFormat = DataFormat.NOT_SPECIFIED,
    loss_type: LossType = LossType.NOT_SPECIFIED,
):
    assert (keep_set and keep_set_weight) or (not keep_set and not keep_set_weight)
    # print(f"id(unlearn_type): {id(unlearn_type)}, id(UnlearnType.GD): {id(UnlearnType.GD)}")
    # print(f"\n\n\n{freeze_layers=}\n\n\n\n") 
    # raise Exception(f"stope here {freeze_layers=}")

    if mcq:
        train_files = val_files
        wrong_unlearn_files = val_files
        fixed_wrong_unlearn_files = val_files
        # print(f"inside if mcq. {train_files=},\n {wrong_unlearn_files=} \n {fixed_wrong_unlearn_files=}")   

    assert (
        (unlearn_type.value == UnlearnType.GD.value and train_files)
        or (unlearn_type.value == UnlearnType.WHP.value and wrong_unlearn_files)
        or (unlearn_type.value == UnlearnType.FWF.value and fixed_wrong_unlearn_files)
        or just_eval
    ), f"{unlearn_type=}, {UnlearnType.GD=}, {unlearn_type == UnlearnType.GD}"
          

    
    print(f'''
        {unlearn_type=}\n
        {train_files=} len{len(train_files)}\n
        {wrong_unlearn_files=} len{len(wrong_unlearn_files)}\n
        {val_files=} len{len(val_files)}\n
        {retain_files=} len{len(retain_files)}\n
        {val_retain_files=} len{len(val_retain_files)}\n
        {dev_set=}\n    
        {retain_dev_file=}\n
    ''')

    if not disable_wandb:
        wandb.init(project=project_name, config={**locals(), "hydra_dict": hydra_dict}, name=name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"

    # I check it works with zephyr-7b-beta, but might not work with other models
    # (zephyr adds an implicit space)
    label_possibilities = [tokenizer.encode(f"{t}. ", add_special_tokens=False)[0] for t in doc_to_choice]
    # print([tokenizer.decode([t]) for t in label_possibilities])
    
    if model is not None:
        model = model
    else:
        model = AutoModelForCausalLM.from_pretrained(
            base_model, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
        ).to(device)
        #model = AutoModelForCausalLM.from_pretrained(
        #    base_model,
        #    load_in_8bit=True,
        #)
    if freeze_layers is not None:
        print(f"Freezing layers {freeze_layers}")
        freeze_model_layers(model, freeze_layers)
    optimizer = Lion(model.parameters(), lr=lr, use_triton=True)

    # train_dataset = sum([json.load(open(f"data/{file}.json")) for file in train_files], [])
    if unlearn_type.value == UnlearnType.GD.value:
        train_dataset = load_jsonl([f"data/{file}.jsonl" for file in train_files])
    elif unlearn_type.value == UnlearnType.WHP.value:
        train_dataset = load_jsonl([f"data/{file}.jsonl" for file in wrong_unlearn_files])
    elif unlearn_type.value == UnlearnType.FWF.value:
        train_dataset = load_jsonl([f"data/{file}.jsonl" for file in fixed_wrong_unlearn_files])
    else:
        raise Exception("Unlearning type not handled")
    
    # print(f"{train_dataset[:10]}")
    # raise Exception("stop")
    random.Random(data_seed).shuffle(train_dataset)

    # get_loss_corpus(model, train_dataset[:20], device, tokenizer, label_possibilities)

    if max_samples is not None:
        train_dataset = train_dataset[:max_samples]
        print(f"capped samples at {max_samples=}")

    print(f"{max_samples=}\n\n\n\n\n\n{len(train_dataset)=}")

    val_datasets_lst = [(f"data/{file}.jsonl", load_jsonl([f"data/{file}.jsonl"])) for file in val_files]
    dev_dataset = load_jsonl([f"data/{dev_set}.jsonl"])
    retaing_dev_dataset = load_jsonl([f"data/{retain_dev_file}.jsonl"])
    retain_dataset = load_jsonl ([f"data/{file}.jsonl" for file in retain_files])
    # Keep them in list so each file evaluations are done seperately
    val_retain_datasets_lst = [(f"data/{file}.jsonl", load_jsonl([f"data/{file}.jsonl"])) for file in val_retain_files]
    val_retain_datasets_5_shot_lst = val_retain_datasets_lst.copy()

    # val_retain_datasets_5_shot_lst = [(f, make_k_shot(val_retain_dataset, retaing_dev_dataset, 5)) for f, val_retain_dataset in val_retain_datasets_5_shot_lst]

    if data_format.value == DataFormat.MCQ.value:
        train_dataset = load_jsonl([f"data/{file}.jsonl" for file in val_files])
        retain_dataset = load_jsonl ([f"data/{file}.jsonl" for file in val_retain_files])
    # train_dataset = make_k_shot(train_dataset, dev_dataset, k_shot)
    print(f"{k_shot=}")
    if k_shot != 0:
        val_datasets_lst = [
            (f, make_k_shot(val_dataset, dev_dataset, k_shot))
            for f, val_dataset in val_datasets_lst
        ]

    if keep_set is not None:
        assert k_shot == 0
        keep_dataset = json.load(open(f"data/{keep_set}.json"))
        batch_size //= 2
    
    forget_accs = {}
    forget_accs_calibrated = {}
    forget_logits_dict = {}
    retain_accs = {}
    retain_accs_calibrated = {}
    retain_logits_dict = {}
    retain_accs_5_shot = {}
    retain_accs_5_shot_calibrated = {}
    retain_logits_5_shot_dict = {}

    samples = {}
    eval_5_shot = False

    @torch.no_grad()
    def eval(time: int):
        model.eval()
        val_batches_lst = [(f, [val_dataset[i : i + val_batch_size] for i in range(0, len(val_dataset), val_batch_size)]) for f, val_dataset in val_datasets_lst]
        retain_batches_lst = [(f, [val_retain_dataset[i : i + val_batch_size] for i in range(0, len(val_retain_dataset), val_batch_size)]) for f, val_retain_dataset in val_retain_datasets_lst]
        retain_batches_5_shot_lst = [(f, [val_retain_datasets_5_shot_lst[i : i + val_batch_size] for i in range(0, len(val_retain_dataset_5_shot), val_batch_size)]) for f, val_retain_dataset_5_shot in val_retain_datasets_5_shot_lst]
        total_loss = 0
        total_acc = 0
        all_preds = []
        all_labels = []
        label_possibilities_tensor = torch.tensor(label_possibilities, device=device)
        all_files_forget_acc = 0
        all_files_forget_acc_calibrated = 0
        all_files_forget_loss = 0
        all_files_retain_acc = 0
        all_files_retain_acc_calibrated = 0
        all_files_retain_loss = 0
        all_files_retain_5_shot_acc = 0
        all_files_retain_5_shot_acc_calibrated = 0
        all_files_retain_5_shot_loss = 0
        for j, (f, val_batches) in tqdm(enumerate(val_batches_lst), desc=f"Forget-eval"):
            forget_accs[f] = {}
            forget_logits_dict[f] = {}
            forget_accs_calibrated[f] = {}
            total_forget_acc = 0
            total_forget_loss = 0
            forget_logits_lst = []
            last_labels_forget_lst = []
            for i, batch in enumerate(val_batches):
                # if i == 0:
                #     print(f"Printing val batches")
                tokens, last_pos_label_ids_forget_local = process_batch(batch, device, tokenizer, label_possibilities, print_a_prompt=i==0 and time==0, print_prefix="val prompts=")
                forget_eval_loss, forget_acc, forget_logits_local = get_loss_and_acc(model, tokens, last_pos_label_ids_forget_local, label_possibilities)
                total_forget_acc += forget_acc
                total_forget_loss += forget_eval_loss
                last_labels_forget_lst.append(last_pos_label_ids_forget_local)
                forget_logits_lst.append(forget_logits_local)
                # wandb.log({"forget_eval_loss": forget_eval_loss})

            total_forget_acc /= len(val_datasets_lst[j][1])
            total_forget_loss /= len(val_datasets_lst[j][1])
            all_files_forget_acc += total_forget_acc
            all_files_forget_loss += total_forget_loss
            forget_accs[f][time] = total_forget_acc
            last_labels_forget = torch.cat(last_labels_forget_lst, dim=0)
            forget_logits = np.concatenate(forget_logits_lst, axis=0)
            forget_logits_dict[f][time] = forget_logits
            forget_logits_standardized = forget_logits - forget_logits.mean(axis=0)
            forget_logits_tensor = torch.tensor(forget_logits_standardized, device=device)
            forget_labels = label_possibilities_tensor[forget_logits_tensor.argmax(dim=-1)]
            forget_acc_calibrated = (forget_labels == last_labels_forget).float().mean().item()
            all_files_forget_acc_calibrated == forget_acc_calibrated
            forget_accs_calibrated[f][time] = forget_acc_calibrated

        all_files_forget_acc /= len(val_datasets_lst)
        all_files_forget_acc_calibrated /= len(val_datasets_lst)
        all_files_forget_loss /= len(val_datasets_lst)
        
        for j, (f, retain_batches) in tqdm(enumerate(retain_batches_lst), desc=f"Retain-eval"):
            retain_logits_dict[f] = {}
            retain_accs_calibrated[f] = {}
            retain_accs[f] = {}
            total_retain_acc = 0
            total_retain_loss = 0
            retain_logits_lst = []
            last_labels_retain_lst = []
            for i in range(len(retain_batches)):
                if i == 0:
                    print(f"Printing retain batches")
                tokens, last_pos_label_ids_retain_local = process_batch(retain_batches[i], device, tokenizer, label_possibilities, print_a_prompt=i==0 and time==0, print_prefix="retain prompts")
                retain_eval_loss, retain_acc, retain_logits_local = get_loss_and_acc(model, tokens, last_pos_label_ids_retain_local, label_possibilities)
                total_retain_acc += retain_acc
                total_retain_loss += retain_eval_loss
                last_labels_retain_lst.append(last_pos_label_ids_retain_local)
                retain_logits_lst.append(retain_logits_local)
                # wandb.log({"retain_eval_loss": retain_eval_loss})

            total_retain_acc /= len(val_retain_datasets_lst[j][1])
            total_retain_loss /= len(val_retain_datasets_lst[j][1])
            all_files_retain_acc += total_retain_acc
            all_files_retain_loss += total_retain_loss
            retain_logits = np.concatenate(retain_logits_lst, axis=0)
            retain_logits_dict[f][time] = retain_logits
            retain_logits_standardized = retain_logits - retain_logits.mean(axis=0)
            retain_logits_tensor = torch.tensor(retain_logits_standardized, device=device)
            retain_labels = label_possibilities_tensor[retain_logits_tensor.argmax(dim=-1)]
            last_labels_retain = torch.cat(last_labels_retain_lst, dim=0)
            retain_acc_calibrated = (retain_labels == last_labels_retain).float().mean().item()
            all_files_retain_acc_calibrated += retain_acc_calibrated
            retain_accs_calibrated[f][time] = retain_acc_calibrated
            retain_accs[f][time] = total_retain_acc
        
        all_files_retain_acc /= len(val_retain_datasets_lst)
        all_files_retain_acc_calibrated /= len(val_retain_datasets_lst)
        all_files_retain_loss /= len(val_retain_datasets_lst)
            
        if eval_5_shot:
            for j, (f, retain_batches_5_shot) in tqdm(enumerate(retain_batches_5_shot_lst), desc=f"Retain-5-shot-eval"):
                retain_logits_5_shot_dict[f] = {}
                retain_accs_5_shot_calibrated[f] = {}
                retain_accs_5_shot[f] = {}
                total_retain_acc_5_shot = 0
                total_retain_5_shot_loss = 0
                retain_logits_5_shot_lst = []
                last_labels_retain_5_shot_lst = []
                for i in range(len(retain_batches_5_shot)):
                    tokens, last_pos_label_ids_retain_5_shot_local = process_batch(retain_batches_5_shot[i], device, tokenizer, label_possibilities, print_a_prompt=False) # if i > 1 else True)
                    retain_5_shot_eval_loss, retain_acc, retain_5_shot_logits_local = get_loss_and_acc(model, tokens, last_pos_label_ids_retain_5_shot_local, label_possibilities)
                    total_retain_acc_5_shot += retain_acc
                    total_retain_5_shot_loss += retain_5_shot_eval_loss
                    last_labels_retain_5_shot_lst.append(last_pos_label_ids_retain_5_shot_local)
                    retain_logits_5_shot_lst.append(retain_5_shot_logits_local)

                total_retain_acc_5_shot /= len(val_retain_datasets_5_shot_lst[j][1])
                total_retain_5_shot_loss /= len(val_retain_datasets_5_shot_lst[j][1])
                all_files_retain_5_shot_acc += total_retain_acc_5_shot
                all_files_retain_5_shot_loss += total_retain_5_shot_loss
                retain_logits_5_shot = np.concatenate(retain_logits_5_shot_lst, axis=0)
                retain_logits_5_shot_dict[f][time] = retain_logits_5_shot
                retain_logits_5_shot_standardized = retain_logits_5_shot - retain_logits_5_shot.mean(axis=0)
                retain_logits_5_shot_tensor = torch.tensor(retain_logits_5_shot_standardized, device=device)
                retain_5_shot_labels = label_possibilities_tensor[retain_logits_5_shot_tensor.argmax(dim=-1)]
                last_labels_retain_5_shot = torch.cat(last_labels_retain_5_shot_lst, dim=0)
                retain_acc_5_shot_calibrated = (retain_5_shot_labels == last_labels_retain_5_shot).float().mean().item()
                all_files_retain_5_shot_acc_calibrated += retain_acc_5_shot_calibrated
                retain_accs_5_shot_calibrated[f][time] = retain_acc_5_shot_calibrated
                retain_accs_5_shot[f][time] = total_retain_acc_5_shot

            all_files_retain_5_shot_acc /= len(val_retain_datasets_5_shot_lst)
            all_files_retain_5_shot_acc_calibrated /= len(val_retain_datasets_5_shot_lst)
            all_files_retain_5_shot_loss /= len(val_retain_datasets_5_shot_lst)

        
        samples[time] = sample_tokens(model, tokenizer, device, max_length=15)


        if not disable_wandb:
            wandb.log(
                {
                    "unlearning/forget_acc": all_files_forget_acc,
                    "unlearning/retain_acc": all_files_retain_acc,
                    "unlearning_other/retain_acc_5_shot": all_files_retain_5_shot_acc if eval_5_shot else None,
                    "unlearning_other/forget_acc_calibrated": all_files_forget_acc_calibrated,
                    "unlearning_other/retain_acc_calibrated": all_files_retain_acc_calibrated,
                    "unlearning_other/retain_acc_5_shot_calibrated": all_files_retain_5_shot_acc_calibrated if eval_5_shot else None,
                    "unlearning_other/eval_forget_loss": all_files_forget_loss,
                    "unlearning_other/eval_retain_loss": all_files_retain_loss,
                    "unlearning_other/eval_retain_5_shot_loss": all_files_retain_5_shot_loss,
                    "unlearning_other/epoch": time, 
                }
            )
        return {
            "unlearning/forget_acc": all_files_forget_acc,
            "unlearning/retain_acc": all_files_retain_acc,
            "unlearning_other/retain_acc_5_shot": all_files_retain_5_shot_acc if eval_5_shot else None,
            "unlearning_other/forget_acc_calibrated": all_files_forget_acc_calibrated,
            "unlearning_other/retain_acc_calibrated": all_files_retain_acc_calibrated,
            "unlearning_other/retain_acc_5_shot_calibrated": all_files_retain_5_shot_acc_calibrated if eval_5_shot else None,
            "unlearning_other/eval_forget_loss": all_files_forget_loss,
            "unlearning_other/eval_retain_loss": all_files_retain_loss,
            "unlearning_other/eval_retain_5_shot_loss": all_files_retain_5_shot_loss,
            "unlearning_other/epoch": time, 
        }

    evaled_0 = False
    eval(0); evaled_0 = True
    # raise Exception("debugging")
    # if len(train_dataset) / batch_size > 600:
    #     train_dataset = train_dataset[:batch_size * 600]

    for epoch in range(epochs):
        if just_eval:
            break
        model.train()
        # frozen_layers = set()
        # unfrozen_layers = set()
        # for name, param in model.named_parameters():
        #     digits = "".join([char for char in name if char.isdigit()])
        #     num = int(digits) if digits != "" else name
        #     if not param.requires_grad:
        #         frozen_layers.add(num)
        #     else:
        #         unfrozen_layers.add(num)
        # print(f"frozen layers: {frozen_layers}\n\nunfrozen layers: {unfrozen_layers}")
        # raise Exception("stop")

        random.Random(epoch).shuffle(train_dataset)
        batches = [train_dataset[i : i + batch_size] for i in range(0, len(train_dataset), batch_size)]
        retain_batches = [retain_dataset[i : i + batch_size] for i in range(0, len(retain_dataset), batch_size)]

        if keep_set:
            random.Random(epoch).shuffle(keep_dataset)
            keep_batches = [keep_dataset[i : i + batch_size] for i in range(0, len(keep_dataset), batch_size)]
        # i = 0
        for i, batch in enumerate(tqdm(batches, desc=f"Training epoch {epoch}")):
            for group in optimizer.param_groups:
                step = epoch * len(batches) + i + 1
                group["lr"] = lr * max(0, min(1, step / warmup_steps))

            optimizer.zero_grad()

            j = i % len(retain_batches)

            forget_loss = get_loss(model, batch, device, tokenizer, label_possibilities, unlearn_type=unlearn_type, mcq=mcq, print_prompts=i==0 and epoch==0, prompts_prefix="forget prompts", data_format=data_format, loss_type=loss_type)
            print(f"{forget_loss=}")        
            retain_loss = get_loss(model, retain_batches[j], device, tokenizer, label_possibilities, unlearn_type=UnlearnType.FWF, print_prompts=i==0 and epoch==0, prompts_prefix="retain prompts", data_format=data_format, loss_type=loss_type)
            print(f"{retain_loss=}")        
            # print(f"{retain_loss=}\n{retain_coeff}")
            # if unlearn_type.value == UnlearnType.GD.value:
            #     loss = -forget_loss + retain_coeff * retain_loss
            # elif unlearn_type.value in [UnlearnType.WHP.value, UnlearnType.FWF.value]:
            #     loss = forget_loss + retain_coeff * retain_loss
            # else:
            #     raise Exception("Unlearning type not specified")
            # raise Exception("loss no implemented")
            try:
                loss = forget_loss + retain_coeff * retain_loss
            except Exception as e:
                print(f"""
                    error. {forget_loss=}\n{retain_loss=}
                    {retain_coeff=}\n{hydra_dict=}
                """)
                raise e

            loss.backward()
            optimizer.step()
            # wandb.log(
            #     {"train_loss": loss.item(), "epoch": epoch + i / len(batches), "train_acc": acc, "lr": group["lr"]}
            #     | ({"keep_loss": keep_loss.item(), "keep_acc": keep_acc} if keep_set is not None else {})
            # )
            if not disable_wandb:
                wandb.log({
                    "unlearning/train_loss": loss.item(),
                    "unlearning_other/epoch": epoch + i / len(batches),
                    "unlearning_other/lr": group["lr"],
                    "unlearning/forget_loss": forget_loss.item(),
                    "unlearning/retain_loss": retain_loss.item()
                })
            # i += 1

        if (not just_eval and (epoch + 1) % eval_every) == 0:
            eval_res = eval(epoch + 1)
            print(f"{eval_res['unlearning/forget_acc']=}, {eval_res['unlearning/retain_acc']=}")
            if eval_res["unlearning/forget_acc"] < 0.5 and eval_res["unlearning/retain_acc"] > 0.5:
                temp_save_name = f"{save_name}_epoch{epoch + 1}_temp-save_forget{eval_res['unlearning/forget_acc']}_retain{eval_res['unlearning/retain_acc']}"
                print(f"saving with name {temp_save_name=}")
                model.save_pretrained(temp_save_name)
                tokenizer.save_pretrained(temp_save_name)


    if not just_eval or not evaled_0:
        eval(epochs)
    if save_name is not None:
        model.save_pretrained(save_name)
        tokenizer.save_pretrained(save_name)
    
    if results_file is not None:
        # dir = f"./evals/GD_evals"
        # os.makedirs(dir, exist_ok=True)
        # results_file = f"{dir}/results.json"
        lock = FileLock(f"{results_file}.lock")
        with lock:
            if os.path.exists(results_file):
                with open(results_file, "r+") as f:
                    results = json.load(f)
                    if save_name not in results:
                        results[save_name] = {}
                    results[save_name]["+".join(val_files)] = forget_accs
                    results[save_name]["+".join(val_retain_files)] = retain_accs
                    f.seek(0)
                    f.truncate()
                    json.dump(results, f, indent=4)
            else:
                with open(results_file, "w+") as f:
                    results = {}
                    if save_name not in results:
                        results[save_name] = {}
                    results[save_name]["+".join(val_files)] = forget_accs
                    results[save_name]["+".join(val_retain_files)] = retain_accs
                    f.seek(0)
                    f.truncate()
                    json.dump(results, f, indent=4)
    if not disable_wandb:
        wandb.finish()
    
    return (
        save_name,
        forget_accs, forget_accs_calibrated, forget_logits_dict,
        retain_accs, retain_accs_calibrated, retain_logits_dict,
        retain_accs_5_shot, retain_accs_5_shot_calibrated,
        retain_logits_5_shot_dict,
        samples
    )


@ray.remote(num_gpus=1)
def remote_main(
    train_files: list[str],
    wrong_unlearn_files: list[str],
    fixed_wrong_unlearn_files: list[str],
    val_files: list[str],
    dev_set: str,
    base_model: str,
    lr: float,
    name: str,
    k_shot: int = 4,
    epochs: int = 10,
    batch_size: int = 4,
    val_batch_size: int = 8,
    warmup_steps: int = 24,
    retain_files: list[str] = [],
    val_retain_files: list[str] = [],
    retain_dev_file: str = "",
    max_samples: Optional[int] = None,
    data_seed: int = 2,
    eval_every: int = 1,
    keep_set: Optional[int] = None,
    keep_set_weight: Optional[float] = None,
    train_on_wrong_answer: bool = False,
    train_set_size: Optional[int] = None,
    val_set_size: Optional[int] = None,
    kind: str = "base",
    save_name: Optional[str] = None,
    version: str = "v2.11",
    model = None,
    retain_coeff: int = 1,
    project_name: str = "unlearn",
    unlearn_type: UnlearnType = UnlearnType.NOT_SPECIFIED,
    results_file: str = None,
    just_eval: bool = False,
    disable_wandb: bool = False,
    freeze_layers: Optional[list[tuple[int, int]]] = None,
    mcq: bool = False,
    hydra_dict: dict = {},
    data_format: DataFormat = DataFormat.NOT_SPECIFIED,
    loss_type: LossType = LossType.NOT_SPECIFIED,
):
    return main(
        train_files,
        wrong_unlearn_files,
        fixed_wrong_unlearn_files,
        val_files,
        dev_set,
        base_model,
        lr,
        name,
        k_shot,
        epochs,
        batch_size,
        val_batch_size,
        warmup_steps,
        retain_files,
        val_retain_files,
        retain_dev_file,
        max_samples,
        data_seed,
        eval_every,
        keep_set,
        keep_set_weight,
        train_on_wrong_answer,
        train_set_size,
        val_set_size,
        kind,
        save_name,
        version,
        model,
        retain_coeff,
        project_name,
        unlearn_type,
        results_file,
        just_eval,
        disable_wandb,
        hydra_dict=hydra_dict,
	freeze_layers=freeze_layers,
	mcq=mcq,
	data_format=data_format,
	loss_type=loss_type,
    )


@ray.remote(num_gpus=1)
@torch.no_grad()
def just_eval(
    train_files: list[str],
    wrong_unlearn_files: list[str],
    fixed_wrong_unlearn_files: list[str],
    val_files: list[str],
    dev_set: str,
    base_model: str,
    lr: float,
    name: str,
    #k_shot: int = 4,
    k_shot: int = 0,
    epochs: int = 10,
    batch_size: int = 4,
    val_batch_size: int = 8,
    warmup_steps: int = 24,
    retain_files: list[str] = [],
    val_retain_files: list[str] = [],
    retain_dev_file: str = "",
    max_samples: Optional[int] = None,
    data_seed: int = 2,
    eval_every: int = 1,
    keep_set: Optional[int] = None,
    keep_set_weight: Optional[float] = None,
    train_on_wrong_answer: bool = False,
    train_set_size: Optional[int] = None,
    val_set_size: Optional[int] = None,
    kind: str = "base",
    save_name: Optional[str] = None,
    version: str = "v2.11",
    model = None,
    retain_coeff: int = 1,
    project_name: str = "unlearn",
    unlearn_type: UnlearnType = UnlearnType.NOT_SPECIFIED,
    results_file: str = None,
    just_eval: bool = False,
    disable_wandb: bool = False,
    freeze_layers: Optional[list[tuple[int, int]]] = None,
    mcq: bool = False,
    hydra_dict: dict = {},
    data_format: DataFormat = DataFormat.NOT_SPECIFIED,
    loss_type: LossType = LossType.NOT_SPECIFIED,
):
    return main(
        train_files,
        wrong_unlearn_files,
        fixed_wrong_unlearn_files,
        val_files,
        dev_set,
        base_model,
        lr,
        name,
        k_shot,
        epochs,
        batch_size,
        val_batch_size,
        warmup_steps,
        retain_files,
        val_retain_files,
        retain_dev_file,
        max_samples,
        data_seed,
        eval_every,
        keep_set,
        keep_set_weight,
        train_on_wrong_answer,
        train_set_size,
        val_set_size,
        kind,
        save_name,
        version,
        model,
        retain_coeff,
        project_name,
        unlearn_type,
        results_file,
        just_eval,
        disable_wandb,
        hydra_dict=hydra_dict,
	freeze_layers=freeze_layers,
	mcq=mcq,
	data_format=data_format,
	loss_type=loss_type,
    )


    # # Save dates_accs_cpy to a JSON file
    # with open(f'{dir}/forget_accs.json', 'w') as f:
    #     json.dump(forget_accs, f, indent=4)

    # # Save mmlu_accs_cpy to a JSON file
    # with open(f'{dir}/retain_accs.json', 'w') as f:
    #     json.dump(retain_accs, f, indent=4)

    # wandb.finish()

class TestLayerFreezing(unittest.TestCase):
    def setUp(self):
        self.model_id = "meta-llama/Meta-Llama-3-8B"  # Use a smaller model for testing
        self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
        self.first_layer_to_freeze = 2
        self.last_layer_to_freeze_exclusive = 5

    def test_layer_freezing(self):
        # Freeze layers
        freeze_layers(self.model, self.first_layer_to_freeze, self.last_layer_to_freeze_exclusive)

        # Simulate a training step
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
        dummy_input = torch.randint(0, 1000, (1, 10))  # Random input
        dummy_output = self.model(dummy_input)
        loss = dummy_output.logits.mean()
        loss.backward()
        optimizer.step()

        # Check if specified layers are frozen
        for name, param in self.model.named_parameters():
            if 'layers' in name:
                layer_num = int(name.split('.')[2])
                if self.first_layer_to_freeze <= layer_num < self.last_layer_to_freeze_exclusive:
                    self.assertFalse(param.requires_grad, f"Layer {layer_num} should be frozen")
                    self.assertTrue(torch.all(torch.tensor(param.grad == 0)), f"Gradient for layer {layer_num} should be zero. {param.grad=}")
                else:
                    self.assertTrue(param.requires_grad, f"Layer {layer_num} should not be frozen")
                    self.assertFalse(torch.all(torch.tensor(param.grad == 0)), f"Gradient for layer {layer_num} should not be zero. {param.grad=}")

        print(f'test done')

if __name__ == '__main__':
    unittest.main()
