import os
import json
import random
import numpy as np
import torch
import hydra
from omegaconf import DictConfig
from tqdm import tqdm
from typing import TypedDict
from lion_pytorch import Lion
from transformers import AutoTokenizer, AutoModelForCausalLM

def lr_lambda(current_step, warmup_steps = 20):
    if current_step < warmup_steps:
        # Linear warm-up
        return float(current_step) / float(max(1, warmup_steps))
    # Constant after warm-up
    return 1.0

class Point(TypedDict):
    question: str
    completion: str

def set_seed(seed_value=22):
    """Set seed for reproducibility."""
    random.seed(seed_value)

    np.random.seed(seed_value)

    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)

def create_prompt_question_letter_answer(point: Point) -> str:
    return point['prompt'] + point['completion']

def get_loss_corpus(
    model,
    batch: list[Point],
    device: torch.device,
    tokenizer: AutoTokenizer,
):
    prompts = [
        create_prompt_question_letter_answer(row) + tokenizer.eos_token for row in batch
    ]

    tokens = tokenizer(prompts, return_tensors="pt", truncation=True, padding=True).to(device)
    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits
    neg_log_probs = -get_log_probs(logits, tokens["input_ids"])
    #print(neg_log_probs.shape)

    loss = 0
    for i in range(len(batch)):
        loss += neg_log_probs[i, -3:].sum()
    loss = loss / len(batch)
    return loss

