import os
import datetime
import sys
import math

import numpy as np
import torch
from transformers import AdamW
from ray.experimental.tqdm_ray import tqdm
import ray
import warnings

from rmu.utils import load_model, get_params, forward_with_cache, get_data

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))


def run_rmu(
    updated_model,
    frozen_model,
    tokenizer,
    forget_data_list,
    retain_data_list,
    args,
):
    rmu_config = vars(args)
    print("====rmu Config====")
    print("\n".join(f"{k}={v}" for k,v in rmu_config.items()))
    print("=====")
    device = "cuda"


    updated_model = updated_model.train()
    updated_model = updated_model.to(device)
    frozen_model = frozen_model.to(device)
    print(f"{args.param_ids=}")
    params = get_params(updated_model, args.layer_ids, args.param_ids)
    optimizer = AdamW(params, lr=args.lr)
    frozen_module = eval(
        args.module_str.format(model_name="frozen_model", layer_id=args.layer_id)
    ).to(device)
    updated_module = eval(
        args.module_str.format(model_name="updated_model", layer_id=args.layer_id)
    ).to(device)

    control_vectors_list = []
    for i in range(len(forget_data_list)):
        random_vector = torch.rand(1,1, updated_model.config.hidden_size, dtype=updated_model.dtype, device=device)
        control_vec = random_vector / torch.norm(random_vector) * args.steering_coeff_list[i]
        control_vectors_list.append(control_vec)

    num_batches = min(
        args.max_num_batches,
        min([len(f) for f in forget_data_list]),
        min([len(r) for r in retain_data_list]),
    ) * len(forget_data_list)
    
    truncation_side = tokenizer.truncation_side
    tokenizer.truncation_side="right"

    for epoch in range(1):
        print(f"======= Epoch {epoch} =======")
        for idx in tqdm(range(num_batches), desc=f"Cut {epoch=}"):
            topic_idx = idx % len(forget_data_list)
            batch_idx = idx // len(forget_data_list)
            control_vec = control_vectors_list[topic_idx]
            unlearn_batch = forget_data_list[topic_idx][batch_idx]
            retain_batch = retain_data_list[topic_idx][batch_idx]

            # Unlearning loss
            max_length = 512 if topic_idx == 0 else 768
            unlearn_inputs = tokenizer(
                unlearn_batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length
            ).to(device)
            updated_forget_activations = forward_with_cache(
                updated_model, unlearn_inputs, module=updated_module, no_grad=False
            ).to(device) 

            unlearn_loss = torch.nn.functional.mse_loss(
                updated_forget_activations, control_vec
            )

            # Retain loss
            retain_inputs = tokenizer(
                retain_batch, return_tensors="pt", padding=True, truncation=True, max_length=512
            ).to(device)
            updated_retain_activations = forward_with_cache(
                updated_model, retain_inputs, module=updated_module, no_grad=False
            ).to(device)
            frozen_retain_activations = forward_with_cache(
                frozen_model, retain_inputs, module=frozen_module, no_grad=True
            ).to(device)

            retain_loss = torch.nn.functional.mse_loss(
                updated_retain_activations, frozen_retain_activations
            )
            retain_loss *= args.alpha[topic_idx]

            # Update model  
            loss = unlearn_loss + retain_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # print(f"loss: {loss.item():.4g} | unlearn_loss: {unlearn_loss.item():.4g} | retain_loss: {retain_loss.item():.4g} | param_change: {params[0].grad.abs().mean().item():.4g}")
            
            # ======= Logging ======
            if args.verbose:
                frozen_forget_activations = forward_with_cache(frozen_model, unlearn_inputs, module=frozen_module, no_grad=True).to(updated_model.device)
                unlearn_cosine= torch.nn.functional.cosine_similarity(updated_forget_activations, frozen_forget_activations, dim=-1).mean()
                retain_cosine = torch.nn.functional.cosine_similarity(updated_retain_activations, frozen_retain_activations, dim=-1).mean()
                
                print(f"unlearn_cosine_sim={unlearn_cosine.item()}")
                print(f"retain_cosine_sim={retain_cosine.item()}")
                print(f"Topic {topic_idx} updated_forget_activations.norm=",torch.mean(updated_forget_activations.norm(dim=-1).mean(dim=1), dim=0).item())
                print(f"Topic {topic_idx} frozen_forget_activations.norm=",torch.mean(frozen_forget_activations.norm(dim=-1).mean(dim=1), dim=0).item())
                print(f"Topic {topic_idx} updated_retain_activations.norm=",torch.mean(updated_retain_activations.norm(dim=-1).mean(dim=1), dim=0).item())
                print(f"Topic {topic_idx} frozen_retain_activations.norm=",torch.mean(frozen_retain_activations.norm(dim=-1).mean(dim=1), dim=0).item())


    tokenizer.truncation_side = truncation_side
    # Save model
    if args.output_dir:
        path = args.output_dir
    else:
        date = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        # path = f"models/{args.model_name_or_path}_alpha-{args.alpha}_batches-{num_batches}_layer-{args.layer_id}_{date}"
        # path = f"models/cut/years/{args.model_name_or_path}_alpha-{args.alpha}_batches-{num_batches}_layer-{args.layer_id}"
        path = f"models/cut/mmlu/{args.model_name_or_path}_alpha-{args.alpha}_batches-{num_batches}_layer-{args.layer_id}"
    updated_model.save_pretrained(path)
    tokenizer.save_pretrained(path)
    print(f"Saved model to {path}")
    return path


