import os
import json
import datetime
from argparse import ArgumentParser
from tkinter import E

import wandb
import torch
from trl import SFTTrainer
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM

from src.benchmarking.benchmark import Benchmark
from src.util.globals import OUTPUT_DIR
from src.util.helpers import seed_everything, create_if_not_exists, save_tokenizer, save_as_json, print_trainable_parameters    
from src.util.data import *
from trl import DataCollatorForCompletionOnlyLM
from pathlib import Path
DATASET_REGISTRY = {
    "gpqa-bio" : get_gpqa_bio_dataset,
    "gpqa-chem" : get_gpqa_chem_dataset,
    "gpqa-physics" : get_gpqa_physics_dataset,
    "gpqa-all" : get_gpqa_all_dataset,
    "wmdp_bio-forget-corpus" : get_bio_forget_dataset,
    "wmdp_bio-retain-corpus" : get_bio_retain_dataset,
    "wmdp_bio-retain-corpus-spanish" : get_bio_retain_dataset_spanish,
    "wmdp_bio-retain-corpus-russian" : get_bio_retain_dataset_russian,
    "wmdp_bio-retain-corpus-chinese" : get_bio_retain_dataset_chinese,
    "wmdp_bio-retain-corpus-spanish-to-english" : get_bio_retain_dataset_spanish_to_english,
    "wmdp_cyber-retain-corpus" : get_cyber_retain_dataset,
    "wmdp_cyber-forget-corpus" : get_cyber_forget_dataset,
    "wikitext": get_wikitext_dataset,
    "wmdp_bio-forget-corpus-mc": get_bio_forget_mc_dataset,
    "wmdp_bio-retain-corpus-mc": get_bio_retain_mc_dataset,
    "wmdp_cyber-forget-corpus-mc": get_cyber_forget_mc_dataset,
    "wmdp_cyber-retain-corpus-mc": get_cyber_retain_mc_dataset,
    "wikitext-bio-mc": get_wikitext_bio_mc_dataset,
    "wikitext-cyber-mc": get_wikitext_cyber_mc_dataset,
    "tofu-qa": get_tofu_wiki_mc_dataset,
    "mmlu": get_mmlu_dataset,
    "medmcqa": get_medmcqa_dataset,
    'medmcqa_gen': get_medmcqa_gen_dataset,
    'medmcqa-trak': get_medmcqa_trak_dataset,
    'tofu-trak': get_tofu_wiki_mc_dataset_trak,
    'pile': get_pile_dataset,
    'pile-trak': get_pile_trak_dataset,
}

# Parameters

# SEED = 42
# MAX_SEQ_LEN = 1024
# DTYPE = torch.bfloat16
# LORA_RANK = 256
# LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj",]
# LORA_ALPHA = 32
# LORA_DROPOUT = 0
# BATCH_SIZE = 1 #old: 2
# GRAD_ACCUM_STEPS = 1 #old: 4
# EPOCHS = 3
# LR = 1e-4
# WARMUP_RATIO = 0.05
# WEIGHT_DECAY = 0.01



