# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import os
os.environ["WANDB_MODE"] = "offline"
import math
import pathlib
import sys
from copy import deepcopy
from typing import Optional, Union, Tuple, Dict, List
from dataclasses import dataclass, field

import torch
import deepspeed
import torch.nn.functional as F
import transformers
# from peft import LoraConfig, get_peft_model
from torch import nn
from transformers import PreTrainedModel, EvalPrediction
from transformers.training_args import TrainingArguments
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus

from utils import init_logger, add_custom_callback
from dataset import DpoDataset, DpoProcessor, DataCollatorForDpo

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    fuse_qkv: bool = False

@dataclass
class DataArguments:
    data_file: str = field(default=None,
                           metadata={"help": "Path to the training data."})
    switch_rate: float = field(default=0.,
                               metadata={"help": "prob to concat"})
    data_cache_dir: str = field(default="cached",
                                metadata={"help": "Path to the save data cache."})
    overwrite_data_cache: bool = field(default=False,
                                       metadata={"help": "Overwrite the data cache."})

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help":
            "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    loss_average: str = field(
        default="token",
        metadata={
            "help": "Loss averages with the granularity of `token`, `response` or `mixture`"
        }
    )
    dpo_beta: float = field(
        default=0.1,
        metadata={
            "help": "DPO beta"
        }
    )
    use_dpo_loss: bool = field(
        default=False,
        metadata={
            "help": "Whether to use DPO loss"
        }
    )
    use_sft_loss: bool = field(
        default=False,
        metadata={
            "help": "Whether to use SFT loss"
        }
    )
    sft_loss_weight: float = field(
        # default=0.3, # mix_v1
        default=0.1,
        metadata={
            "help": "SFT loss weight"
        }
    )
    use_kl_loss: bool = field(
        default=False,
        metadata={
            "help": "Whether to use KL loss"
        }
    )
    kl_loss_weight: float = field(
        # default=0.1, # mix_v1
        default=0.05,
        metadata={
            "help": "KL loss weight"
        }
    )
    kl_topk: int = field(
        default=50,
        metadata={
            "help": "KL topk"
        }
    )
    reference_free: bool = field(
        default=False,
        metadata={
            "help": "Whether to use reference-free DPO"
        }
    )
    remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'})


def _z3_params_to_fetch(param_list):
    return [
        p for p in param_list
        if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
    ]

def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
    zero_stage_3 = (zero_stage == 3)
    try:
        os.makedirs(save_dir, exist_ok=True)
    except Exception as e:pass
    WEIGHTS_NAME = "pytorch_model.bin"
    output_model_file = os.path.join(save_dir, WEIGHTS_NAME)

    model_to_save = model_ema.module if hasattr(model_ema,
                                                'module') else model_ema
    if not zero_stage_3:
        if global_rank == 0:
            torch.save(model_to_save.state_dict(), output_model_file)
    else:
        output_state_dict = {}
        for k, v in model_to_save.named_parameters():

            if hasattr(v, 'ds_id'):
                with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([v
                                                                            ]),
                                                       enabled=zero_stage_3):
                    v_p = v.data.cpu()
            else:
                v_p = v.cpu()
            if global_rank == 0 and "lora" not in k:
                output_state_dict[k] = v_p
        if global_rank == 0:
            torch.save(output_state_dict, output_model_file)
        del output_state_dict

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
                                   output_dir: str,global_rank: int):
    """Collects the state dict and dump to disk."""
    save_zero_three_model(trainer.deepspeed, global_rank, output_dir, zero_stage=3)


