import argparse
import os
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, AutoTokenizer, TrainingArguments, logging, set_seed
from evaluation.CompAcc_FEqAcc import check_CompAccRunEqAcc_cpp
import torch
from trl import SFTTrainer

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type = str, default = "pseudocode-cpp", choices = ["pseudocode-cpp", "SHORTpseudocode-cpp"])

    parser.add_argument("--num_epochs", type = int, default = 20)
    parser.add_argument("--batch_size", type = int, default = 2)
    parser.add_argument("--eval_batch_size", type = int, default = 2)
    parser.add_argument("--gradient_accumulation_steps", type = int, default = 4) #2
    parser.add_argument("--eos_token_id", type = int, default = 49152)
    parser.add_argument("--max_seq_length", type = int, default = 1024)

    parser.add_argument("--learning_rate", type = float, default = 1e-4)
    parser.add_argument("--lr_scheduler_type", type = str, default = "cosine")
    parser.add_argument("--num_warmup_steps", type = int, default = 100)
    parser.add_argument("--weight_decay", type = float, default = 0.05)

    parser.add_argument("--local_rank", type = int, default = 0)
    parser.add_argument("--fp16", action = "store_true", default = False)
    parser.add_argument("--bf16", action = "store_true", default = False)
    parser.add_argument("--gradient_checkpointing", action = "store_true", default = False)
    parser.add_argument("--seed", type = int, default = 0)
    parser.add_argument("--num_workers", type = int, default = None)
    parser.add_argument("--output_dir", type = str, default = "./CG2b_checkpointsSFT_MultiGPU")
    parser.add_argument("--log_freq", default = 1, type = int)
    parser.add_argument("--eval_freq", default = 500, type = int)
    parser.add_argument("--save_freq", default = 500, type = int)
    return parser.parse_args()

def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak. 
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits, dim=-1)
    return pred_ids

def compute_val_metrics(pred_label):
    global tokenizer, args
    pred_ids = pred_label.predictions
    labels_ids = pred_label.label_ids
    #print ("pred_ids", pred_ids)
    #print ("labels_ids", labels_ids)
    pred_ids[pred_ids == -100] = tokenizer.pad_token_id
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens = True)
    pred_str = [p.split("<|fim_middle|>")[-1].strip().removesuffix("<|file_separator|>").strip() \
                    for p in pred_str]
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    label_str = [l.split("<|fim_middle|>")[-1].strip().removesuffix("<|file_separator|>").strip() \
                    for l in label_str]
    '''
    print ("\n\nlen(pred_str)", len(pred_str), "\n\n")
    for p in pred_str:
        print ("--------\n", p)
    print ("\n\nlabel_str", len(label_str), "\n\n")
    for l in label_str:
        print ("--------\n", l)
    '''

    metrics_output = check_CompAccRunEqAcc_cpp(
        pred_str,
        args.output_dir,
        args.task
    )
    return metrics_output

def formatting_prompts_func(examples):
    #if packing False
    #print ("examples", examples)
    output_text = []
    for i in range(len(examples["prompt"])):
        inp = examples["prompt"][i]
        out = examples["completion"][i]
        text = f"<|fim_prefix|>{inp}<|fim_suffix|><|fim_middle|>{out}"
        '''
        chat = [
            {"role": "user", "content": inp},
            {"role": "model", "content": out},
        ]
        text = tokenizer.apply_chat_template(chat, tokenize = False, add_generation_prompt = False)
        '''
        output_text.append(text)
    return output_text
    '''
    #if packing True
    inp = examples["prompt"]
    out = examples["completion"]
    text = f"{inp}{out}"
    return text
    '''

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def create_datasets(tokenizer, args):
    train_path = f"dataset/{args.task}/train.{args.task}.jsonl"
    test_path = f"dataset/{args.task}/test.{args.task}.jsonl"
    valid_path = f"dataset/{args.task}/test.{args.task}.jsonl"
    dataset = load_dataset("json", data_files = {"train": train_path, "valid": valid_path, "test": test_path})
    train_data = dataset["train"]
    test_data = dataset["test"]
    valid_data = dataset["valid"]
    print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
    return train_data, valid_data

def run_training(args, model_path, train_data, val_data, accelerator):
    print("Loading the model")

    lora_config = LoraConfig(
        r = 8,
        target_modules = ["q_proj", "o_proj", "k_proj", "v_proj",
                        "gate_proj", "up_proj", "down_proj"],
        task_type = "CAUSAL_LM",
    )

    print("Starting main loop")

    training_args = TrainingArguments(
        output_dir = args.output_dir,
        dataloader_drop_last = False,
        evaluation_strategy = "steps",
        num_train_epochs = args.num_epochs,
        eval_steps = args.eval_freq,
        save_strategy = "steps",
        save_steps = args.save_freq,
        save_only_model = True,
        logging_steps = args.log_freq,
        per_device_train_batch_size = args.batch_size,
        per_device_eval_batch_size = args.eval_batch_size,
        learning_rate = args.learning_rate,
        lr_scheduler_type = args.lr_scheduler_type,
        warmup_steps = args.num_warmup_steps,
        gradient_accumulation_steps = args.gradient_accumulation_steps,
        gradient_checkpointing = args.gradient_checkpointing,
        weight_decay = args.weight_decay,
        run_name = "codegemma-2b-SFT-1May-multigpu",
        ddp_find_unused_parameters = False,
    ) # report_to = "wandb", fp16 = args.fp16, bf16 = args.bf16,

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_path, 
        quantization_config=bnb_config
    ) #load_in_8bit = True, #bits and bytes https://huggingface.co/google/gemma-2b
    # device_map = {"": accelerator.process_index}

    #train_data, val_data, model = accelerator.prepare(train_data, val_data, model)

    trainer = accelerator.prepare(SFTTrainer(
        model = model,
        args = training_args,
        max_seq_length = args.max_seq_length,
        train_dataset = train_data,
        eval_dataset = val_data,
        peft_config = lora_config,
        formatting_func = formatting_prompts_func,
        packing = False,
        compute_metrics = compute_val_metrics, 
        preprocess_logits_for_metrics = preprocess_logits_for_metrics
    ))

    print_trainable_parameters(trainer.model)

    print("Training...")
    trainer.train()

    print("Saving last checkpoint of the model")
    trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))


def main(args):
    global tokenizer
    accelerator = Accelerator()
    model_id = "google/codegemma-2b"
    tokenizer = GemmaTokenizer.from_pretrained(model_id, truncation_side = "left", padding_side = "left")
    train_dataset, eval_dataset = create_datasets(tokenizer, args)
    run_training(args, model_id, train_dataset, eval_dataset, accelerator)


if __name__ == "__main__":
    tokenizer = None
    args = get_args()
    set_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)
    logging.set_verbosity_error()
    main(args)