def get_args():
    import argparse

    parser = argparse.ArgumentParser()
    ### Model arguments
    parser.add_argument(
        # "--model_name_or_path", type=str, default="HuggingFaceH4/zephyr-7b-beta"
        "--model_name_or_path", type=str, default="meta-llama/Meta-Llama-3-8B"
    )
    parser.add_argument(
        "--module_str", type=str, default="{model_name}.model.layers[{layer_id}]"
    )
    parser.add_argument(
        "--output_dir", type=str, default=None
    )

    ### Data arguments
    parser.add_argument(
        "--retain_corpora",
        type=str,
        # default="fineweb_edu_seed-42/split_0,fineweb_edu_seed-42/split_1,fineweb_edu_seed-42/split_2,fineweb_edu_seed-42/split_3,fineweb_edu_seed-42/split_4",
        default="mmlu_cats_random_trimmed/corpus_mmlu_health,mmlu_cats_random_trimmed/corpus_mmlu_history,mmlu_cats_random_trimmed/corpus_mmlu_law,mmlu_cats_random_trimmed/corpus_mmlu_philosophy,mmlu_cats_random_trimmed/corpus_mmlu_social sciences",
        help="comma-separated list of corpora to retain",
    )
    parser.add_argument(
        "--forget_corpora",
        type=str,
        # default="dates-corpus-rephrased/rephrased_events",
        # default="dates-years-trimmed/corpus_split_0,dates-years-trimmed/corpus_split_1,dates-years-trimmed/corpus_split_2,dates-years-trimmed/corpus_split_3,dates-years-trimmed/corpus_split_4",
        default="mmlu_cats_random_trimmed/corpus_mmlu_STEM,mmlu_cats_random_trimmed/corpus_mmlu_business,mmlu_cats_random_trimmed/corpus_mmlu_chemistry,mmlu_cats_random_trimmed/corpus_mmlu_culture,mmlu_cats_random_trimmed/corpus_mmlu_geography",
        help="comma-separated list of corpora to forget",
    )
    ### rmu hyperparameters
    parser.add_argument("--alpha", type=str, default="700,700,700,700,700", help="retain weight")
    parser.add_argument(
        "--steering_coeffs",
        type=str,
        default="20,20,20,20,20",
        help="Steer vector weight in order of topic",
    )
    parser.add_argument("--lr", type=float, default=5e-5, help="learning rate")
    parser.add_argument("--min_len", type=int, default=0)
    parser.add_argument("--max_len", type=int, default=2000)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--max_num_batches", type=int, default=80)
    parser.add_argument("--layer_id", type=int, default=7, help="layer to unlearn")
    parser.add_argument("--layer_ids", type=str, default="5,6,7", help="update layers")
    parser.add_argument("--param_ids", type=str, default="6", help="update params")
    parser.add_argument("--seed", type=int, default=42, help="Seed")
    parser.add_argument("--verbose", action="store_true", help="Logging the activations norms and cosine at each step")

    args = parser.parse_args()
    args.retain_corpora = args.retain_corpora.split(",")
    args.forget_corpora = args.forget_corpora.split(",")
    args.steering_coeff_list = [float(c) for c in args.steering_coeffs.split(",")]
    args.alpha = [float(c) for c in args.alpha.split(",")]
    args.layer_ids = [int(layer_id) for layer_id in args.layer_ids.split(",")]
    args.param_ids = [int(param_id) for param_id in args.param_ids.split(",")]
    return args 