class CustomTrainer(transformers.Trainer):

    def __init__(self, reference_model: Union[PreTrainedModel, torch.nn.Module, None], **kwargs):
        super().__init__(**kwargs)
        self.beta = self.args.dpo_beta
        self.sft_loss_weight = self.args.sft_loss_weight
        self.kl_loss_weight = self.args.kl_loss_weight
        self.reference_free = self.args.reference_free

        self.reference_model = reference_model

        if self.reference_model is not None:
            hf_deepspeed_config = deepcopy(self.args.hf_deepspeed_config)
            hf_deepspeed_config.del_config_sub_tree("optimizer")
            hf_deepspeed_config.del_config_sub_tree("lr_scheduler")
            config = hf_deepspeed_config.config
            optimizer, lr_scheduler = None, None
            model_parameters = None
            kwargs = {
                "model": self.reference_model,
                "model_parameters": model_parameters,
                "config_params": config,
                "optimizer": optimizer,
                "lr_scheduler": lr_scheduler,
            }
            deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
            self.reference_model = deepspeed_engine
            for name, param in self.reference_model.named_parameters():
                if param.requires_grad:
                    param.requires_grad = False

        else:
            self.reference_model = self.model.base_model

        self.use_dpo_loss = self.args.use_dpo_loss
        self.use_sft_loss = self.args.use_sft_loss
        self.use_kl_loss = self.args.use_kl_loss
        self.kl_topk = self.args.kl_topk

    def _get_batch_logps(self, logits: torch.FloatTensor, labels: torch.LongTensor,
                         average_log_prob: bool = False):
        """Compute the log probabilities of the given labels under the given logits.

        Args:
            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
            labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)
            average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.

        Returns:
            A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
        """
        assert logits.shape[:-1] == labels.shape, f"Logits and labels must have the same shape. Got {logits.shape} and {labels.shape}."
        labels = labels[:, 1:].clone()
        logits = logits[:, :-1, :]
        loss_mask = (labels != -100)
        # dummy token; we'll ignore the losses on these tokens later
        labels[labels == -100] = 0
        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
        if average_log_prob:
            return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
            # return (per_token_logps * loss_mask).sum() / loss_mask.sum()
        elif average_log_prob is None:
            return per_token_logps * loss_mask, loss_mask
        else:
            return (per_token_logps * loss_mask).sum(-1)

    def _dpo_loss(self, policy_chosen_logps: torch.FloatTensor,
                 policy_rejected_logps: torch.FloatTensor,
                 reference_chosen_logps: torch.FloatTensor,
                 reference_rejected_logps: torch.FloatTensor,
                 beta: float,
                 reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute the DPO loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
            reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
            reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
            beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
            reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the DPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = reference_chosen_logps - reference_rejected_logps

        if reference_free:
            ref_logratios = 0

        logits = pi_logratios - ref_logratios
        # print('dpo loss logits:', logits)
        # # if logits less than -5, set it to 0
        # logits = torch.where(logits < -5, torch.zeros_like(logits), logits)
        # # upper bound of logits is 50
        # logits = torch.clip(logits, None, 15)


        losses = -F.logsigmoid(beta * logits)
        chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
        rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()

        return losses, chosen_rewards, rejected_rewards

    def _pad_to_length(self, tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
        if tensor.size(dim) >= length:
            return tensor
        else:
            pad_size = list(tensor.shape)
            pad_size[dim] = length - tensor.size(dim)
            return torch.cat(
                [
                    tensor,
                    pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
                ],
                dim=dim,
            )

    def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]], padding_value: int = 0,
                            label_pad_token_id: int = -100):
        """Concatenate the chosen and rejected inputs into a single tensor.

        Args:
            batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).

        Returns:
            A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
        """


        max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
        concatenated_batch = {}
        for k in batch:
            if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
                pad_value = label_pad_token_id if "labels" in k else padding_value
                concatenated_key = k.replace("chosen", "concatenated")
                concatenated_batch[concatenated_key] = self._pad_to_length(batch[k], max_length, pad_value=pad_value)
        for k in batch:
            if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
                pad_value = label_pad_token_id if "labels" in k else padding_value
                concatenated_key = k.replace("rejected", "concatenated")
                concatenated_batch[concatenated_key] = torch.cat(
                    (
                        concatenated_batch[concatenated_key],
                        self._pad_to_length(batch[k], max_length, pad_value=pad_value),
                    ),
                    dim=0,
                )
        return concatenated_batch

    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]):
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = self.concatenated_inputs(batch)

        all_logits = model(
            concatenated_batch["concatenated_input_ids"],
            attention_mask=concatenated_batch["concatenated_attention_mask"],
        ).logits.to(torch.float32)

        all_logps = self._get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
            average_log_prob=False,
        )

        chosen_logps = all_logps[: batch["chosen_input_ids"].shape[0]]
        rejected_logps = all_logps[batch["chosen_input_ids"].shape[0]:]

        chosen_logits = all_logits[: batch["chosen_input_ids"].shape[0]]
        rejected_logits = all_logits[batch["chosen_input_ids"].shape[0]:]

        chonsen_labels = concatenated_batch["concatenated_labels"][: batch["chosen_input_ids"].shape[0]]
        rejected_labels = concatenated_batch["concatenated_labels"][batch["chosen_input_ids"].shape[0]:]

        result_item = {
            'chosen_logits': chosen_logits, 'chosen_labels': chonsen_labels,
            'rejected_logits': rejected_logits, 'rejected_labels': rejected_labels,
            'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps
        }

        return (chosen_logps, rejected_logps, chosen_logits, rejected_logits), result_item

    def _compute_dpo_loss(self, model_output, inputs, loss_mask):
        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
        ) = model_output

        with torch.no_grad():
            (
                reference_chosen_logps,
                reference_rejected_logps,
                _,
                _,
            ), _ = self.concatenated_forward(self.reference_model, inputs)

        losses, chosen_rewards, rejected_rewards = self._dpo_loss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, self.beta, self.reference_free)
        losses = losses * loss_mask
        return losses, chosen_rewards, rejected_rewards

    def _compute_sft_loss(self, policy_chosen_logits, labels, loss_mask):
        chosen_logps, token_loss_mask = self._get_batch_logps(policy_chosen_logits, labels, average_log_prob=None)
        valid_token_num = (token_loss_mask.sum(-1) * loss_mask).sum()
        neg_logps_sum = (-chosen_logps.sum(-1) * loss_mask).sum()
        # safe div
        return neg_logps_sum / torch.max(valid_token_num, torch.ones_like(valid_token_num))

    def compute_loss(self, model, inputs, return_outputs=False):

        output_dict = {}

        if self.use_dpo_loss:
            dpo_loss_mask = inputs['dpo_loss_mask']
            if dpo_loss_mask.sum() == 0:
                print('dpo_loss_mask.sum() == 0, the dpo loss should be 0')
            # print('computing dpo loss...')
            model_outputs, output_item = self.concatenated_forward(model, inputs)
            (chosen_logps, rejected_logps, chosen_logits, rejected_logits) = model_outputs
            chosen_labels = output_item['chosen_labels']
            dpo_losses, chosen_rewards, rejected_rewards = self._compute_dpo_loss(model_outputs, inputs, dpo_loss_mask)
            dpo_loss = dpo_losses.sum()
        else:
            input_ids = inputs['chosen_input_ids'] if 'chosen_input_ids' in inputs else inputs['input_ids']
            attention_mask = inputs['chosen_attention_mask'] if 'chosen_attention_mask' in inputs else inputs[
                'attention_mask']
            chosen_logits = model(input_ids, attention_mask=attention_mask).logits.to(torch.float32)
            chosen_labels = inputs['chosen_labels'] if 'chosen_labels' in inputs else inputs['labels']
            dpo_loss = 0

        if self.use_sft_loss:
            sft_loss_mask = inputs['sft_loss_mask']
            if sft_loss_mask.sum() == 0:
                print('sft_loss_mask.sum() == 0, the sft loss should be 0')

            sft_loss = self._compute_sft_loss(chosen_logits, chosen_labels, sft_loss_mask)
            sft_loss *= self.sft_loss_weight
        else:
            sft_loss = 0
        if self.use_kl_loss:

            kl_loss_mask = inputs['kl_loss_mask']
            if kl_loss_mask.sum() == 0:
                print('kl_loss_mask.sum() == 0, the kl loss should be 0')

            # print('computing kl loss...')
            with torch.no_grad():
                reference_logits = self.reference_model(input_ids=inputs['kl_input_ids'],
                                                        attention_mask=inputs['kl_attention_mask']).logits.to(torch.float32)
            kl_loss = self._compute_kl_loss(chosen_logits, reference_logits, inputs['chosen_labels'],
                                              inputs['kl_labels'], kl_loss_mask)

            kl_loss *= self.kl_loss_weight
        else:
            kl_loss = 0
        total_loss = sft_loss + dpo_loss + kl_loss
        # total_loss = sft_loss 
        output_dict['sft_loss'] = sft_loss.detach().cpu().numpy().reshape(-1)[0] if self.use_sft_loss else 0
        output_dict['dpo_loss'] = dpo_loss.detach().cpu().numpy().reshape(-1)[0] if self.use_dpo_loss else 0
        output_dict['kl_loss'] = kl_loss.detach().cpu().numpy().reshape(-1)[0] if self.use_kl_loss else 0
        output_dict['total_loss'] = total_loss.detach().cpu().numpy().reshape(-1)[0]
        print(f'total_loss:{output_dict["total_loss"]}\tsft_loss:{output_dict["sft_loss"]}\tdpo_loss:{output_dict["dpo_loss"]}\tkl_loss:{output_dict["kl_loss"]}')

        return (total_loss, output_dict) if return_outputs else total_loss

    def _compute_kl_loss(self, chosen_logits, reference_logits, chosen_labels, reference_labels,  kl_loss_mask):
        def get_logits_by_labels(logits, labels, kl_loss_mask=None):
            """
            gather the logits where label != -100
            """
            # reshape labels [batch_size, seq_size] to [batch_size * seq_size]
            labels = labels.clone()[:, 1:].reshape(-1)
            labels_index = torch.where(labels != -100)[0]
            # expand kl_loss_mask: [batch_size] -> [batch_size * seq_size]
            if kl_loss_mask is not None:
                kl_loss_mask = kl_loss_mask.unsqueeze(-1).expand(-1, logits.shape[1] - 1).reshape(-1)
                kl_loss_mask = kl_loss_mask[labels_index]
            # reshape logits [batch_size, seq_size, vocab_size] to [batch_size * seq_size, vocab_size]
            logits = logits[:, :-1, :].reshape(-1, logits.shape[-1])
            logits = logits[labels_index]
            return logits, kl_loss_mask
        chosen_logits, kl_loss_mask = get_logits_by_labels(chosen_logits, chosen_labels, kl_loss_mask)
        reference_logits, _ = get_logits_by_labels(reference_logits, reference_labels)
        assert chosen_logits.shape == reference_logits.shape
        reference_probs = F.softmax(reference_logits, dim=-1)
        chosen_logps = F.log_softmax(chosen_logits, dim=-1)
        # topk tokens from reference_logits. [?, vocab_size] -> [?, topk]
        reference_probs, reference_indices = reference_probs.topk(self.kl_topk, dim=-1)
        chosen_logps = chosen_logps.gather(dim=-1, index=reference_indices)
        # [?, topk] -> [?]
        kl_losses = F.kl_div(chosen_logps, reference_probs, reduction='none').mean(dim=-1)
        # valid_token-wised mean
        mask_sum = torch.max(kl_loss_mask.sum(), torch.ones_like(kl_loss_mask.sum()))
        return (kl_losses * kl_loss_mask).sum() / mask_sum

