import os
import datetime
import sys
import math
from typing import Optional
import numpy as np
import torch
from transformers import AdamW
from ray.experimental.tqdm_ray import tqdm
import ray
import warnings
from dataclasses import dataclass
import json
from transformers import AdamW, AutoTokenizer, AutoModelForCausalLM
import wandb
import random
from pipeline import DataFormat

from rmu.utils import load_model, get_params, forward_with_cache, get_data

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))


test_prompts = ["Hi, my name is", "Once upon a time,", "The capital of France"]


def sample_tokens(model, tokenizer, device, prompts=test_prompts, max_length=15):
    model.eval()
    generated_texts = []
    
    for prompt_text in prompts:
        input_ids = tokenizer.encode(prompt_text, return_tensors="pt").to(device)

        # Generate tokens using the model
        outputs = model.generate(input_ids, max_length=max_length)

        # Decode the generated tokens to text
        texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
        generated_texts.extend(texts)
    
    return generated_texts

def run_rmu(
    updated_model,
    frozen_model,
    tokenizer,
    forget_data_list,
    retain_data_list,
    args,
    wandb_project_name="unlearn",
):
    rmu_config = vars(args)
    print("====rmu Config====")
    print("\n".join(f"{k}={v}" for k,v in rmu_config.items()))
    print("=====")
    device = "cuda"


    updated_model = updated_model.train()
    updated_model = updated_model.to(device)
    frozen_model = frozen_model.to(device)
    print(f"{args.param_ids=}")
    params = get_params(updated_model, args.layer_ids, args.param_ids)
    optimizer = AdamW(params, lr=args.lr)
    frozen_module = eval(
        args.module_str.format(model_name="frozen_model", layer_id=args.layer_id)
    ).to(device)
    updated_module = eval(
        args.module_str.format(model_name="updated_model", layer_id=args.layer_id)
    ).to(device)

    control_vectors_list = []
    for i in range(len(forget_data_list)):
        print(f"{i=}")
        print(f"{args.steering_coeff_list[i]=}")
        random_vector = torch.rand(1,1, updated_model.config.hidden_size, dtype=updated_model.dtype, device=device)
        control_vec = random_vector / torch.norm(random_vector) * args.steering_coeff_list[i]
        control_vectors_list.append(control_vec)

    num_batches = min(
        args.max_num_batches,
        min([len(f) for f in forget_data_list]),
        min([len(r) for r in retain_data_list]),
    ) * len(forget_data_list)
    
    print(f"inside run_rum. {num_batches=}")
    truncation_side = tokenizer.truncation_side
    tokenizer.truncation_side="right"

    for epoch in range(args.epochs):
        print(f"======= Epoch {epoch} =======")
        for idx in tqdm(range(num_batches), desc=f"Cut {epoch=}"):
            topic_idx = idx % len(forget_data_list)
            batch_idx = idx // len(forget_data_list)
            control_vec = control_vectors_list[topic_idx]
            unlearn_batch = forget_data_list[topic_idx][batch_idx]
            retain_batch = retain_data_list[topic_idx][batch_idx]

            # Unlearning loss
            max_length = 512 if topic_idx == 0 else 768
            unlearn_inputs = tokenizer(
                unlearn_batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length
            ).to(device)
            updated_forget_activations = forward_with_cache(
                updated_model, unlearn_inputs, module=updated_module, no_grad=False
            ).to(device) 

            unlearn_loss = torch.nn.functional.mse_loss(
                updated_forget_activations, control_vec
            )

            # Retain loss
            retain_inputs = tokenizer(
                retain_batch, return_tensors="pt", padding=True, truncation=True, max_length=512
            ).to(device)
            updated_retain_activations = forward_with_cache(
                updated_model, retain_inputs, module=updated_module, no_grad=False
            ).to(device)
            frozen_retain_activations = forward_with_cache(
                frozen_model, retain_inputs, module=frozen_module, no_grad=True
            ).to(device)

            retain_loss = torch.nn.functional.mse_loss(
                updated_retain_activations, frozen_retain_activations
            )
            retain_loss *= args.alpha[topic_idx]

            # Update model  
            loss = unlearn_loss + retain_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            wandb.log({
                "unlearning/unlearn_loss": unlearn_loss.item(),
                "unlearning/retain_loss": retain_loss.item(),
                "unlearning/train_loss": loss.item(),
            })
            # print(f"loss: {loss.item():.4g} | unlearn_loss: {unlearn_loss.item():.4g} | retain_loss: {retain_loss.item():.4g} | param_change: {params[0].grad.abs().mean().item():.4g}")
            
            # ======= Logging ======
            if args.verbose:
                frozen_forget_activations = forward_with_cache(frozen_model, unlearn_inputs, module=frozen_module, no_grad=True).to(updated_model.device)
                unlearn_cosine= torch.nn.functional.cosine_similarity(updated_forget_activations, frozen_forget_activations, dim=-1).mean()
                retain_cosine = torch.nn.functional.cosine_similarity(updated_retain_activations, frozen_retain_activations, dim=-1).mean()
                
                print(f"unlearn_cosine_sim={unlearn_cosine.item()}")
                print(f"retain_cosine_sim={retain_cosine.item()}")
                print(f"Topic {topic_idx} updated_forget_activations.norm=",torch.mean(updated_forget_activations.norm(dim=-1).mean(dim=1), dim=0).item())
                print(f"Topic {topic_idx} frozen_forget_activations.norm=",torch.mean(frozen_forget_activations.norm(dim=-1).mean(dim=1), dim=0).item())
                print(f"Topic {topic_idx} updated_retain_activations.norm=",torch.mean(updated_retain_activations.norm(dim=-1).mean(dim=1), dim=0).item())
                print(f"Topic {topic_idx} frozen_retain_activations.norm=",torch.mean(frozen_retain_activations.norm(dim=-1).mean(dim=1), dim=0).item())


    tokenizer.truncation_side = truncation_side
    # Save model
    path = args.output_dir
    if path is not None:
        updated_model.save_pretrained(path)
        tokenizer.save_pretrained(path)
        print(f"Saved model to {path}")
    return updated_model