def get_log_probs(logits, tokens):
    log_probs = logits.log_softmax(dim=-1)
    log_probs_for_tokens = log_probs[:, : -1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return log_probs_for_tokens

def get_loss(model, batch, device, tokenizer):
    prompts = [pt["text"] for pt in batch]
    tokens = tokenizer(
        prompts, return_tensors="pt",
        truncation=True, padding=True
    ).to(device)
    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits
    loss = -get_log_probs(logits, tokens["input_ids"]).mean()

    return loss

def load_jsonl(file):
    with open(file, 'r') as f:
        dataset = json.load(f)

    return dataset

def get_loss_dataset(model, dataset, device, tokenizer, batch_size):
    batches = [dataset[i : i + batch_size] for i in range(0, len(dataset), batch_size)]
    running_loss = 0
    for j, batch in enumerate(batches):
        loss = get_loss_corpus(model, batch, device, tokenizer)
        running_loss += loss.item() * len(batch)

    return running_loss/len(dataset)

def create_prompt_question(point: Point) -> str:
    return point['prompt']

def get_acc_corpus(
    model,
    batch: list[Point],
    device: torch.device,
    tokenizer: AutoTokenizer,
):
    prompts = [
        create_prompt_question(row) for row in batch
    ]

    tokens = tokenizer(prompts, return_tensors="pt", truncation=True, padding=True).to(device)
    logits = model(**model.prepare_inputs_for_generation(**tokens)).logits
    last_position_probs = (logits[:, -1, :]).softmax(dim=-1)
    top_prob, top_idx = last_position_probs.max(dim=-1)
    top_tokens = [tokenizer.decode(idx) for idx in top_idx]
    # print(top_tokens)
    correct_count = sum(map(lambda x: x[0] in x[1]["completion"], zip(top_tokens, batch)))
    return correct_count

def get_acc_dataset(model, dataset, device, tokenizer, batch_size):
    batches = [dataset[i : i + batch_size] for i in range(0, len(dataset), batch_size)]
    correct_count = 0
    for j, batch in enumerate(batches):
        correct_count += get_acc_corpus(model, batch, device, tokenizer)

    return correct_count/len(dataset)

def compute_decision_boundary(cfg, model, tokenizer, save_path):
    db_cfg = cfg.decision_boundary
    step_sz = db_cfg.grid_step_size
    X = []
    y = []

    for x_1 in range(0, db_cfg.grid_width, step_sz):
        for x_2 in range(0, db_cfg.grid_height, step_sz):
            prompt = f"What is the label for this input?\nInput: {x_1} {x_2}\nLabel:"

            X.append((x_1, x_2))

            tokens = tokenizer([prompt], return_tensors="pt", truncation=True, padding=True).to("cuda")
            with torch.no_grad():
                logits = model(**model.prepare_inputs_for_generation(**tokens)).logits
            last_position_probs = (logits[:, -1, :]).softmax(dim=-1)
            # top_prob, top_idx = last_position_probs.max(dim=-1)
            # top_tokens = [tokenizer.decode(idx) for idx in top_idx]
            token_id_label_0 = tokenizer(cfg.dataset.labels[0], return_tensors="pt")["input_ids"][0][-1]
            token_id_label_1 = tokenizer(cfg.dataset.labels[1], return_tensors="pt")["input_ids"][0][-1]
            label_0_prob = last_position_probs[:, token_id_label_0][0].item()
            label_1_prob = last_position_probs[:, token_id_label_1][0].item()

            if (label_0_prob + label_1_prob) == 0:
                y.append(0.5)
            else:
                y.append(label_1_prob/(label_0_prob + label_1_prob))

            # print(f"Input: ({x_1} {x_2}), {label_0_prob}, {label_1_prob}")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    np.save(save_path, np.array(y))


def train(cfg, meta_model, tokenizer, retain_dataset, unlearn_dataset, rc, train_batch_size, epochs, eval_every, lr, device="cuda", mode="random"):
    optimizer = Lion(meta_model.parameters(), lr=lr, use_triton=True)
    unlearn_dataset_size = len(unlearn_dataset)
    model_short_name = cfg.model.base_model.split("/")[-1]
    unlearn_loss_lst = []
    unlearn_acc_lst = []
    retain_loss_lst = []
    retain_acc_lst = []
    step = 0

    tqdm_bar = tqdm(range(epochs), desc="Epoch")
    for epoch in tqdm_bar:
        meta_model.train()
        random.shuffle(retain_dataset)
        random.shuffle(unlearn_dataset)
        retain_batches = [retain_dataset[i : i + train_batch_size] for i in range(0, len(retain_dataset), train_batch_size)]

        for j, batch in enumerate(retain_batches):

            # Evaluation
            if step % eval_every == 0:
                meta_model.eval()

                # print(f"{epoch=}, {step=}", end='\r')

                # Loss
                if cfg.compute_loss:
                    retain_loss = get_loss_dataset(meta_model, retain_dataset, device, tokenizer, batch_size=train_batch_size)
                    retain_loss_lst.append(retain_loss)
                    # print(f"retain loss: {retain_loss}")
                    if unlearn_dataset_size != 0:
                        unlearn_loss = get_loss_dataset(meta_model, unlearn_dataset, device, tokenizer, batch_size=train_batch_size)
                        unlearn_loss_lst.append(unlearn_loss)
                        # print(f"unlearn loss: {unlearn_loss}")

                    if step in cfg.model.checkpoints[unlearn_dataset_size]:
                        save_dir = f"results/{model_short_name}/{cfg.sample_size}/unlearn_dz_{unlearn_dataset_size}/{mode}"
                        os.makedirs(save_dir, exist_ok=True)
                        np.save(save_dir + f"/retain_loss-step{step}", np.array(retain_loss_lst))
                        if unlearn_dataset_size != 0:
                            np.save(save_dir + f"/unlearn_loss-step{step}", np.array(unlearn_loss_lst))
                    tqdm_bar.set_postfix_str(f"{step=}, {retain_loss=}")

                # Accuracy
                if cfg.compute_acc:
                    retain_acc = get_acc_dataset(meta_model, retain_dataset, device, tokenizer, batch_size=train_batch_size)
                    retain_acc_lst.append(retain_acc)
                    # print(f"retain acc: {retain_acc}")
                    if unlearn_dataset_size != 0:
                        unlearn_acc = get_acc_dataset(meta_model, unlearn_dataset, device, tokenizer, batch_size=train_batch_size)
                        unlearn_acc_lst.append(unlearn_acc)
                        # print(f"unlearn acc: {unlearn_acc}")

                    if step in cfg.model.checkpoints[unlearn_dataset_size]:
                        save_dir = f"results/{model_short_name}/{cfg.sample_size}/unlearn_dz_{unlearn_dataset_size}/{mode}"
                        os.makedirs(save_dir, exist_ok=True)
                        np.save(save_dir + f"/retain_acc-step{step}", np.array(retain_acc_lst))
                        if unlearn_dataset_size != 0:
                            np.save(save_dir + f"/unlearn_acc-step{step}", np.array(unlearn_acc_lst))
                    tqdm_bar.set_postfix_str(f"{step=}, {retain_acc=}")

                if cfg.compute_loss and cfg.compute_acc:
                    tqdm_bar.set_postfix_str(f"{step=}, {retain_loss=}, {retain_acc=}")

            # Decision boundary
            if cfg.compute_decision_boundary and step in cfg.decision_boundary.checkpoints[unlearn_dataset_size]:
                compute_decision_boundary(cfg, meta_model, tokenizer,
                                          save_path = f"results/{model_short_name}/{cfg.sample_size}/unlearn_dz_{unlearn_dataset_size}/{mode}/decision_boundary-step{step}")

            # Training
            meta_model.train()
            optimizer.zero_grad()
            if rc == 1:
                unlearn_loss = 0
            else:
                unlearn_loss = get_loss_corpus(meta_model, unlearn_dataset, device, tokenizer)
            if rc == 0:
                retain_loss = 0
            else:
                retain_loss = get_loss_corpus(meta_model, batch, device, tokenizer)

            loss = (1 - rc) * unlearn_loss + rc * retain_loss
            loss.backward()
            if cfg.use_clip_grad_norm:
                torch.nn.utils.clip_grad_norm_(meta_model.parameters(), max_norm=1)
            optimizer.step()

            step +=  1

    save_dir = f"results/{model_short_name}/{cfg.sample_size}/unlearn_dz_{unlearn_dataset_size}/{mode}"
    os.makedirs(save_dir, exist_ok=True)
    if cfg.compute_loss:
        np.save(save_dir + f"/retain_loss", np.array(retain_loss_lst))
        if unlearn_dataset_size != 0:
            np.save(save_dir + f"/unlearn_loss", np.array(unlearn_loss_lst))
    if cfg.compute_acc:
        np.save(save_dir + f"/retain_acc", np.array(retain_acc_lst))
        if unlearn_dataset_size != 0:
            np.save(save_dir + f"/unlearn_acc", np.array(unlearn_acc_lst))

    if cfg.save_model:
        if cfg.model.finetuned_steps == 0:
            model_path_prefix = f"models/{model_short_name}/{cfg.sample_size}/{mode}/ft-"
        else:
            model_path_prefix = f"models/{model_short_name}/{cfg.sample_size}/{mode}/ft-step{cfg.model.finetuned_steps}-unlearn-"

        model_path = f"{model_path_prefix}step{step}"
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        meta_model.save_pretrained(model_path)
        tokenizer.save_pretrained(model_path)
        print(f"Model saved at {model_path}.")

CONFIG_PATH = "configs/Qwen3-4B" # Use --config-path=CONFIG_PATH in command line to override
CONFIG_NAME = "unlearn_1024" # Use --config-name=CONFIG_NAME in command line to override
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME, version_base=None)
def main(cfg: DictConfig) -> None:

    # print(cfg)
    model_cfg = cfg.model
    for unlearn_dataset_size in model_cfg.unlearn_dataset_size:
        for mode in model_cfg.modes:
            set_seed(22)

            # Load (pretrained) model and tokenizer
            if model_cfg.finetuned_steps == 0:
                base_model_name = model_cfg.base_model
            else:
                model_short_name = model_cfg.base_model.split("/")[-1]
                base_model_name = f"models/{model_short_name}/{cfg.sample_size}/{mode}/ft-step{model_cfg.finetuned_steps}"
            meta_model = AutoModelForCausalLM.from_pretrained(
                    base_model_name,
                    torch_dtype=torch.float16,
                    device_map="cuda" # Use GPU if available
                )
            tokenizer = AutoTokenizer.from_pretrained(base_model_name)
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.padding_side = "left"
            tokenizer.truncation_side = "left"

            # Load datasets
            retain_dataset = load_jsonl(f"data/{cfg.sample_size}/{mode}/{model_cfg.retain_dataset}.json")
            unlearn_dataset = [] if "unlearn_dataset" not in model_cfg else load_jsonl(f"data/{cfg.sample_size}/{mode}/{model_cfg.unlearn_dataset}.json")
            random.shuffle(unlearn_dataset)
            unlearn_dataset = unlearn_dataset[:unlearn_dataset_size]

            print(f"{mode=}, {base_model_name=}, {len(retain_dataset)=}, {len(unlearn_dataset)=}")

            train(cfg, meta_model, tokenizer, retain_dataset, unlearn_dataset, rc=model_cfg.rc, train_batch_size=model_cfg.train_batch_size, epochs=model_cfg.epochs, eval_every = model_cfg.eval_every, lr=model_cfg.lr, device="cuda", mode=mode)

if __name__ == "__main__":
    main()