import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import datetime
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AdamW
import tqdm as tqdm
from peft import PeftModel
from eraser.utils import *

import random

def run_eraser(
    updated_model,
    frozen_model,
    tokenizer,
    forget_data_list,
    retain_data_list,
    args,
):
    print("==== Config ====")
    for k, v in vars(args).items():
        print(f"{k} = {v}")
    print("====================")

    def print_configuration_settings(args):
        settings = {
            'BCE Loss': args.use_BCE_loss,
            'LoRA': args.use_lora,
            'TMP Inputs': args.use_TMP_INPUTS
        }
        
        print("\n=== Configuration Settings ===")
        for setting, value in settings.items():
            status = "USING" if value else "NOT USING"
            print(f"{status} {setting}")
        print("============================\n")

    print_configuration_settings(args)

    device = updated_model.device
    updated_model.train()

    if args.use_BCE_loss:
        layer_preference_head = PreferenceHead(updated_model.config.hidden_size).to(updated_model.device)
        print('Loaded Preference Head!!!')
        params = [p for p in updated_model.parameters() if p.requires_grad] + list(layer_preference_head.parameters())
    else:
        params = [p for p in updated_model.parameters() if p.requires_grad]

    optimizer = AdamW(params, lr=args.lr)

    frozen_module = eval(f"frozen_model.model.layers[{args.layer_id}]")
    
    if args.use_lora:
        updated_module = eval(args.module_str.format(model_name="updated_model", layer_id=args.layer_id))
    else:
        updated_module = eval(f"updated_model.model.layers[{args.layer_id}]")

    num_batches = min(
        args.max_num_batches,
        min(len(f) for f in forget_data_list),
        min(len(r) for r in retain_data_list),
    )

    truncation_side = tokenizer.truncation_side
    tokenizer.truncation_side="right"

    shared_dominant_vectors = compute_shared_dominant_vectors(updated_model, tokenizer, forget_data_list, args.layer_id, updated_model.device, sample_size=args.num_samples_for_pcs, num_components=args.num_PCs, method=args.method_for_pcs)
    print("======= Training with Activation-Based Weighting =======")
    for epoch in range(args.epoch):
        print(f"======= Epoch {epoch} =======")
        with tqdm.tqdm(total=num_batches) as pbar:

            for batch_idx in range(num_batches):
                
                topic_idx = batch_idx % len(forget_data_list)
                batch_idx = batch_idx // len(forget_data_list)
                unlearn_batch = forget_data_list[topic_idx][batch_idx]
                retain_batch = retain_data_list[topic_idx][batch_idx]

                max_length = 512 if topic_idx == 0 else 768

                unlearn_inputs = tokenizer(
                unlearn_batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length
                ).to(updated_model.device)
                updated_forget_activations = forward_with_cache(
                    updated_model, unlearn_inputs, module=updated_module, no_grad=False
                ).to(updated_model.device)
                
                if args.use_TMP_INPUTS:
                        TMP_INPUTS = tokenizer(
                        [tokenizer.pad_token*max_length], return_tensors="pt", padding="max_length", truncation=True, max_length=max_length
                        ).to(updated_model.device)
                        frozen_forget_activations = forward_with_cache(
                            frozen_model, TMP_INPUTS, module=frozen_module, no_grad=True
                        ).to(updated_model.device)
                else:
                    frozen_forget_activations = remove_subspace(updated_forget_activations, shared_dominant_vectors)

                retain_inputs = tokenizer(
                    retain_batch, return_tensors="pt", padding=True, truncation=True, max_length=512
                ).to(updated_model.device)
                updated_retain_activations = forward_with_cache(
                    updated_model, retain_inputs, module=updated_module, no_grad=False
                ).to(updated_model.device)
                frozen_retain_activations = forward_with_cache(
                    frozen_model, retain_inputs, module=frozen_module, no_grad=True
                ).to(updated_model.device)

                unlearn_loss = torch.nn.functional.mse_loss(updated_forget_activations, frozen_forget_activations)
                retain_loss = torch.nn.functional.mse_loss(updated_retain_activations, frozen_retain_activations)
                retain_loss *= args.alpha[topic_idx]

                if args.use_BCE_loss:
                    score_forget = layer_preference_head(updated_forget_activations.float())
                    score_retain = layer_preference_head(updated_retain_activations.float())
                    L_aux = compute_preference_loss(score_retain, score_forget)*args.lux_weight
                    loss = unlearn_loss + retain_loss + L_aux

                else:
                    loss = unlearn_loss + retain_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                print('\n' + '*' * 30)
                if args.use_BCE_loss:
                    print(f"[TOPIC {topic_idx}][Batch {batch_idx}] loss: {loss.item():.4g} | unlearn_loss: {unlearn_loss.item():.4g} | retain_loss: {retain_loss.item():.4g} | L_aux: {L_aux.item():.4g} | param_change: {params[0].grad.abs().mean().item():.4g}")
                else:
                    print(f"[TOPIC {topic_idx}][Batch {batch_idx}] loss: {loss.item():.4g} | unlearn_loss: {unlearn_loss.item():.4g} | retain_loss: {retain_loss.item():.4g} | param_change: {params[0].grad.abs().mean().item():.4g}")
                frozen_forget_activations = forward_with_cache(frozen_model, unlearn_inputs, module=frozen_module, no_grad=True).to(device)
                unlearn_cosine = F.cosine_similarity(updated_forget_activations, frozen_forget_activations, dim=-1).mean()
                retain_cosine = F.cosine_similarity(updated_retain_activations, frozen_retain_activations, dim=-1).mean()

                print(f"unlearn_cosine_sim={unlearn_cosine.item()}")
                print(f"retain_cosine_sim={retain_cosine.item()}")
                pbar.update(1)

    tokenizer.truncation_side = truncation_side
    if args.output_dir:
        path = args.output_dir
    else:
        date = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        path = f"models/{args.model_name_or_path}_alpha-{args.alpha}_batches-{num_batches}_layer-{args.layer_id}_{date}"
    
    if args.use_lora:
        updated_model = updated_model.merge_and_unload()

        if isinstance(updated_model, PeftModel):
            updated_model = updated_model.base_model

        for n, p in updated_model.named_parameters():
            if 'lora' in n:
                print('LORA Parameters exist in model')
    updated_model.save_pretrained(path)
    tokenizer.save_pretrained(path)
    print(f"Saved model to {path}")


