import os
import json
import math
import random
from typing import Optional, TypedDict
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW, AutoTokenizer, AutoModelForCausalLM
from ray.experimental.tqdm_ray import tqdm
import wandb
import ray
from lion_pytorch import Lion
import fcntl

MAX_SEQ_LEN = 512


class Point(TypedDict):
    question: str
    choices: list[str]
    answer: int


doc_to_choice = ["A", "B", "C", "D"]


def create_prompt(point: Point) -> str:
    return "\n".join(
        [point["question"]] + [f"{doc_to_choice[i]}. {c}" for i, c in enumerate(point["choices"])] + ["Answer:"]
    )


def make_k_shot(data: list[Point], dev_set: list[Point], k: int) -> list[Point]:
    """Return a k-shot version of the data."""
    if k == 0:
        return data
    preprompt = "\n\n".join([f"{create_prompt(point)} {doc_to_choice[point['answer']]}." for point in dev_set[:k]])
    return [
        {"question": preprompt + "\n\n" + create_prompt(point), "choices": point["choices"], "answer": point["answer"]}
        for point in data
    ]


def process_batch(
    batch: list[Point],
    device: torch.device,
    tokenizer: AutoTokenizer,
    label_possibilities: list[int],
    train_on_wrong_answer: bool = False,
):
    """Return tokens, last pos label ids."""
    prompts = [create_prompt(point) for point in batch]
    tokens = tokenizer(prompts, return_tensors="pt", max_length=MAX_SEQ_LEN, truncation=True, padding=True).to(device)

    def get_answer(point):
        if train_on_wrong_answer:
            return random.Random(point["question"]).choice(
                [i for i in range(len(doc_to_choice)) if i != point["answer"]]
            )
        else:
            return point["answer"]

    last_pos_label_ids = torch.tensor([label_possibilities[get_answer(point)] for point in batch], device=device)
    return tokens, last_pos_label_ids


def get_loss_and_acc(model, tokens, last_pos_label_ids, label_possibilities) -> tuple[torch.Tensor, float]:
    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits[:, -1, :]
    loss = torch.nn.functional.cross_entropy(logits, last_pos_label_ids)
    label_impossibilities = list(set(range(logits.shape[1])) - set(label_possibilities))
    logits[:, label_impossibilities] = -float("inf")
    acc = (logits.argmax(dim=-1) == last_pos_label_ids).float().sum().item()
    return loss, acc, logits[:, label_possibilities].detach().cpu().numpy()

def write_results_to_file(model_path, balanced_acc, filename="results.txt"):
    with open(filename, "a") as f:
        f.write(f"{model_path}, {balanced_acc:.2f}\n")


