#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import random
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

import json
import copy
from dotenv import load_dotenv
load_dotenv()
os.environ['OPENAI_API_KEY'] = os.getenv('OPEN_AI_KEY')

from tqdm import tqdm
import pandas as pd

import torch
import transformers
from datasets import Dataset, concatenate_datasets
from transformers import AutoModelForCausalLM, set_seed
device = "cuda" if torch.cuda.is_available() else "cpu"

from alignment import (
    DataArguments,
    DPOConfig,
    H4ArgumentParser,
    ModelArguments,
    # IterArguments,
    apply_chat_template,
    get_checkpoint,
    get_datasets,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
    get_tokenizer,
    is_adapter_model,
)

from alignment.ranking_utils import retrieve_preference_matrix, compute_logit, compute_logit_2
from peft import PeftConfig, PeftModel
from datasets import Dataset
from utilities.trl import iREPOTrainer
from utilities.alpaca_farm.auto_annotations import PairwiseAutoAnnotator
from pathlib import Path
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from vllm import LLM, SamplingParams
import llm_blender

import argparse
CURRENT_DIR = Path().parent
BASE_DIR = CURRENT_DIR / "evaluators_configs"
DATA_DIR = CURRENT_DIR / "data"
RANKER_DIR = CURRENT_DIR / "ranker"
logger = logging.getLogger(__name__)

stop_list = ["<|im_start|>user", "<user>", "<useruser>", "\nUser:", "\n\nuser:", "<|useruserase|>", "<|useruser|>", "<|useruser>", "<user", "<#useruser#>", "<==useruser"]
annotators_configs = [
    "llm-blender/pair-ranker",
    "llm-blender/PairRM",
    "OpenAssistant/reward-model-deberta-v3-large-v2",
    "openbmb/UltraRM-13b",
    "berkeley-nest/Starling-RM-7B-alpha",
    #RANKER_DIR / "PairRM-1",
    #RANKER_DIR / "PairRM-2",
    #RANKER_DIR / "PairRM-3",
    #RANKER_DIR / "PairRM-4",
    #RANKER_DIR / "PairRM-5"
]

selected_configs = random.sample(annotators_configs, 5)

annotators_list = []

for config in selected_configs:
    blender = llm_blender.Blender()
    blender.loadranker(config) # load PairRM
    annotators_list.append(blender)

# swap 2 element of 2 list of same index
def swap_at_index(list1, list2, index):
    if index < len(list1) and index < len(list2):
        list1[index], list2[index] = list2[index], list1[index]

def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

def get_model_number(model_name):
    
    if model_name.find("sft") != -1:
        return -1, model_name.split("/")[1].split("-")[0]
    else:
        return int(model_name[-1]), model_name.split("/")[1].split("-")[0]