def get_args():
    import argparse

    parser = argparse.ArgumentParser()
    ### Model arguments
    parser.add_argument(
        # "--model_name_or_path", type=str, default="HuggingFaceH4/zephyr-7b-beta"
        "--model_name_or_path", type=str, default="meta-llama/Meta-Llama-3-8B"
    )
    parser.add_argument(
        "--module_str", type=str, default="{model_name}.model.layers[{layer_id}]"
    )
    parser.add_argument(
        "--output_dir", type=str, default=None
    )

    ### Data arguments
    parser.add_argument(
        "--retain_corpora",
        type=str,
        # default="fineweb_edu_seed-42/split_0,fineweb_edu_seed-42/split_1,fineweb_edu_seed-42/split_2,fineweb_edu_seed-42/split_3,fineweb_edu_seed-42/split_4",
        default="mmlu_cats_random_trimmed/corpus_mmlu_health,mmlu_cats_random_trimmed/corpus_mmlu_history,mmlu_cats_random_trimmed/corpus_mmlu_law,mmlu_cats_random_trimmed/corpus_mmlu_philosophy,mmlu_cats_random_trimmed/corpus_mmlu_social sciences",
        help="comma-separated list of corpora to retain",
    )
    parser.add_argument(
        "--forget_corpora",
        type=str,
        # default="dates-corpus-rephrased/rephrased_events",
        # default="dates-years-trimmed/corpus_split_0,dates-years-trimmed/corpus_split_1,dates-years-trimmed/corpus_split_2,dates-years-trimmed/corpus_split_3,dates-years-trimmed/corpus_split_4",
        default="mmlu_cats_random_trimmed/corpus_mmlu_STEM,mmlu_cats_random_trimmed/corpus_mmlu_business,mmlu_cats_random_trimmed/corpus_mmlu_chemistry,mmlu_cats_random_trimmed/corpus_mmlu_culture,mmlu_cats_random_trimmed/corpus_mmlu_geography",
        help="comma-separated list of corpora to forget",
    )
    ### rmu hyperparameters
    parser.add_argument("--alpha", type=str, default="700,700,700,700,700", help="retain weight")
    parser.add_argument(
        "--steering_coeffs",
        type=str,
        default="20,20,20,20,20",
        help="Steer vector weight in order of topic",
    )
    parser.add_argument("--lr", type=float, default=5e-5, help="learning rate")
    parser.add_argument("--min_len", type=int, default=0)
    parser.add_argument("--max_len", type=int, default=2000)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--max_num_batches", type=int, default=80)
    parser.add_argument("--layer_id", type=int, default=7, help="layer to unlearn")
    parser.add_argument("--layer_ids", type=str, default="5,6,7", help="update layers")
    parser.add_argument("--param_ids", type=str, default="6", help="update params")
    parser.add_argument("--seed", type=int, default=42, help="Seed")
    parser.add_argument("--verbose", action="store_true", help="Logging the activations norms and cosine at each step")

    args = parser.parse_args()
    args.retain_corpora = args.retain_corpora.split(",")
    args.forget_corpora = args.forget_corpora.split(",")
    args.steering_coeff_list = [float(c) for c in args.steering_coeffs.split(",")]
    args.alpha = [float(c) for c in args.alpha.split(",")]
    args.layer_ids = [int(layer_id) for layer_id in args.layer_ids.split(",")]
    args.param_ids = [int(param_id) for param_id in args.param_ids.split(",")]
    return args 