def compute_metrics(ep: EvalPrediction):
    return {'chosen_rewards': ep.predictions[0].mean(), 'rejected_rewards': ep.predictions[1].mean(),
            'sft_loss': ep.predictions[2].mean(), 'dpo_loss': ep.predictions[3].mean(),
            'total_loss': ep.predictions[4].mean()}

def train():
    parser = transformers.HfArgumentParser((
        ModelArguments,
        DataArguments,
        TrainingArguments
    ))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    logger = init_logger(
        os.path.join(training_args.output_dir, 'train.log'),
        training_args.local_rank
    )
    logger.info(f'model args: {model_args}')
    logger.info(f'data args: {data_args}')
    logger.info(f'training args: {training_args}')
    
    policy_model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=True,
        # device_map='auto',
        cache_dir=training_args.cache_dir
    )

    lora_module_names = set()
    for name, module in policy_model.named_modules():
        if isinstance(module, torch.nn.Linear):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names:  # needed for 16-bit
        lora_module_names.remove('lm_head')

    lora_module_names = list(lora_module_names)


    reference_model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        # device_map='auto',
        trust_remote_code=True,
        cache_dir=training_args.cache_dir,
        # torch_dtype=torch.float16,
    )

    global_rank = torch.distributed.get_rank()
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        use_fast=False,
        trust_remote_code=True,
        model_max_length=training_args.model_max_length,
        cache_dir=training_args.cache_dir
    )
    tokenizer.pad_token_id = 0

    processor = DpoProcessor(
        tokenizer=tokenizer,
        user_token_id=195,
        assistant_token_id=196
    )
    dataset = DpoDataset(
        data_file=data_args.data_file,
        processor=processor,
        max_length=training_args.model_max_length,
        switch_rate=data_args.switch_rate,
        cache_dir=data_args.data_cache_dir,
        overwrite=data_args.overwrite_data_cache,
        use_sft_loss=training_args.use_sft_loss,
        use_dpo_loss=training_args.use_dpo_loss,
        use_kl_loss=training_args.use_kl_loss,
    )
    for epoch in range(math.ceil(training_args.num_train_epochs)):
        dataset.update(epoch)

    trainer = CustomTrainer(model=policy_model,
                            reference_model=reference_model,
                            args=training_args,
                            train_dataset=dataset,
                            tokenizer=tokenizer,
                            compute_metrics=compute_metrics,
                            data_collator=DataCollatorForDpo(tokenizer=tokenizer))
    add_custom_callback(trainer, logger)

    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()
    trainer.save_state()
    safe_save_model_for_hf_trainer(trainer=trainer,
                                   output_dir=training_args.output_dir,global_rank=global_rank)


if __name__ == "__main__":
    local_rank = int(os.environ.get('LOCAL_RANK', '0'))
    print("local_rank",local_rank)
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' if local_rank < 2 else '4,5,6,7'
    train()
