#!/usr/bin/env python
# coding=utf-8
"""
This file is modified from the huggingface example for finetuning language models
[run_clm.py](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py)
"""

import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional, List, Literal
import datasets
from tqdm import tqdm
import torch
from datasets import load_dataset

import transformers
from transformers import (
    AutoConfig,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
    pipeline
)

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from peft import LoraConfig, TaskType, IA3Config, PromptTuningConfig, PromptTuningInit

logger = logging.getLogger(__name__)

class PeftPPOTrainer(PPOTrainer):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def _save_checkpoint(self, model, trial, metrics=None):
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
        # want to save except FullyShardedDDP.
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"

        # Save model checkpoint
        checkpoint_folder = f"checkpoint-{self.state.global_step}"

        if self.hp_search_backend is None and trial is None:
            self.store_flos()

        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)
        self.save_model(output_dir, _internal_call=True)
        
        
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """
    optimize_device_cache: Optional[bool] = False
    """Optimize device cache for slightly more memory-efficient training"""
    kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl"
    """kl penalty options: 'kl': model_logp - ref_logp,  'abs': abs(kl),  'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution"""
    early_stopping: bool = False
    """Whether to stop the PPO optimization loop early is the KL too high"""
    target_kl: float = 6.0
    """Stop early if we exceed this value by over 50%"""
    use_score_scaling: bool = False
    """Use score scaling"""
    use_score_norm: bool = False
    """Use score normalization. Only applicable if use_score_scaling is True"""
    score_clip: Optional[float] = None
    """Score clipping"""
    use_lora: Optional[bool] = False
    lora_rank: Optional[int] = field(default=64, metadata={"help": "the beta parameter for DPO loss"})
    lora_alpha: Optional[float] = field(default=64, metadata={"help": "the beta parameter for DPO loss"})
    lora_dropout: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
    lora_module: Optional[List[str]] = field(default_factory=lambda : ['k_proj', 'q_proj'], metadata={"help": "the beta parameter for DPO loss"})
    """LoRAConfig"""
    use_ia3: Optional[bool] = False
    ia3_module: Optional[List[str]] = field(default_factory=lambda : ['down_proj'], metadata={"help": "the beta parameter for DPO loss"})
    feedforward_modules: Optional[List[str]] = field(default_factory=lambda : ['down_proj'], metadata={"help": "the beta parameter for DPO loss"})
    """IA3Config"""
    use_prompt_tuning: Optional[bool] = False
    num_virtual_tokens: Optional[int] = 16
    prompt_tuning_init_text: Optional[str] = "You are a helpful and harmless AI assistant."
    """PromptTuningConfig"""
    lora_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The lora checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    reward_model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    use_flash_attn: bool = field(
        default=True,
        metadata={"help": "Whether to use flash attention in the model training"},
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=False,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )
    use_auth_token: bool = field(
        default=None,
        metadata={
            "help": "The `use_auth_token` argument is deprecated and will be removed in Transformers v4.34. Please use `token`."
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
                "execute code present on the Hub on your local machine."
            )
        },
    )
    torch_dtype: Optional[str] = field(
        default="auto",
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )
    low_cpu_mem_usage: bool = field(
        default=False,
        metadata={
            "help": (
                "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
                "set True will benefit LLM loading time and RAM consumption."
            )
        },
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a json/jsonl file)."})
    test_file: Optional[str] = field(default=None, metadata={"help": "The input test data file (a json/jsonl file)."})
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    max_seq_length: Optional[int] = field(
        default=None,
        metadata={
            "help": ("The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,")
        },
    )
    max_prompt_length: Optional[int] = field(
        default=None,
        metadata={
            "help": ("The maximum prompt input sequence length after tokenization. Sequences longer than this will be truncated,")
        },
    )

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None:
            raise ValueError("Need either a dataset name or a training file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["json", "jsonl"], "`train_file` should be a json or a jsonl file."

def convert_to_messages(text):
    messages = []
    for pair in text.split('Human:')[1:]:
        temp = pair.split('Assistant:')
        try:
            user, assist = temp[0], temp[1]
        except:
            user, assist = temp[0], ' '
        messages.append({
            "role": "user",
            "content": user
        })
        messages.append({
            "role": "assistant",
            "content": assist
        })
    return messages

def create_prompt_with_tulu_chat_format(messages, bos="<s>", eos="</s>", add_bos=False):
    formatted_text = ""
    for message in messages:
        if message["role"] == "system":
            formatted_text += "<|system|>\n" + message["content"].strip() + "\n"
        elif message["role"] == "user":
            formatted_text += "<|user|>\n" + message["content"].strip() + "\n"
        elif message["role"] == "assistant":
            formatted_text += "<|assistant|>\n" + message["content"].strip() + "\n"
        else:
            raise ValueError(
                "Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(message["role"])
                )
    formatted_text = bos + formatted_text if add_bos else formatted_text
    return formatted_text

def extract_anthropic_prompt(prompt_and_response):
    """Extract the anthropic prompt from a prompt and response pair."""
    search_term = "<|assistant|>\n"
    search_term_idx = prompt_and_response.rfind(search_term)
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
    return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(data_files: str, tokenizer, max_prompt_length: int = 2048, sanity_check: bool = False, cache_dir: str = None):
    """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """
    dataset = load_dataset('json', data_files=data_files, cache_dir=cache_dir)['train']
    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def split_prompt_and_responses(sample):
        chosen = convert_to_messages(sample["chosen"])
        formatted_chosen = create_prompt_with_tulu_chat_format(chosen)
        prompt = extract_anthropic_prompt(formatted_chosen)
        # assert prompt == extract_anthropic_prompt(formatted_rejected), f"The chosen and rejected should have the same prompt\nchosen:\n{chosen}\nrejected:\n{rejected}\n"
        sample["input_ids"] = tokenizer.encode(prompt)[:max_prompt_length]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    dataset = dataset.map(split_prompt_and_responses, batched=False)
    dataset.set_format('pt')
    return dataset

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    
    if model_args.use_auth_token is not None:
        warnings.warn("The `use_auth_token` argument is deprecated and will be removed in Transformers v4.34.", FutureWarning)
        if model_args.token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        model_args.token = model_args.use_auth_token

    ppo_config = PPOConfig(
            seed=training_args.seed,
            learning_rate=training_args.learning_rate,
            log_with=training_args.report_to,
            mini_batch_size=training_args.per_device_train_batch_size,
            batch_size=training_args.per_device_train_batch_size*training_args.gradient_accumulation_steps,
            backward_batch_size=training_args.per_device_train_batch_size*training_args.gradient_accumulation_steps,
            gradient_accumulation_steps=training_args.gradient_accumulation_steps,
            early_stopping=model_args.early_stopping,
            target_kl=model_args.target_kl,
            kl_penalty=model_args.kl_penalty,
            use_score_scaling=model_args.use_score_scaling,
            use_score_norm=model_args.use_score_norm,
            score_clip=model_args.score_clip,
            optimize_device_cache=model_args.optimize_device_cache
        )
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    config_kwargs = {
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
    else:
        raise ValueError(
            "You are instantiating a new config instance from scratch. This is not supported by this finetuning script."
        )

    tokenizer_kwargs = {
        "cache_dir": model_args.cache_dir,
        "use_fast": model_args.use_fast_tokenizer,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this finetuning script."
        )
        
    tokenizer.padding_side = 'left'
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # 2. Load the Anthropic Helpful-Harmless dataset
    train_dataset = get_hh(data_args.train_file, tokenizer=tokenizer, max_prompt_length=data_args.max_prompt_length, cache_dir=model_args.cache_dir, sanity_check=False)
    # eval_dataset = get_hh(data_args.test_file, cache_dir=model_args.cache_dir)

    if model_args.model_name_or_path:
        torch_dtype = (
            model_args.torch_dtype
            if model_args.torch_dtype in ["auto", None]
            else getattr(torch, model_args.torch_dtype)
        )
        if model_args.use_lora:
            logger.info("Training with LORA...")
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM, 
                inference_mode=False, 
                r=model_args.lora_rank, 
                lora_alpha=model_args.lora_alpha, 
                lora_dropout=model_args.lora_dropout,
                target_modules=model_args.lora_module
            )
        elif model_args.use_ia3:
            logger.info("Training with IA3...")
            peft_config = IA3Config(
                task_type=TaskType.CAUSAL_LM, 
                inference_mode=False, 
                target_modules=model_args.ia3_module,
                feedforward_modules=model_args.feedforward_modules
            )
        elif model_args.use_prompt_tuning:
            logger.info("Training with PROMPT_TUNING...")
            peft_config = PromptTuningConfig(
                task_type=TaskType.CAUSAL_LM,
                prompt_tuning_init=PromptTuningInit.TEXT,
                inference_mode=False, 
                num_virtual_tokens=model_args.num_virtual_tokens,
                prompt_tuning_init_text=model_args.prompt_tuning_init_text, 
                tokenizer_name_or_path=model_args.model_name_or_path,
            )   
        else:
            logger.info("Full fintuning...")
            peft_config = None 
            
        if peft_config is None:
            ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
                model_args.model_name_or_path,
                from_tf=bool(".ckpt" in model_args.model_name_or_path),
                config=config,
                cache_dir=model_args.cache_dir,
                revision=model_args.model_revision,
                token=model_args.token,
                trust_remote_code=model_args.trust_remote_code,
                torch_dtype=torch_dtype,
                low_cpu_mem_usage=model_args.low_cpu_mem_usage,
                use_flash_attention_2=True if model_args.use_flash_attn else False,
            )
        else:
            ref_model = None
            
        model = AutoModelForCausalLMWithValueHead.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            token=model_args.token,
            trust_remote_code=model_args.trust_remote_code,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
            use_flash_attention_2=True if model_args.use_flash_attn else False,
            peft_config=peft_config
        )   
    else:
        logger.warning("No pretrained model_name_or_path is given. Training new model from scratch.")
        model = AutoModelForCausalLMWithValueHead.from_config(config, trust_remote_code=model_args.trust_remote_code)
        n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
        logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
        
    # set use_cache to False to enable activation checkpointing 
    # model.config.use_cache = False

    # initalize a trainer
    # here we use a custom trainer that moves the model to CPU when saving the checkpoint in FSDP mode
    # we can switch to the default trainer after moving to deepspeed (let's don't change too much for now)
    ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=train_dataset, data_collator=collator)
    # breakpoint()
    device = ppo_trainer.accelerator.device
    if ppo_trainer.accelerator.is_main_process:
        if training_args.output_dir is not None:
            os.makedirs(training_args.output_dir, exist_ok=True)
    # breakpoint()
    if ppo_trainer.accelerator.num_processes == 1:
        device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
    ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
    # reward_model = AutoModelForSequenceClassification.from_pretrained(model_args.reward_model_name_or_path)
    # reward_tokenizer = AutoTokenizer.from_pretrained(model_args.reward_model_name_or_path, **tokenizer_kwargs)
    if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
        with ds_plugin.zero3_init_context_manager(enable=False):
            # sentiment_pipe = pipeline('sentiment-analysis', model=reward_model, tokenizer=reward_tokenizer, device=device)
            sentiment_pipe = pipeline('sentiment-analysis', model=model_args.reward_model_name_or_path, device=device)
    else:
        # sentiment_pipe = pipeline('sentiment-analysis', model=reward_model, tokenizer=reward_tokenizer, device=device)
        sentiment_pipe = pipeline('sentiment-analysis', model=model_args.reward_model_name_or_path, device=device)

    generation_kwargs = {
        "min_length": -1,
        "top_k": 0.0,
        "top_p": 1.0,
        "do_sample": True,
        "pad_token_id": tokenizer.eos_token_id,
        "max_new_tokens": 128,
    }
    sent_kwargs = {'top_k': None, "function_to_apply": "none", "batch_size": 16}

    for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
        query_tensors = batch["input_ids"]
        ppo_trainer.model.pretrained_model.gradient_checkpointing_disable()
        ppo_trainer.model.pretrained_model.config.use_cache = True
        response_tensors, ref_response_tensors = ppo_trainer.generate(
            query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs
        )
        batch["response"] = tokenizer.batch_decode(response_tensors)
        batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors)

        # Compute sentiment score
        texts = [q + r for q, r in zip(batch["query"], batch["response"])]
        pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
        rewards = [torch.tensor(output[0]["score"]) for output in pipe_outputs]
        ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
        ref_pipe_outputs = sentiment_pipe(ref_texts, **sent_kwargs)
        ref_rewards = [torch.tensor(output[0]["score"]) for output in ref_pipe_outputs]
        batch["ref_rewards"] = ref_rewards

        # Run PPO step
        ppo_trainer.model.pretrained_model.gradient_checkpointing_enable()
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
        ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"])
    
    if ppo_trainer.accelerator.is_main_process():
        ppo_trainer._save_pretrained(training_args.output_dir)

if __name__ == "__main__":
    main()