def main():
    
    parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig))
    model_args, data_args, training_args = parser.parse()

    #######
    # Setup
    #######
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.info(f"Model parameters {model_args}")
    logger.info(f"Data parameters {data_args}")
    logger.info(f"Training/evaluation parameters {training_args}")

    # Check for last checkpoint
    last_checkpoint = get_checkpoint(training_args)
    if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

    # Set seed for reproducibility
    set_seed(training_args.seed)

    # Model name should end with the iteration number
    i = training_args.output_dir[-1]

    logger.info(f"***iREPO Iteration {i} - Using model {model_args}***")
    #training_args.output_dir = training_args.output_dir + "-" + str(i)
    #training_args.hub_model_id = training_args.output_dir

    ###############
    # Load datasets
    ###############
    raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)

    logger.info(
        f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
    )
    column_names = list(raw_datasets["train"].features)

    #####################################
    # Load tokenizer and process datasets
    #####################################
    data_args.truncation_side = "left"  # Truncate from left to ensure we don't lose labels in final turn
    data_args.padding_side = "left"
    tokenizer = get_tokenizer(model_args, data_args)

    tokenizer.chat_template = tokenizer.default_chat_template

    logger.info(f"Chat Template:\n{tokenizer.chat_template}")
    #####################
    # Apply chat template
    #####################
    raw_datasets = raw_datasets.map(
        apply_chat_template,
        fn_kwargs={"tokenizer": tokenizer, "task": "irepo"},
        num_proc=data_args.preprocessing_num_workers,
        #remove_columns=column_names,
        remove_columns = ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages'],
        desc="Formatting comparisons with prompt template",
    )

    # Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
    for split in ["train", "test"]:
        raw_datasets[split] = raw_datasets[split].rename_columns(
            {"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
        )

    # Log a few random samples from the training set:
    for index in random.sample(range(len(raw_datasets["train"])), 2):
        logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
        logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}")
        logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}")


    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    quantization_config = get_quantization_config(model_args)
        
    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        use_flash_attention_2=model_args.use_flash_attention_2,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )

    model = model_args.model_name_or_path

    if is_adapter_model(model, model_args.model_revision) is True:

        logger.info(f"PEFT Model")
        logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}")
        peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
        model_kwargs = dict(
            revision=model_args.base_model_revision,
            trust_remote_code=model_args.trust_remote_code,
            use_flash_attention_2=model_args.use_flash_attention_2,
            torch_dtype=torch_dtype,
            use_cache=False if training_args.gradient_checkpointing else True,
            device_map=get_kbit_device_map() if quantization_config is not None else None,
            quantization_config=quantization_config,
        )
        base_model = AutoModelForCausalLM.from_pretrained(
            peft_config.base_model_name_or_path,
            **model_kwargs,
        )
        model = PeftModel.from_pretrained(
            base_model,
            model_args.model_name_or_path,
            revision=model_args.model_revision,
        )

        model_kwargs = None

    ref_model = model
    ref_model_kwargs = model_kwargs

    if model_args.use_peft is True:
        ref_model = None
        ref_model_kwargs = None

    #########################
    # Generating Data
    #########################
    if i != "0":

        vllm = LLM(model=model_args.model_name_or_path,  gpu_memory_utilization=0.9)
        
        # Log a few random samples from the training set:
        logger.info(f"iREPO Iteration {i} - Self-Generating Data....")
        iter_num, model_name = get_model_number(model_args.model_name_or_path)
        data_filename = (model_name + "_" + training_args.loss_type + "_train_i" + str(i) + ".json")
        logger.info(f"Generated data filename {data_filename}")
        if os.path.exists(DATA_DIR / data_filename):
            with open(DATA_DIR / data_filename, 'r') as file:
                train_dataset = json.load(file)
            logger.info(f"Loaded generated data from existing file {data_filename}")
        else:
            logger.info("*** Merge and Unload Model Adapters ***")        
            indexes = random.sample(range(len(raw_datasets["train"])), 30000)

            logger.info(f"Prompt sample {indexes[:3]} of the raw training set:\n\n{raw_datasets['train'][indexes[:3]]['prompt']}")
            response_filename= model_name + "_" + training_args.loss_type + "_response_list_i" + training_args.output_dir[-1] + ".json"
            annotate_filename= model_name + "_" + training_args.loss_type + "_annotated_list_i" + training_args.output_dir[-1] + ".json"
            #logger.info(f"Responses will be saved to {response_filename}")
            if os.path.exists(response_filename):
                # If it exists, load and print the existing data
                with open(response_filename, 'r') as file:
                    response_list = json.load(file)
                logger.info("Loaded response list from existing file")
            else:

                logger.info(f"iREPO - Generating Responses...")
                response_list = []   

                sampling_params_1 = SamplingParams(temperature = 0.8,
                                    top_p=1.0, max_tokens=512,
                                    stop= stop_list)

                sampling_params_2 = SamplingParams(temperature = 1.0,
                                    top_p=1.0, max_tokens=512,
                                    stop= stop_list)

                sampling_params_3 = SamplingParams(temperature = 0.8,
                                    top_p=0.8, max_tokens=512,
                                    stop= stop_list)
                
                response_1_list = []
                response_2_list = []
                response_3_list = []
                response_list = []

                for sample_indexes in tqdm(batch(indexes, 50)):

                    outputs_1 = vllm.generate(raw_datasets['train'][sample_indexes]['prompt'], sampling_params_1)

                    for output in outputs_1:
                        prompt = output.prompt
                        generated_text = output.outputs[0].text
                        response_1_list.append(generated_text)
                                           
                    outputs_2 = vllm.generate(raw_datasets['train'][sample_indexes]['prompt'], sampling_params_2)
           
                    for output in outputs_2:
                        prompt = output.prompt
                        generated_text = output.outputs[0].text
                        response_2_list.append(generated_text)

                    outputs_3 = vllm.generate(raw_datasets['train'][sample_indexes]['prompt'], sampling_params_3)

                    for output in outputs_3:
                        prompt = output.prompt
                        generated_text = output.outputs[0].text
                        response_3_list.append(generated_text)

                    batch_responses = [{"index": indexes[i],
                                        "instruction":raw_datasets['train'][i]['prompt'],
                                        "input":'',
                                        "output_1": response_1_list[i],
                                        "output_2": response_1_list[i],
                                        "output_3": response_1_list[i]} for i in indexes]

                    response_list.extend(batch_responses)

                    torch.cuda.empty_cache()

                with open(response_filename, 'w') as file:
                    json.dump(response_list, file)

                torch.cuda.empty_cache()
                del vllm

            logger.info(f"iREPO - Annotating Data....")

            if os.path.exists(annotate_filename):
                with open(annotate_filename, 'r') as file:
                    annotated_list = json.load(file)
                logger.info("Loaded annotate list from existing file")
            else:
                for sample in response_list:
                    rank_list = []  
                    for annotator in annotators_list:
                        ranks = annotator.rank(sample['instruction'],
                                               sample['output_1'],
                                               sample['output_2'],
                                               sample['output_3'],
                                               return_scores=False, batch_size=1)
                        rank_list.append(ranks)
                    
                    preferences = retrieve_preference_matrix(rank_list)
                    ys, yl, ws, wl, logit = compute_logit(preferences)

                    sample['strongest'] = ys
                    sample['lowest'] = yl
                    sample['strenth_strongest'] = ws
                    sample['strenth_lowest'] = wl
                    sample['logit'] = logit 

                with open(annotate_filename, 'w') as file:
                    json.dump(response_list, file)

            logger.info(f"iREPO Iteration {i} - Format Dataset....")

            gen_train_dataset = {"prompt": [],
                                "chosen": [],
                                "rejected": [],
                                "score_chosen": [],
                                "score_rejected": []}
            
            for j, responses in tqdm(enumerate(response_list)):
                #print(j)
                gen_train_dataset["prompt"].append(responses["instruction"])
                        
                chosen_response = 'output_' + str(responses['strongest'])
                rejected_response = 'output_' + str(responses['lowest'])
                gen_train_dataset["chosen"].append(responses[chosen_response])
                gen_train_dataset["rejected"].append(responses[rejected_response])
                gen_train_dataset["score_chosen"].append(responses['strenth_strongest'])
                gen_train_dataset["score_rejected"].append(responses['strenth_lowest'])
                gen_train_dataset["reward_diff"].append(responses['logit'])
        
       
        avg_train_rewards_difference = sum(gen_train_dataset['reward_diff']) / len(gen_train_dataset['reward_diff'])
        logger.info(f"iREPO Iteration {i} - Avg Train Reward Difference: {avg_train_rewards_difference}")

        gen_train_dataset['score_chosen'] = [float(num) for num in train_dataset['score_chosen']]
        gen_train_dataset['score_rejected'] = [float(num) for num in train_dataset['score_rejected']]

        gen_train_dataset = Dataset.from_dict(gen_train_dataset) # v2

        test_rewards_difference  = compute_logit_2(raw_datasets['test']['score_chosen'],
                                                   raw_datasets['test']['score_rejected'])
        
        avg_test_rewards_difference = sum(test_rewards_difference) / len(test_rewards_difference)
        logger.info(f"iREPO Iteration {i} - Avg Test Reward Difference: {avg_test_rewards_difference}")
        raw_datasets['test'] = raw_datasets['test'].add_column("reward_diff", test_rewards_difference)

        sample_indexes = random.sample(range(len(train_dataset)), 3)

        logger.info(f"Chosen Score {sample_indexes} of the training set:\n\n{gen_train_dataset[sample_indexes]['score_chosen']}")
        logger.info(f"Rejected Score {sample_indexes} of the training set:\n\n{gen_train_dataset[sample_indexes]['score_rejected']}")
        logger.info(f"Reward_diff sample {sample_indexes} of the training set:\n\n{ gen_train_dataset[sample_indexes]['reward_diff']}")

        #########################
        # Instantiate DPO trainer
        #########################
        #model_args.load_in_4bit = False
        trainer = iREPOTrainer(
            model,
            ref_model,
            model_init_kwargs=model_kwargs,
            ref_model_init_kwargs=ref_model_kwargs,
            args=training_args,
            beta=training_args.beta,
            train_dataset=gen_train_dataset,
            eval_dataset=raw_datasets["test"],
            tokenizer=tokenizer,
            max_length=training_args.max_length,
            max_prompt_length=training_args.max_prompt_length,
            peft_config=get_peft_config(model_args),
            loss_type=training_args.loss_type,
        )
    else:

        train_rewards_difference = compute_logit_2(raw_datasets['train']['score_chosen'],
                                                   raw_datasets['train']['score_rejected'])

        test_rewards_difference = compute_logit_2(raw_datasets['test']['score_chosen'],
                                                  raw_datasets['test']['score_rejected'])

        avg_train_rewards_difference = sum(train_rewards_difference) / len(train_rewards_difference)
        avg_test_rewards_difference = sum(test_rewards_difference) / len(test_rewards_difference)

        logger.info(f"iREPO Iteration {i} - Avg Train Reward Difference: {avg_train_rewards_difference}")
        logger.info(f"iREPO Iteration {i} - Avg Test Reward Difference: {avg_test_rewards_difference}")

        raw_datasets['train'] = raw_datasets['train'].add_column("reward_diff", train_rewards_difference)
        raw_datasets['test'] = raw_datasets['test'].add_column("reward_diff", test_rewards_difference)

        #raw_datasets['train'] = raw_datasets['train'].map()
        sample_indexes = random.sample(range(len(raw_datasets['train'])), 3)
        logger.info(f"Chosen Score {sample_indexes} of the training set:\n\n{raw_datasets['train'][sample_indexes]['score_chosen']}")
        logger.info(f"Rejected Score {sample_indexes} of the training set:\n\n{raw_datasets['train'][sample_indexes]['score_rejected']}")
        logger.info(f"Reward_diff sample {sample_indexes} of the raw training set:\n\n{ raw_datasets['train'][sample_indexes]['reward_diff']}")

        #########################
        # Instantiate DPO trainer
        #########################
        #model_args.load_in_4bit = False
        trainer = iREPOTrainer(
            model,
            ref_model,
            model_init_kwargs=model_kwargs,
            ref_model_init_kwargs=ref_model_kwargs,
            args=training_args,
            beta=training_args.beta,
            #reward_difference = avg_rewards_difference,
            train_dataset=raw_datasets['train'],
            eval_dataset=raw_datasets["test"],
            tokenizer=tokenizer,
            max_length=training_args.max_length,
            max_prompt_length=training_args.max_prompt_length,
            peft_config=get_peft_config(model_args),
            loss_type=training_args.loss_type,
        )

    ###############
    # Training loop
    ###############
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
    metrics = train_result.metrics
    if i != "0":
        metrics["train_samples"] = len(train_dataset)    
    else:
        metrics["train_samples"] = len(raw_datasets["train"])
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    logger.info("*** Training complete ***")

    ##########
    # Evaluate
    ##########
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate()
        metrics["eval_samples"] = len(raw_datasets["test"])
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    ##################################
    # Save model and create model card
    ##################################

    logger.info("*** Save model ***")
    logger.info(f"Model saved to {training_args.output_dir}")
    trainer.save_model(training_args.output_dir)
    #model_args.model_name_or_path = training_args.output_dir

    # Save everything else on main process
    kwargs = {
        "finetuned_from": model_args.model_name_or_path,
        "dataset": list(data_args.dataset_mixer.keys()),
        "dataset_tags": list(data_args.dataset_mixer.keys()),
        "tags": ["alignment-handbook"],
    }

    #model_args.model_name_or_path = training_args.output_dir 

    if trainer.accelerator.is_main_process:
        trainer.create_model_card(**kwargs)
        # Restore k,v cache for fast inference
        trainer.model.config.use_cache = True
        trainer.model.config.save_pretrained(training_args.output_dir)

    if training_args.push_to_hub is True:
        logger.info("Pushing to hub...")
        merged_model = trainer.model.merge_and_unload()
        merged_model.save_pretrained(
            model_args.hub_model_id,
            push_to_hub=True,
            repo_id=model_args.hub_model_id,
        )
        tokenizer.save_pretrained(
            model_args.hub_model_id,
            push_to_hub=True,
            repo_id=model_args.hub_model_id,
        )
        merged_model.push_to_hub(model_args.hub_model_id)
        tokenizer.push_to_hub(model_args.hub_model_id)
    
        #trainer.push_to_hub(**kwargs)

    logger.info("*** Training complete! ***")

    #logger.info("*** Deleting variables! ***")

    del model, tokenizer, trainer, raw_datasets, train_dataset, annotated_list, response_list
    torch.cuda.empty_cache()


if __name__ == "__main__":
    main()