@dataclass
class Args:
    layer_ids: list[int]
    param_ids: list[int]
    lr: float
    module_str: str
    steering_coeff_list:list[float]
    max_num_batches: int
    verbose: bool
    output_dir: str
    model_name_or_path: str
    alpha: list[float]
    layer_id: int
    seed: int
    min_len: int
    max_len: int
    epochs: int



def get_data_joined(
    forget_corpora,
    retain_corpora,
    min_len,
    max_len,
    batch_size,
    data_seed
):
    random.seed(data_seed)
    forget_data = []  
    for file in forget_corpora:
        print(f"{file=}")
        for line in open(f"data/{file}.jsonl", "r"):
            raw_text = json.loads(line)
            if isinstance(raw_text, dict):
                raw_text = raw_text["text"]
            if len(raw_text) > min_len:
                forget_data.append(str(raw_text))
    random.shuffle(forget_data)
    forget_data =[
        forget_data[i:i + batch_size]
        for i in range(1, len(forget_data), batch_size)
    ]

    retain_data = []
    for file in retain_corpora:
        for line in open(f"data/{file}.jsonl", "r"):
            raw_text = json.loads(line)["text"]
            if len(raw_text) > min_len:
                retain_data.append(str(raw_text))
    random.shuffle(retain_data)
    retain_data = [
        retain_data[i:i + batch_size]
        for i in range(0, len(retain_data), batch_size)
    ]

    return (
        [forget_data],
        [retain_data]
    )

     