if __name__ == "__main__":
    # Ensure reproducibility
    
    # Parse input
    parser = ArgumentParser()
    parser.add_argument("--model", "-m", 
                        required=True, 
                        help="Name of the model to finetune")
    parser.add_argument("--tokenizer", 
                        default=None,)
    parser.add_argument("--dataset", "-d",
                        required=True,
                        choices=list(DATASET_REGISTRY.keys()),
                        help="Name of the dataset to use for finetuning")
    parser.add_argument("--eval_dataset", "-ed",
                        default=None,
                        help="Name of the dataset to use for evaluating")
    parser.add_argument("--n_samples", "-n", 
                        type=int,
                        default=100,
                        help="Number of samples in a dataset to use")
    parser.add_argument("--save_dir", default=None, help="Save the model to the output directory")
    parser.add_argument("--system_prompt", "-s", default="")

    parser.add_argument("--lora_rank", "-rank", type=int, default=128, help="Rank of the LoRA matrix")
    parser.add_argument("--lora_alpha", "-alpha", type=int, default=16, help="Alpha of the LoRA matrix")
    parser.add_argument("--lora_dropout", "-dropout", type=float, default=0, help="Dropout of the LoRA matrix")
    parser.add_argument("--epochs", "-e", type=int, default=3, help="Number of epochs")
    parser.add_argument("--lr", "-lr", type=float, default=2e-4, help="Learning rate")
    parser.add_argument("--warmup_ratio", "-wr", type=float, default=0.05, help="Warmup ratio")
    parser.add_argument("--weight_decay", "-wd", type=float, default=0.01, help="Weight decay")
    parser.add_argument("--wandb_tags", "-wt", type=str, default=None, help="Wandb tags")
    parser.add_argument("--n_skip_samples", "-ns", type=int, default=None, help="Number of samples to skip")
    parser.add_argument("--no_lora", "-nl", action="store_true", help="Do not use LoRA")
    parser.add_argument("--batch_size", "-bs", type=int, default=1, help="Batch size")
    parser.add_argument("--dataset_text_field", "-dtf", type=str, default=None, help="Dataset text field")
    parser.add_argument("--random_sample_training_dataset", "-rtd", action="store_true", help="Randomly sample training dataset")
    parser.add_argument("--seed", "-seed", type=int, default=42, help="Seed")
    parser.add_argument("--train_layers", "-tl", type=str, default=None, help="Train layers")
    args = parser.parse_args()
    



    if args.tokenizer is None:
        args.tokenizer = args.model

    if args.dataset_text_field is None:
        if 'gpt' in args.model:
            args.dataset_text_field = "text"

    SEED = args.seed
    seed_everything(SEED)

    MAX_SEQ_LEN = 1024
    DTYPE = torch.bfloat16
    LORA_RANK = args.lora_rank
    LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj",]
    if "pythia" in args.model:
        LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj", "up_proj", "down_proj"]
    LORA_ALPHA = args.lora_alpha
    LORA_DROPOUT = args.lora_dropout
    BATCH_SIZE = args.batch_size
    GRAD_ACCUM_STEPS = 1
    EPOCHS = args.epochs
    LR = args.lr
    WARMUP_RATIO = args.warmup_ratio
    WEIGHT_DECAY = args.weight_decay


    training_params = {
        'seed': SEED,
        'max_seq_len': MAX_SEQ_LEN,
        'dtype': DTYPE,
        'lora_rank': LORA_RANK,
        'lora_alpha': LORA_ALPHA,
        'lora_dropout': LORA_DROPOUT,
        'batch_size': BATCH_SIZE,
        'grad_accum_steps': GRAD_ACCUM_STEPS,
        'epochs': EPOCHS,
        'lr': LR,
        'warmup_ratio': WARMUP_RATIO,
        'weight_decay': WEIGHT_DECAY,
        'n_samples': args.n_samples,
        'task': args.dataset,
        'no_lora': args.no_lora,
        'training_dataset': args.dataset,
        'eval_dataset': args.eval_dataset,
        'random_sample_training_dataset': args.random_sample_training_dataset,
        'train_layers': args.train_layers,
    }

    # Get destination subdir
    date = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
    model_str = args.model.split("/")[-1]
    run_name = f"{model_str}__{args.dataset}__{args.n_samples}__{date}"
    output_dest = os.path.join(OUTPUT_DIR, "finetuning", run_name)
    create_if_not_exists(output_dest)
    
    # Initialize wandb
    wandb_tags = None
    if args.wandb_tags:
        wandb_tags = args.wandb_tags.split(",")
    wandb.init(project="finetuning", dir=output_dest, name = run_name, tags=wandb_tags, config=training_params)
    
    get_n_data = args.n_samples 
    if args.random_sample_training_dataset:
        get_n_data = None
    # Get dataset
    if 'gemma' in args.model:
        args.system_prompt = None
    dataset = DATASET_REGISTRY[args.dataset](limit=get_n_data, system_prompt=args.system_prompt)
    
    if args.random_sample_training_dataset:
        import numpy as np
        ds_idx = np.random.choice(len(dataset), size=args.n_samples, replace=False)
        dataset = dataset.select(ds_idx)
    # Get model
    model = AutoModelForCausalLM.from_pretrained(
            args.model,
            torch_dtype=DTYPE,
            device_map="auto",
            trust_remote_code=True,
    )
    model.config.use_cache = False
    model.config.pretraining_tp = 1
    model.gradient_checkpointing_enable()
    
    # Get tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
    tokenizer.padding_side = 'right'
    tokenizer.pad_token = tokenizer.eos_token
    
    # Set up LORA   
    if not args.no_lora:
        peft_config = LoraConfig(
            model,
            r = LORA_RANK, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
            target_modules = LORA_TARGET_MODULES,
            lora_alpha = LORA_ALPHA,
            lora_dropout = LORA_DROPOUT, # Supports any, but = 0 is optimized
            bias = "none",    # Supports any, but = "none" is optimized
        )
        model = get_peft_model(model, peft_config)
        # get the model size
        print_trainable_parameters(model)
    
    else:
        if args.train_layers:
            train_layers = args.train_layers.split(",")

    def formatting_prompts_func(example):
        output_texts = []

        for i in range(len(example)):
            full_question = f"The following are multiple choice questions (with answers) about biology.\n\n"
            full_question += example["question"][i].strip() + "\n"
            for idx, choice in enumerate(example["choices"][i]):
                full_question += f"{WMDP_OPTIONS[idx]}. {choice}\n"
            full_question += "Answer:"

            answer_idx = example['answer'][i]
            full_answer = f"{WMDP_OPTIONS[answer_idx]}. {example['choices'][i][answer_idx]}"
            text = full_question + full_answer
            output_texts.append(text)
        return output_texts

    response_template = "Answer:"
    # trust me bro
    response_template_ids = [16533]
    collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)

    # initialize trainer
    trainer = SFTTrainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = dataset,
        max_seq_length = MAX_SEQ_LEN,
        dataset_num_proc = 2,
        peft_config=peft_config if not args.no_lora else None,
        packing = False, # Can make training 5x faster for short sequences.
        dataset_text_field = None if not args.dataset_text_field else args.dataset_text_field,
        args = TrainingArguments( #these can be retrieved later on so no need to make them into Global constants for saving purposes
            per_device_train_batch_size = BATCH_SIZE,
            gradient_accumulation_steps = GRAD_ACCUM_STEPS,
            warmup_ratio = WARMUP_RATIO,
            num_train_epochs = EPOCHS,
            learning_rate = LR,
            logging_steps = 1,
            optim = "adamw_torch",
            weight_decay = WEIGHT_DECAY,
            lr_scheduler_type = "linear",
            seed = SEED,
            output_dir = output_dest,
            report_to="wandb",
            run_name=run_name,
            save_strategy="no",
        ),
    )
    # hack to make trainer return eval loss for peft model
    trainer.can_return_loss = True
    
    # GPU memory usage monitoring - current stats
    gpu_stats = torch.cuda.get_device_properties(0)
    start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
    print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
    print(f"{start_gpu_memory} GB of memory reserved.")
    
    ### Finetune
    trainer_stats = trainer.train()
    
    # GPU memory usage monitoring - final stats
    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
    used_percentage = round(used_memory         /max_memory*100, 3)
    lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
    print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
    print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
    print(f"Peak reserved memory = {used_memory} GB.")
    print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
    print(f"Peak reserved memory % of max memory = {used_percentage} %.")
    print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
    
    ### Log parameters
    
    # Get user-defined global variables (for reproducibility)
    # assumption = user defined globals consist only uppercase letter
    global_params = {key : globals()[key] for key in list(globals().keys()) if key.isupper()}
    save_as_json(output_dest, "global_params.json", global_params)
    
    # Get training arguments
    train_params = vars(trainer.args)
    save_as_json(output_dest, "train_params.json", train_params)
    
    ### Evaluate model
    wandb.finish() # benchmark has its own wandb.init
    if not args.no_lora:
        merged_model = trainer.model.merge_and_unload()
    else:
        merged_model = trainer.model
    if args.save_dir:
        peft_dir = os.path.join(args.save_dir, "peft")
        Path(peft_dir).mkdir(parents=True, exist_ok=True)
        model.save_pretrained(peft_dir)
        tokenizer.save_pretrained(peft_dir)
        merged_model.save_pretrained(args.save_dir)
        tokenizer.save_pretrained(args.save_dir)

    if not args.eval_dataset:
        args.eval_dataset = args.dataset
    if args.n_skip_samples is None:
        args.n_skip_samples = args.n_samples
        
    bench = Benchmark(
        output_dest,
        tasks=args.eval_dataset.split(","),
        wandb_project="finetuning_results",
        wandb_tags=None if args.wandb_tags is None else args.wandb_tags.split(","),
        config=training_params,
        run_name=run_name,
        save_requests=False,
        upload_requests_to_hf=False,
        ignore_chat_template=True if args.dataset_text_field is not None else False,
        system_prompt=args.system_prompt,
        skip_n_samples=args.n_skip_samples,
    )
    bench.run(merged_model, tokenizer)