import argparse
import os
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig, PeftModel
from tqdm import tqdm
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, AutoTokenizer, TrainingArguments, logging, set_seed
from evaluation.CompAcc_FEqAcc import check_CompAccRunEqAcc_cpp
import torch
from trl import SFTTrainer
from trl import PPOTrainer, PPOConfig
from trl import AutoModelForCausalLMWithValueHead, set_seed
from torch.utils.data import DataLoader
from accelerate import PartialState
import accelerate
from CG2b_RL_reward import check_CompAcc_Boolean_cpp

os.environ["WANDB_DISABLED"] = "false"

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type = str, default = "pseudocode-cpp", choices = ["pseudocode-cpp", "SHORTpseudocode-cpp"])

    parser.add_argument("--num_epochs", type = int, default = 20)
    parser.add_argument("--batch_size", type = int, default = 12)
    parser.add_argument("--eval_batch_size", type = int, default = 16)
    parser.add_argument("--eval_freq_steps", type = int, default = 10)
    parser.add_argument("--gradient_accumulation_steps", type = int, default = 4)
    #as bs high, setting gradient_accumulation_steps to 1 currently for speed-up

    parser.add_argument("--seed", type = int, default = 0)
    parser.add_argument("--output_dir", type = str, default = "./CG2b_checkpoints_RL_MultiGPU3")
    return parser.parse_args()

def tokenize(sample):
    global tokenizer
    sample["input_ids"] = tokenizer.encode(sample["query"], truncation = True, max_length = 750, return_tensors = 'pt')[0]
    #print (sample, "\n\n")
    return sample

def tokenize_maxLenPadding(sample):
    global tokenizer
    sample["input_ids"] = tokenizer.encode(sample["query"], truncation = True, padding = 'max_length', 
                            max_length = 750, return_tensors = 'pt')[0]
    return sample

def collator(data):
    #Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]
    #Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}
    return dict((key, [d[key] for d in data]) for key in data[0])

def addFIM_spclTokens(sample):
    sample["query"] = "<|fim_prefix|>" + sample["query"] + "<|fim_suffix|><|fim_middle|>"
    return sample

def test(dataloader, model, device):
    global tokenizer, args, accelerator
    test_kwargs = {
        "do_sample": False,
        "max_new_tokens": 512,
        "pad_token_id": tokenizer.eos_token_id
    }
    print ("tokenizer.eos_token_id, tokenizer.pad_token_id", tokenizer.eos_token_id, tokenizer.pad_token_id)
    print ("\n\nSTARTING TEST:")
    testOut = []
    #unwrappedModel = accelerator.unwrap_model(model)
    with torch.no_grad():
        for bIndx, test_batch in tqdm(enumerate(dataloader)):
            src_id_mat = test_batch["input_ids"]
            src_id_mat = src_id_mat[:, ~torch.all(src_id_mat == tokenizer.pad_token_id, axis = 0)]
            print ("src_id_mat", src_id_mat)
            print ("src_id_mat.shape", src_id_mat.shape)
            test_pred_idList = model.generate(input_ids = src_id_mat.to(device), **test_kwargs)
            test_pred_list = tokenizer.batch_decode(test_pred_idList, skip_special_tokens = True)
            test_pred_list = [p.split("<|file_separator|>")[0].strip().split("<|fim_middle|>")[-1].strip()\
                                for p in test_pred_list]
            testOut.extend(test_pred_list)
            #if bIndx == 3:
            #    break

    print ("\n\nlen(testOut):", len(testOut))
    metrics_output = check_CompAccRunEqAcc_cpp(
        testOut,
        args.output_dir,
        args.task
    )
    return metrics_output

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def create_datasets(tokenizer, args):
    train_path = f"dataset/{args.task}/train.{args.task}.jsonl"
    test_path = f"dataset/{args.task}/test.{args.task}.jsonl"
    valid_path = f"dataset/{args.task}/test.{args.task}.jsonl"
    dataset = load_dataset("json", data_files = {"train": train_path, "valid": valid_path, "test": test_path})
    train_data = dataset["train"]
    test_data = dataset["test"]
    valid_data = dataset["valid"]
    print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
    return train_data, valid_data