def main(
    unlearn_files: list[str] = [],
    val_files: list[str] = [],
    retain_files: list[str] = [],
    val_retain_files: list[str] = [],
    dev_file: str = [],
    retain_dev_file: str = [],
    base_model: str = "",
    lr: float = 1e-7,
    epochs: int = 3,
    batch_size: int = 4,
    val_batch_size: int = 8,
    retain_coeff: int = 1,
    warmup_steps: int = 24,
    data_seed: int = 42,
    eval_every: int = 1,
    save_name: Optional[str] = None,
    wandb_project_name: str = "unlearn",
    hydra_dict: dict = {},
    data_format: DataFormat = DataFormat.CORPUS,    
    steering_coeff: float = 20,
    max_samples: int = None,
):
    print(f"{steering_coeff=}")
    # raise Exception(f"{steering_coeff=}")
    # from just_eval_prev import eval
    wandb.init(project=wandb_project_name, config={**locals(), **hydra_dict}, name=save_name)
    # max_num_batches = 160
    max_num_batches = int(9999999 if max_samples is None else max_samples/batch_size)
    args = Args(
        model_name_or_path=base_model,
        lr=lr,
        output_dir=save_name,
        alpha=[retain_coeff],
        layer_ids=[5, 6, 7],
        param_ids=[6],
        layer_id=7,
        module_str="{model_name}.model.layers[{layer_id}]",
        steering_coeff_list=[steering_coeff],
        max_num_batches=max_num_batches,
        verbose=False,
        seed=42,
        min_len=5,
        max_len=2000,
        epochs=epochs
    )
    print(f"{args.max_num_batches=}")
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    frozen_model, tokenizer = load_model(args.model_name_or_path)
    frozen_model = frozen_model.to(device)
    updated_model, tokenizer = load_model(args.model_name_or_path)
    updated_model = updated_model.to(device)
    forget_data_list, retain_data_list = get_data_joined(
        unlearn_files,
        retain_files,
        args.min_len,
        args.max_len,
        batch_size,
        data_seed,
    )
    warnings.filterwarnings(
        "ignore",
        message="Using a target size .* that is different to the input size .*"
    )
    # TODO Make all use the same model to avoid loading the model multiple times
    model = run_rmu(
        updated_model,
        frozen_model,
        tokenizer,
        forget_data_list,
        retain_data_list,
        args,
        wandb_project_name=wandb_project_name,
    )
    forget_accs = {}
    forget_accs_calibrated = {}
    forget_logits_dict = {}
    # for file in val_files:
        # forget_accs[file], forget_accs_calibrated[file], forget_logits_dict[file] = just_eval(
        #     model_path,
        #     dev_set=retain_dev_file,
        #     k_shot=0,
        #     val_files=[file],
        #     results_file=None
        # )
    retain_accs, retain_accs_calibrated, retain_logits_dict = {}, {}, {}

    # for file in val_retain_files:
        # retain_accs[file], retain_accs_calibrated[file], retain_logits_dict[file] = just_eval(
        #     model_path,
        #     dev_set=retain_dev_file,
        #     k_shot=0,
        #     val_files=val_retain_files,
        #     results_file=None
        # )
    
    # (
    #     forget_accs, forget_accs_calibrated, forget_logits_dict,
    #     retain_accs, retain_accs_calibrated, retain_logits_dict
    # )  = eval(model_path)
    import unlearn_corpus 
    (
        save_name,
        forget_accs, forget_accs_calibrated, forget_logits_dict,
        retain_accs, retain_accs_calibrated, retain_logits_dict,
        retain_accs_5_shot, retain_accs_5_shot_calibrated,
        retain_logits_5_shot_dict,
        samples
    ) =  unlearn_corpus.main(
        unlearn_type=unlearn_corpus.UnlearnType.GD,
        train_files=[],
        wrong_unlearn_files=[],
        fixed_wrong_unlearn_files=[],
        val_files=val_files,
        dev_set=dev_file,
        retain_files=[],
        val_retain_files=val_retain_files,
        retain_dev_file=retain_dev_file,
        base_model=base_model,
        lr=lr,
        name="",
        model=model,
        epochs=epochs,
        batch_size=batch_size,
        val_batch_size=val_batch_size,
        retain_coeff=retain_coeff,
        warmup_steps=warmup_steps,
        data_seed=data_seed,
        eval_every=eval_every,
        save_name=None,
        project_name=wandb_project_name,
        just_eval=True,
        disable_wandb=True,
    )

    retain_accs_5_shot, retain_accs_5_shot_calibrated, retain_logits_5_shot_dict = {}, {}, {}
    # for file in val_retain_files:

    #     retain_accs_5_shot[file], retain_accs_5_shot_calibrated[file], retain_logits_5_shot_dict[file] = just_eval(
    #         model_path,
    #         dev_set=retain_dev_file,
    #         k_shot=5,
    #         val_files=val_retain_files,
    #         results_file=None
    # )
        
    def mean(x):
        if len(x) == 0: return 0
        lst = [d[max(d.keys())] for d in x]
        return sum(lst) / len(x)

    try:

        wandb.log(
            {
                "unlearning/forget_acc": mean(forget_accs.values()),
                "unlearning/forget_acc_calibrated": mean(forget_accs_calibrated.values()),
                "unlearning/retain_acc": mean(retain_accs.values()),
                "unlearning/retain_acc_calibrated": mean(retain_accs_calibrated.values()),
                "unlearning/retain_acc_5_shot": mean(retain_accs_5_shot.values()),
                "unlearning/retain_acc_5_shot_calibrated": mean(retain_accs_5_shot_calibrated.values()),
            }
        )
    except Exception as e:
        print(f"{forget_accs=}")
        print(e)
        print("Error logging to wandb")