def get_args():
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name_or_path", type=str, default="/data/models_ckpt/zephyr-7b-beta")
    parser.add_argument("--module_str", type=str, default="{model_name}.base_model.model.model.layers[{layer_id}]")
    parser.add_argument("--output_dir", type=str, default="/llm_unlearning/wmdp/models/unlearned_model")

    parser.add_argument("--retain_corpora", type=str, default="wikitext,wikitext")
    parser.add_argument("--forget_corpora", type=str, default="bio-forget-corpus,cyber-forget-corpus")
    parser.add_argument("--alpha", type=str, default="1200,1200")

    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--min_len", type=int, default=50)
    parser.add_argument("--max_len", type=int, default=2000)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--max_num_batches", type=int, default=500)
    parser.add_argument("--layer_id", type=int, default=7)
    parser.add_argument("--layer_ids", type=str, default="5,6,7")

    parser.add_argument("--epoch", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--verbose", action="store_true")

    parser.add_argument("--lora_layer_selection",type=str, default="gate_proj,up_proj,down_proj")
    parser.add_argument("--r",type=int, default=256)
    parser.add_argument("--lora_alpha",type=int, default=512)
    parser.add_argument("--lora_dropout",type=float, default=0.94)

    parser.add_argument("--use_BCE_loss",type=bool, default=True)
    parser.add_argument("--use_lora",type=bool, default=True)
    parser.add_argument("--use_TMP_INPUTS",type=bool, default=False)
    parser.add_argument("--lux_weight", type=float, default=1.05)

    parser.add_argument("--num_PCs", type=int, default=20)
    parser.add_argument("--num_samples_for_pcs", type=int, default=100)
    parser.add_argument("--method_for_pcs", type=str, default="power_iteration")

    args = parser.parse_args()
    args.retain_corpora = args.retain_corpora.split(",")
    args.forget_corpora = args.forget_corpora.split(",")
    args.lora_layer_selection = args.lora_layer_selection.split(",")

    args.steering_coeff_list = [float(c) for c in args.steering_coeffs.split(",")]
    args.alpha = [float(c) for c in args.alpha.split(",")]
    args.layer_ids = [int(x) for x in args.layer_ids.split(",")]

    return args

if __name__ == "__main__":
    import random
    args = get_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    frozen_model, tokenizer = load_model(args.model_name_or_path)
    updated_model, tokenizer = load_model(args.model_name_or_path)

    forget_data_list, retain_data_list = get_data(
        args.forget_corpora,
        args.retain_corpora,
        args.min_len,
        args.max_len,
        args.batch_size
    )

    chosen_layers= args.layer_ids

    if args.use_lora:
        updated_model = apply_lora_to_mlp(args, updated_model, chosen_layers)

    run_eraser(
        updated_model,
        frozen_model,
        tokenizer,
        forget_data_list,
        retain_data_list,
        args
    )