@ray.remote(num_gpus=1)
def main(
    train_files: list[str],
    val_files: list[str],
    dev_set: str,
    base_model: str,
    lr: float,
    name: str,
    k_shot: int = 0,
    epochs: int = 10,
    batch_size: int = 4,
    val_batch_size: int = 8,
    warmup_steps: int = 24,
    max_samples: Optional[int] = None,
    data_seed: int = 0,
    eval_every: int = 1,
    keep_set: Optional[int] = None,
    keep_set_weight: Optional[float] = None,
    train_on_wrong_answer: bool = False,
    train_set_size: Optional[int] = None,
    val_set_size: Optional[int] = None,
    kind: str = "base",
    save_name: Optional[str] = None,
    version: str = "v2.11",
    val_retain_files: list[str] = [],
):
    assert (keep_set and keep_set_weight) or (not keep_set and not keep_set_weight)

    wandb.init(project="retrain", config=locals(), name=name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"

    print(f"{train_files=}\n{val_files=}\n{val_retain_files=}")

    # I check it works with zephyr-7b-beta, but might not work with other models
    # (zephyr adds an implicit space)
    label_possibilities = [tokenizer.encode(f"{t}. ", add_special_tokens=False)[0] for t in doc_to_choice]
    print([tokenizer.decode([t]) for t in label_possibilities])

    model = AutoModelForCausalLM.from_pretrained(
        base_model, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
    ).to(device)
    optimizer = Lion(model.parameters(), lr=lr, use_triton=True)

    train_dataset = sum([json.load(open(f"data/{file}.json")) for file in train_files], [])
    random.Random(data_seed).shuffle(train_dataset)

    if max_samples is not None:
        train_dataset = train_dataset[:max_samples]

    # epochs = math.ceil(total_steps / len(train_dataset))


    val_dataset = sum([json.load(open(f"data/{file}.json")) for file in val_files], [])
    dev_dataset = json.load(open(f"data/{dev_set}.json"))
    val_retain_dataset = sum([json.load(open(f"data/{file}.json")) for file in val_retain_files], [])

    train_dataset = make_k_shot(train_dataset, dev_dataset, k_shot)
    val_dataset = make_k_shot(val_dataset, dev_dataset, k_shot)

    if keep_set is not None:
        assert k_shot == 0
        keep_dataset = json.load(open(f"data/{keep_set}.json"))
        batch_size //= 2

    
    forget_accs = {}
    retain_accs = {}

    

    i = 0
    @torch.no_grad()
    def eval(time: int):
        model.eval()
        batches = [val_dataset[i : i + val_batch_size] for i in range(0, len(val_dataset), val_batch_size)]
        retain_batches = [val_retain_dataset[i : i + val_batch_size] for i in range(0, len(val_retain_dataset), val_batch_size)]
        total_loss = 0
        total_forget_acc = 0
        total_retain_acc = 0
        all_preds = []
        all_labels = []
        for batch in tqdm(batches, desc=f"Forget-eval-{time=}"):
            tokens, last_pos_label_ids = process_batch(batch, device, tokenizer, label_possibilities)
            loss, acc, preds = get_loss_and_acc(model, tokens, last_pos_label_ids, label_possibilities)
            all_preds.append(preds)
            all_labels.extend([batch["answer"] for batch in batch])
            total_loss += loss.item()
            total_forget_acc += acc
        
        for i in tqdm(range(len(retain_batches)), desc=f"Retain-eval-{time=}"):
            if i == 0: 
                print(f"In retain eval")
            tokens, last_pos_label_ids = process_batch(retain_batches[i], device, tokenizer, label_possibilities)
            _, retain_acc, _ = get_loss_and_acc(model, tokens, last_pos_label_ids, label_possibilities)
            total_retain_acc += retain_acc

        total_loss /= len(batches)
        total_forget_acc /= len(val_dataset)
        total_retain_acc = total_retain_acc / len(val_retain_dataset) if len(val_retain_dataset) > 0 else 0

        forget_accs[f"{name}-time{time}"] = total_forget_acc
        retain_accs[f"{name}-time{time}"] = total_retain_acc


        all_preds_a = np.concatenate(all_preds, axis=0)
        balanced = all_preds_a - all_preds_a.mean(axis=0)
        bal_acc = (balanced.argmax(axis=1) == np.array(all_labels)).mean()
        prop_pred_per_class = {
            f"prop_pred_{i}": (balanced.argmax(axis=1) == i).mean() for i in range(len(doc_to_choice))
        }

        wandb.log(
            {
                "forget_acc": total_forget_acc,
                "retain_acc": total_retain_acc,
                "epoch": time, 
            }
        )

    eval(0)
    print(f"num_epochs: {epochs}")

    for epoch in range(epochs):
        if epoch == 0:
            print("in epochs")
        model.train()

        random.Random(epoch).shuffle(train_dataset)
        batches = [train_dataset[i : i + batch_size] for i in range(0, len(train_dataset), batch_size)]

        if keep_set:
            random.Random(epoch).shuffle(keep_dataset)
            keep_batches = [keep_dataset[i : i + batch_size] for i in range(0, len(keep_dataset), batch_size)]

        for i, batch in enumerate(tqdm(batches, desc=f"Training epoch {epoch}")):
            for group in optimizer.param_groups:
                step = epoch * len(batches) + i + 1
                group["lr"] = lr * max(0, min(1, step / warmup_steps))

            optimizer.zero_grad()

            tokens, last_pos_label_ids = process_batch(
                batch, device, tokenizer, label_possibilities, train_on_wrong_answer
            )
            loss, acc, _ = get_loss_and_acc(model, tokens, last_pos_label_ids, label_possibilities)

            if keep_set is not None:
                keep_tokens, keep_last_pos_label_ids = process_batch(
                    keep_batches[i % len(keep_batches)], device, tokenizer, label_possibilities
                )
                keep_loss, keep_acc, _ = get_loss_and_acc(
                    model, keep_tokens, keep_last_pos_label_ids, label_possibilities
                )
                loss = -loss + keep_set_weight * keep_loss

            loss.backward()
            optimizer.step()
            wandb.log(
                {"train_loss": loss.item(), "epoch": epoch + i / len(batches), "train_acc": acc, "lr": group["lr"]}
                | ({"keep_loss": keep_loss.item(), "keep_acc": keep_acc} if keep_set is not None else {})
            )

        if (epoch + 1) % eval_every == 0:
            eval(epoch + 1)

    if save_name is not None:
        model.save_pretrained(save_name)
        tokenizer.save_pretrained(save_name)
    
    dir = f"./evals/ft/{name}"

    os.makedirs(dir, exist_ok=True)

    # Save dates_accs_cpy to a JSON file
    with open(f'{dir}/forget_accs.json', 'w') as f:
        json.dump(forget_accs, f, indent=4)
    
    print(f"forget_accs: {forget_accs}")

    # Save mmlu_accs_cpy to a JSON file
    with open(f'{dir}/retain_accs.json', 'w') as f:
        json.dump(retain_accs, f, indent=4)

    wandb.finish()

def load_jsonl(files):
    dataset = []
    for file in files:
        for line in open(file, "r"):
            dataset += [json.loads(line)]
        
    return dataset

@torch.no_grad()
def just_eval(
    model_path: str,
    val_batch_size: int = 8,
    dev_set: str = "dates-dev",
    k_shot: int = 0,
    val_files: list[str] = ["dates-split-0", "dates-split-1"],
    results_file: str = "evals/just_eval_results/results2.json"
):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"

    # I check it works with zephyr-7b-beta, but might not work with other models
    # (zephyr adds an implicit space)
    label_possibilities = [tokenizer.encode(f"{t}. ", add_special_tokens=False)[0] for t in doc_to_choice]
    print([tokenizer.decode([t]) for t in label_possibilities])

    model = AutoModelForCausalLM.from_pretrained(
        model_path, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
    ).to(device)

    val_dataset = load_jsonl([f"data/{file}.jsonl" for file in val_files])
    dev_dataset = load_jsonl([f"data/{dev_set}.jsonl"])

    val_dataset = make_k_shot(val_dataset, dev_dataset, k_shot)

    model.eval()
    batches = [val_dataset[i : i + val_batch_size] for i in range(0, len(val_dataset), val_batch_size)]
    total_loss = 0
    total_acc = 0
    all_preds = []
    all_labels = []
    for batch in tqdm(batches, desc=f"Just Eval {val_files[0].split('/')[0]}"):
        tokens, last_pos_label_ids = process_batch(batch, device, tokenizer, label_possibilities)
        loss, acc, preds = get_loss_and_acc(model, tokens, last_pos_label_ids, label_possibilities)
        all_preds.append(preds)
        all_labels.extend([batch["answer"] for batch in batch])
        total_loss += loss.item()
        total_acc += acc
    total_loss /= len(batches)
    total_acc /= len(val_dataset)
    all_preds_a = np.concatenate(all_preds, axis=0)
    balanced = all_preds_a - all_preds_a.mean(axis=0)
    bal_acc = (balanced.argmax(axis=1) == np.array(all_labels)).mean()

    if results_file is not None:
        if os.path.exists(results_file):
            with open(results_file, "r+") as f:
                fcntl.flock(f, fcntl.LOCK_EX)
                results = json.load(f)
                if model_path not in results:
                    results[model_path] = {}
                results[model_path]["+".join(val_files)] = {"total_acc": total_acc, "bal_acc": bal_acc}
                f.seek(0)
                f.truncate()
                json.dump(results, f, indent=4)
                fcntl.flock(f, fcntl.LOCK_UN)
        else:
            results = {}
            if model_path not in results:
                results[model_path] = {}
            results[model_path]["+".join(val_files)] = {"total_acc": total_acc, "bal_acc": bal_acc}
            with open(results_file, "w") as f:
                json.dump(results, f, indent=4)

        # Ensure the model_path exists in results
        if model_path not in results:
            results[model_path] = {}

        # Add the val_files and accuracies to the model_path dictionary
        # results[model_path]["+".join(val_files)] = {"total_acc": total_acc, "bal_acc": bal_acc}
        # os.makedirs(name="/".join(results_file.split("/")[:-1]), exist_ok=True)

        # with open(results_file, "w") as f:
        #     json.dump(results, f, indent=4)

    # print(f"{model_path}: {total_acc=} {bal_acc=}")
    return total_acc, bal_acc, all_preds_a

# remote_just_eval = ray.remote(just_eval)
@torch.no_grad()
@ray.remote(num_gpus=1)
def remote_just_eval(model_path: str,
    val_batch_size: int = 8,
    dev_set: str = "dates-dev",
    k_shot: int = 0,
    val_files: list[str] = ["dates-split-0", "dates-split-1"],
    results_file: str = "evals/just_eval_results/results2.json"
):
    return just_eval(
        model_path=model_path,
        val_batch_size=val_batch_size,
        dev_set=dev_set,
        k_shot=k_shot,
        val_files=val_files,
        results_file=results_file
    )



## For Cut

if __name__ == "__main__":
    ray.init()
    deps = []
    epochs = 4
    lr = 5e-7
    results_file: str = "evals/just_eval_results/results_trimmed.json"
    # for skip_split in range(5):
    # for skip_split in range(3):
    # for i in range(5, 20):
    # for i in range(5, 6):
    # for alpha in [5, 7, 9 , 11, 13, 15, 17]:
    # for alpha in [0.2, 0.4, 0.6, 0.8, 1, 2, 3]:
    # for i in range(1, 4, 1):
    for i in range(0, 1):
        # original_model = f"models/HuggingFaceH4/zephyr-7b-beta_alpha-[100.0, 100.0]_batches-80_layer-7_2024-06-13-17-37-29"
        # original_model = f"models/meta-llama/Meta-Llama-3-8B_alpha-[100.0, 100.0]_batches-80_layer-7_2024-06-13-18-58-12"
        # base_dir = "models"


        # model = f"HuggingFaceH4/zephyr-7b-beta"
        # model = f"HuggingFaceH4/zephyr-7b-beta"
        # model = f"models/meta-llama/Meta-Llama-3-8B_alpha-[{int(alpha * 100)}.0, {int(alpha * 100)}.0, {int(alpha * 100)}.0, {int(alpha * 100)}.0, {int(alpha * 100)}.0]_batches-400_layer-7"
        # alpha = 0.6
        # alpha = i * 5
        model = f"meta-llama/Meta-Llama-3-8B"
        # model = f"models/cut/mmlu/meta-llama/Meta-Llama-3-8B_alpha-[{int(alpha * 100)}.0, {int(alpha * 100)}.0, {int(alpha * 100)}.0, {int(alpha * 100)}.0, {int(alpha * 100)}.0]_batches-400_layer-7"
        # model = f"models/cut/years/meta-llama/Meta-Llama-3-8B_alpha-[{int(alpha * 100)}, {int(alpha * 100)}, {int(alpha * 100)}, {int(alpha * 100)}, {int(alpha * 100)}]_batches-400_layer-7"
        # model = "models/HuggingFaceH4/zephyr-7b-beta_alpha-[500.0]_batches-80_layer-7_2024-06-13-19-18-39" 
        # model = f"models/HuggingFaceH4/zephyr-7b-beta_alpha-[600.0, 600.0, 600.0, 600.0, 600.0]_batches-400_layer-7_2024-06-14-19-26-27"
        # model = f"models-skip_split{skip_split}-lr{lr}/dates-corpus-retain-rephrased-epochs{epochs}-lr"
        # model = f"models/HuggingFaceH4/ft-skip_split{skip_split}--{original_model.split('/')[-1]}"


        # val_files=[f"dates-years/split_{i}" for i in range(5) if i == skip_split]
        # val_files=[f"dates-years/split_{i}" for i in range(3)]
        # val_files =["mmlu/mmlu_3000"]
        # mmlu_cat = "anatomy"
        # mmlu_cat = "college_chemistry"
        # mmlu_cat = "machine_learning"
        # val_files =[f"mmlu_cats_random_trimmed/mmlu_{mmlu_cat}" for mmlu_cat in ["health", "history", "law", "philosophy", "social sciences"]]
        val_files =[f"mmlu_cats_random_trimmed/mmlu_{mmlu_cat}" for mmlu_cat in ["STEM", "business"]]
        # val_files =[f"dates-years-trimmed/split_{split_num}" for split_num in range(2)]
        # val_files =[f"mmlu_cats/mmlu_other"]
        # val_files =["dates-years/split_0"]
        # print(f"{forget_model=}")
        deps += [
            remote_just_eval.remote(
                model,
                dev_set="dates-years/dev",
                k_shot=0,
                val_files=val_files,
                results_file=results_file
            )
        ]

    for dep in deps:
        print(f"{ray.get(dep)=}")
    

## For Gradient Difference

# if __name__ == "__main__":
#     ray.init()
#     models_nums = [5e-7]
#     deps = []
#     # lrs = [1e-7, 3e-7, 1e-6, 3e-6, 1e-5]
#     lrs = [5e-7]
#     # coeffs = [0, 0.01, 0.1, 0.5, 1.5, 2.4, 4]
#     # coeffs = [0.02, 0.05, 0.07, 0.2, 0.8, 1, 1.2]
#     coeffs = [0, 0.01, 0.02, 0.05, 0.07, 0.1, 0.2, 0.5, 0.8, 1, 1.2, 1.5, 2.4, 4]
#     # lrs = [2.1e-7]
#     skip_split = 0
#     epochs = 4
#     for skip_split in range(5):
#     # for skip_split in range(1):
#         for lr in lrs:
#             original_model = f"models/paraphrased_models/dates-corpus-retain-rephrased-lr"
#             # forget_model = f"models-skip_split{skip_split}-lr{lr}/dates-corpus-retain-rephrased-epochs{epochs}-lr"
#             forget_model = f"ft-dates-corpus-retain-rephrased-epochs{epochs}-skip_split{skip_split}-lr{lr}"
#             for num in models_nums:
#                 # for coeff in coeffs:
#                 for coeff in coeffs[:1]:
#                     # print(f"{skip_split=}")
#                     deps += [
#                         main.remote(
#                             [f"dates-years/split_{i}" for i in range(5) if i != skip_split],
#                             [f"dates-years/split_{i}" for i in range(5) if i == skip_split],
#                             "dates-years/dev",
#                             f"{original_model}{num}-rc{coeff}-seed{seed}",
#                             lr,
#                             epochs=epochs,
#                             name=f"{forget_model}-rc{coeff}-seed{seed}",
#                             kind="ft",
#                             save_name=f"models/{forget_model}-rc{coeff}-seed{seed}",
#                             # val_retain_files=["mmlu"],
#                             )
#                         for seed in range(3)
#                     ]
#                 # main.remote(
#                 #             [f"dates-years/split_{i}" for i in range(5) if i != skip_split],
#                 #             [f"dates-years/split_{i}" for i in range(5) if i == skip_split],
#                 #             "dates-years/dev",
#                 #             f"{original_model}{num}-rc{coeff}-seed{seed}",
#                 #             lr,
#                 #             epochs=epochs,
#                 #             name=f"{forget_model}-rc{coeff}-seed{seed}",
#                 #             kind="ft",
#                 #             save_name=f"models/{forget_model}-rc{coeff}-seed{seed}",
#                 #             # val_retain_files=["mmlu"],
#                 #             )

#     for dep in deps:
#         ray.get(dep)
# 'Meta-Llama-3-8B_alpha-[60.0, 60.0, 60.0, 60.0, 60.0]_batches-125_layer-7'