@ray.remote(num_gpus=1)
def main(args, alpha):
    from just_eval import just_eval
    SEED = args.seed
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    i = alpha
    # args.alpha = [int(i) for _ in range(5)]
    args.alpha = [float(i) for _ in range(5)]
    # args.alpha = [int(i * 100) for _ in range(5)]
    device = "cuda"
    frozen_model, tokenizer = load_model(args.model_name_or_path)
    frozen_model = frozen_model.to(device)
    updated_model, tokenizer = load_model(args.model_name_or_path)
    updated_model = updated_model.to(device)
    forget_data_list, retain_data_list = get_data(
        args.forget_corpora,
        args.retain_corpora,
        args.min_len,
        args.max_len,
        args.batch_size,
    )
    
    # print(f"\n{forget_data_list=}\n\n{retain_data_list=}")
    # return
    warnings.filterwarnings("ignore", message="Using a target size .* that is different to the input size .*")
    model_path = run_rmu(
        updated_model,
        frozen_model,
        tokenizer,
        forget_data_list,
        retain_data_list,
        args
    )

    # val_files =[f"dates-years-trimmed/split_{split_num}" for split_num in range(5)]
    val_files =[f"mmlu_cats_random_trimmed/mmlu_{mmlu_cat}" for mmlu_cat in ["STEM", "business", "chemistry", "culture", "geography"]]
    # results_file: str = "evals/cut_evals/results_years.json"
    results_file: str = "evals/cut_evals/results_mmlu.json"
    just_eval(
        model_path,
        dev_set="dates-years/dev",
        k_shot=0,
        val_files=val_files,
        results_file=results_file
    )

    # val_files =[f"mmlu_cats_random_trimmed/mmlu_{mmlu_cat}" for mmlu_cat in ["health", "history", "law", "philosophy", "social sciences"]]
    val_files =[f"mmlu_cats_random_trimmed/mmlu_{mmlu_cat}" for mmlu_cat in ["health", "history", "law", "philosophy", "social sciences"]]
    just_eval(
        model_path,
        dev_set="dates-years/dev",
        k_shot=0,
        val_files=val_files,
        results_file=results_file
    )


if __name__ == "__main__":
    args = get_args()
    ray.init()
    deps = []
    # Ignore specific UserWarning
    warnings.filterwarnings("ignore", message="Using a target size .* that is different to the input size .*")
    # for i in range(5, 20, 2):
    # for i in range(20, 18, 1):
    # coeffs = [2, 4, 8, 16]
    values = []
    # current_value = 0.001

    # while True:
    #     values.append(current_value)
    #     if current_value >= 3000:
    #         break
    #     current_value *= 10

    # coeffs = [10 * (2 ** i) for i in range(int(math.log(3000 / 10, 2)) + 1)]
    values = [1280, 1290, 1300, 1310, 1320]
    values = [5, 6, 8, 12, 20, 36]
    print(f"{len(values)=}\n{values=}")
    for i in values:
        deps += [
            main.remote(args, alpha=i)
            # main.remote(args, alpha=i * 0.5)
            # main.remote(args, alpha=1 / i)
        ]
    
    for dep in deps:
        ray.get(dep)