def run_training(args, model, train_data, val_dataloader):
    global tokenizer

    PPO_config = PPOConfig(
            learning_rate = 1.41e-5, 
            remove_unused_columns = False, 
            batch_size = args.batch_size, 
            log_with = "wandb",
            steps = 1000,
            gradient_accumulation_steps = args.gradient_accumulation_steps,
            ) #steps = 10000, mini_batch_size = args.batch_size // 2, , optimize_device_cache = False, ppo_epochs = args.num_epochs

    PPO_trainer = PPOTrainer(
        model = model,
        config = PPO_config,
        dataset = train_data,
        tokenizer = tokenizer,
        data_collator = collator
    )
    device = PPO_trainer.accelerator.device
    print_trainable_parameters(PPO_trainer.model)

    generation_kwargs = {
        "min_length": -1,
        "top_k": 0.0, 
        "top_p": 1.0,
        "do_sample": True,
        "pad_token_id": tokenizer.eos_token_id,
        "return_prompt": False,
        "max_new_tokens": 512
    } #"top_k": 0.0, "top_p": 1.0, ["temperature": 1 was slowing the training a bit, so trying with k,p]

    
    #test_mets = test(val_dataloader, model, device)
    #print (test_mets)
    #PPO_trainer.log_stats(test_mets)
    print("Starting main loop")
    btNum = 0
    for epoch in tqdm(range(args.num_epochs), "epoch: "):
        for batch in tqdm(PPO_trainer.dataloader): 
            query_listOfTensors = batch["input_ids"]
        
            #### Get response from SFTModel
            print ("started response")
            response_tensors = PPO_trainer.generate(query_listOfTensors, **generation_kwargs)
            print ("done response")
            batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens = True)
            print ("done response decode")
        
            #### Compute reward score
            texts = batch["response"]
            texts = [t.split("<|file_separator|>")[0].strip() for t in texts]
            print (texts)
            pipe_outputs = check_CompAcc_Boolean_cpp(texts, args.output_dir, device)
            #pipe_outputs = [1.0 for i in range(len(texts))]
            rewards = [torch.tensor(float(output)) for output in pipe_outputs]
        
            #### Run PPO step
            print ("started ppo step")
            stats = PPO_trainer.step(query_listOfTensors, response_tensors, rewards)
            print ("done ppo step")
            if ((btNum) % args.eval_freq_steps) == 1: #(btNum >= 35) and 
                modelPath = os.path.join(args.output_dir, f"ep{epoch}_bt{btNum}/")
                os.makedirs(modelPath, exist_ok = True)
                PPO_trainer.model.module.save_pretrained(modelPath)
                #test_mets = test(val_dataloader, model, device) #PPO_trainer.model
                #PPO_trainer.log_stats(stats.update(test_mets), batch, rewards)
                PPO_trainer.log_stats(stats, batch, rewards)
            else:
                PPO_trainer.log_stats(stats, batch, rewards)
            btNum += 1


    #print("Saving last checkpoint of the model")
    #PPO_trainer.model.module.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))


def main(args):
    global tokenizer, accelerator
    accelerator = Accelerator()
    device_string = PartialState().process_index

    #-----model-----
    #https://discuss.huggingface.co/t/unable-to-load-fine-tuned-llm/45432/2
    #https://huggingface.co/docs/transformers/main/en/peft
    #https://discuss.huggingface.co/t/correct-way-to-save-load-adapters-and-checkpoints-in-peft/77836/2
    #https://github.com/huggingface/peft/issues/673
    #https://discuss.huggingface.co/t/further-finetuning-a-lora-finetuned-causallm-model/36987/4
    print("Loading the model")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    model_id = "google/codegemma-2b" 
    peft_model_id = "./CG2b_checkpointsSFT_MultiGPU_done/checkpoint-2000"

    model = AutoModelForCausalLM.from_pretrained(model_id,
                quantization_config = bnb_config,
                device_map={'':device_string})
    model.load_adapter(peft_model_id)

    lora_config = LoraConfig(
        r = 8,
        target_modules = ["q_proj", "o_proj", "k_proj", "v_proj",
                        "gate_proj", "up_proj", "down_proj"],
        task_type = "CAUSAL_LM",
    )

    model = AutoModelForCausalLMWithValueHead.from_pretrained(
        model, 
        quantization_config = bnb_config,
        peft_config = lora_config,
        device_map={'':device_string}
    )
    print_trainable_parameters(model)

    #-----tokenizer-----
    tokenizer = GemmaTokenizer.from_pretrained(model_id, truncation_side = "left", 
                            padding_side = "left")
    tokenizer.pad_token = tokenizer.eos_token

    #-----dataset-----
    train_dataset, eval_dataset = create_datasets(tokenizer, args)

    train_dataset = train_dataset.rename_column("prompt", "query").rename_column("completion", "response")
    train_dataset = train_dataset.map(addFIM_spclTokens, batched = False)
    train_dataset = train_dataset.map(tokenize, batched = False)
    train_dataset.set_format("pytorch")

    eval_dataset = eval_dataset.rename_column("prompt", "query").rename_column("completion", "response")
    eval_dataset = eval_dataset.map(addFIM_spclTokens, batched = False)
    eval_dataset = eval_dataset.map(tokenize_maxLenPadding, batched = False)
    eval_dataset.set_format("pytorch")
    eval_dataloader = DataLoader(eval_dataset, batch_size = args.eval_batch_size, 
                                    shuffle = False)
    #ppo_trainer.log_stats()

    #-----training-----
    run_training(args, model, train_dataset, eval_dataloader)


if __name__ == "__main__":
    tokenizer = None
    accelerator = None
    args = get_args()
    set_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)
    logging.set_verbosity_error()
    main(args)
    #mini_batch_size incr, num_steps decr