#    model = AutoModelForCausalLM.from_pretrained(
#        model_path, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
#    ).to(device)

    wandb.finish()
    return (
        base_model,
        forget_accs, forget_accs_calibrated, forget_logits_dict,
        retain_accs, retain_accs_calibrated, retain_logits_dict,
        retain_accs_5_shot, retain_accs_5_shot_calibrated,
        retain_logits_5_shot_dict,
        {1: sample_tokens(model, tokenizer, device, max_length=15)}
    )

    

# @ray.remote(num_gpus=1)
# def main(args, alpha):
#     from just_eval import just_eval
#     SEED = args.seed
#     torch.cuda.manual_seed(SEED)
#     torch.cuda.manual_seed_all(SEED)
#     torch.manual_seed(SEED)
#     np.random.seed(SEED)
#     i = alpha
#     # args.alpha = [int(i) for _ in range(5)]
#     args.alpha = [float(i) for _ in range(5)]
#     # args.alpha = [int(i * 100) for _ in range(5)]
#     device = "cuda"
#     frozen_model, tokenizer = load_model(args.model_name_or_path)
#     frozen_model = frozen_model.to(device)
#     updated_model, tokenizer = load_model(args.model_name_or_path)
#     updated_model = updated_model.to(device)
#     forget_data_list, retain_data_list = get_data(
#         args.forget_corpora,
#         args.retain_corpora,
#         args.min_len,
#         args.max_len,
#         args.batch_size,
#     )
#     warnings.filterwarnings("ignore", message="Using a target size .* that is different to the input size .*")
#     model_path = run_rmu(
#         updated_model,
#         frozen_model,
#         tokenizer,
#         forget_data_list,
#         retain_data_list,
#         args
#     )

#     # val_files =[f"dates-years-trimmed/split_{split_num}" for split_num in range(5)]
#     val_files =[f"mmlu_cats_random_trimmed/mmlu_{mmlu_cat}" for mmlu_cat in ["STEM", "business", "chemistry", "culture", "geography"]]
#     # results_file: str = "evals/cut_evals/results_years.json"
#     results_file: str = "evals/cut_evals/results_mmlu.json"
#     just_eval(
#         model_path,
#         dev_set="dates-years/dev",
#         k_shot=0,
#         val_files=val_files,
#         results_file=results_file
#     )

#     # val_files =[f"mmlu_cats_random_trimmed/mmlu_{mmlu_cat}" for mmlu_cat in ["health", "history", "law", "philosophy", "social sciences"]]
#     val_files =[f"mmlu_cats_random_trimmed/mmlu_{mmlu_cat}" for mmlu_cat in ["health", "history", "law", "philosophy", "social sciences"]]
#     just_eval(
#         model_path,
#         dev_set="dates-years/dev",
#         k_shot=0,
#         val_files=val_files,
#         results_file=results_file
#     )


if __name__ == "__main__":
    args = get_args()
    ray.init()
    deps = []
    # Ignore specific UserWarning
    warnings.filterwarnings("ignore", message="Using a target size .* that is different to the input size .*")
    # for i in range(5, 20, 2):
    # for i in range(20, 18, 1):
    # coeffs = [2, 4, 8, 16] d
    values = []
    # current_value = 0.001

    # while True:
    #     values.append(current_value)
    #     if current_value >= 3000:
    #         break
    #     current_value *= 10

    # coeffs = [10 * (2 ** i) for i in range(int(math.log(3000 / 10, 2)) + 1)]
    values = [1280, 1290, 1300, 1310, 1320]
    values = [5, 6, 8, 12, 20, 36]
    print(f"{len(values)=}\n{values=}")
    for i in values:
        deps += [
            main.remote(args, alpha=i)
            # main.remote(args, alpha=i * 0.5)
            # main.remote(args, alpha=1 / i)
        ]
    
    for dep in deps:
        ray.get(dep)
