
# from trl import DPOTrainer
from dpo_trainer import DPOTrainer
import torch
import numpy as np
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch.nn.functional as F
import torch.nn as nn
from transformers import (
    AutoModelForCausalLM,
    DataCollator,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
)

def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
    if tensor.size(dim) >= length:
        return tensor
    else:
        pad_size = list(tensor.shape)
        pad_size[dim] = length - tensor.size(dim)
        return torch.cat(
            [
                tensor,
                pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
            ],
            dim=dim,
        )

class FuseSFTTrainer(DPOTrainer):
    source_samples = 4

    def __init__(self, **kwargs):
        super().__init__(**kwargs)  # Pass all other arguments using **kwargs
        training_args = kwargs["args"]
        self.use_ref = training_args.use_ref
        self.avg_logp = training_args.avg_logp
        self.norm_type = training_args.norm_type
        self.norm_temp = float(training_args.norm_temp)


    @staticmethod
    def tokenize_row(feature, tokenizer, max_length, max_prompt_length, truncation_mode, label_pad_token_id) -> Dict:
        """Tokenize a single row from a DPO specific dataset.

        At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
        in case the prompt + chosen or prompt + rejected responses is/are too long. First
            we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

        We also create the labels for the chosen/rejected responses, which are of length equal to
            the sum of the length of the prompt and the chosen/rejected response, with
            label_pad_token_id  for the prompt tokens.
        """
        batch = {}
        batch["scores"] = feature["scores"]
        prompt = feature["prompt"]
        source = feature["source"]

        def build_tokenized_answer(prompt, answer, tokenizer):
            """
            Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
            It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
            Reference:
                https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
            """

            full_tokenized = tokenizer(prompt + answer, add_special_tokens=False)
            prompt_input_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]

            answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids):]
            answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids):]

            # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
            full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])

            # Prepare input tokens for token by token comparison
            full_input_ids = np.array(full_tokenized["input_ids"])

            if len(full_input_ids) != len(full_concat_input_ids):
                raise ValueError("Prompt input ids and answer input ids should have the same length.")

            # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
            # can be merged together when tokenizing prompt+answer. This could result
            # on the last token from the prompt being different when tokenized on its own
            # vs when done as prompt+answer.
            response_token_ids_start_idx = len(prompt_input_ids)

            # If tokenized prompt is different than both prompt+answer, then it means the
            # last token has changed due to merging.
            if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
                response_token_ids_start_idx -= 1

            prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
            prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]

            if len(prompt_input_ids) != len(prompt_attention_mask):
                raise ValueError("Prompt input ids and attention mask should have the same length.")

            answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
            answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]

            return dict(
                prompt_input_ids=prompt_input_ids,
                prompt_attention_mask=prompt_attention_mask,
                input_ids=answer_input_ids,
                attention_mask=answer_attention_mask,
            )



        if not isinstance(prompt, str):
            raise ValueError(f"prompt should be an str but got {type(prompt)}")
        prompt_tokens = tokenizer(prompt, add_special_tokens=False)
        prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}

        if not isinstance(source, list):
            raise ValueError(f"chosen should be a list but got {type(source)}")
        source_tokens_list = [build_tokenized_answer(prompt, source_text, tokenizer) for source_text in source] # list

        # Last prompt token might get merged by tokenizer and
        # it should not be included for generation if that happens
        prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])

        source_prompt_len_input_ids_list = [len(source_tokens["prompt_input_ids"]) for source_tokens in source_tokens_list]
        prompt_len_input_ids = min(source_prompt_len_input_ids_list)

        for k, v in prompt_tokens.items():
            prompt_tokens[k] = v[:prompt_len_input_ids]

        # add EOS token to end of answer
        for source_tokens in source_tokens_list:
            source_tokens["input_ids"].append(tokenizer.eos_token_id)
            source_tokens["attention_mask"].append(1)

        longer_response_length = max([len(source_tokens["input_ids"]) for source_tokens in source_tokens_list])

        # if combined sequence is too long, truncate the prompt
        for answer_tokens in source_tokens_list + [prompt_tokens]:
            if len(answer_tokens["prompt_input_ids"]) + longer_response_length > max_length:
                if truncation_mode == "keep_start":
                    for k in ["prompt_input_ids", "prompt_attention_mask"]:
                        answer_tokens[k] = answer_tokens[k][: max_prompt_length]
                elif truncation_mode == "keep_end":
                    for k in ["prompt_input_ids", "prompt_attention_mask"]:
                        answer_tokens[k] = answer_tokens[k][-max_prompt_length :]
                else:
                    raise ValueError(f"Unknown truncation mode: {truncation_mode}")

        # if that's still too long, truncate the response

        for answer_tokens in source_tokens_list:
            if len(answer_tokens["prompt_input_ids"]) + longer_response_length > max_length:
                for k in ["input_ids", "attention_mask"]:
                    answer_tokens[k] = answer_tokens[k][: max_length - max_prompt_length]

        source_sequence_tokens_list = []
        for source_tokens in source_tokens_list:
            source_sequence_tokens = {
                k: source_tokens[f"prompt_{k}"] + source_tokens[k] for k in ["input_ids", "attention_mask"]
            }
            source_sequence_tokens["labels"] = source_sequence_tokens["input_ids"][:]
            source_sequence_tokens["labels"][: len(source_tokens["prompt_input_ids"])] = [
                label_pad_token_id
            ] * len(source_tokens["prompt_input_ids"])
            source_sequence_tokens_list.append(source_sequence_tokens)

        for k, toks in {
            "source": source_sequence_tokens_list,
            "": prompt_tokens,
        }.items():
            if k != "":
                for type_key in ["input_ids", "attention_mask","labels"]:
                    if type_key == "token_type_ids":
                        continue
                    for ww,tok in enumerate(toks):
                        batch[f"{k}{ww}_{type_key}"]=tok[type_key]
            else:
                for type_key, tokens in toks.items():
                    if type_key == "token_type_ids":
                        continue
                    batch[f"{k}{type_key}"] = tokens

        return batch


    @staticmethod
    def concatenated_inputs(
            batch: Dict[str, Union[List, torch.LongTensor]],
            is_encoder_decoder: bool = False,
            label_pad_token_id: int = -100,
            padding_value: int = 0,
            device: Optional[torch.device] = None,
    ) -> Dict[str, torch.LongTensor]:
        """Concatenate the chosen and rejected inputs into a single tensor.

        Args:
            batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
            is_encoder_decoder: Whether the model is an encoder-decoder model.
            label_pad_token_id: The label pad token id.
            padding_value: The padding value to use for the concatenated inputs_ids.
            device: The device for the concatenated inputs.

        Returns:
            A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
        """
        concatenated_batch = {}

        if is_encoder_decoder:
            max_length = max([batch[f"source{ww}_labels"].shape[1] for ww in range(FuseSFTTrainer.source_samples)])
        else:
            max_length = max([batch[f"source{ww}_input_ids"].shape[1] for ww in range(FuseSFTTrainer.source_samples)])


        for source_num in range(FuseSFTTrainer.source_samples):
            for k in batch:
                if k.startswith(f"source{source_num}") and isinstance(batch[k], torch.Tensor):
                    if "labels" in k or is_encoder_decoder:
                        pad_value = label_pad_token_id
                    elif k.endswith("_input_ids"):
                        pad_value = padding_value
                    elif k.endswith("_attention_mask"):
                        pad_value = 0

                    concatenated_key = k.replace(f"source{source_num}", "concatenated")
                    if source_num == 0:
                        concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
                    else:
                    # switch to list
                        concatenated_batch[concatenated_key] = torch.cat(
                            (
                                concatenated_batch[concatenated_key],
                                pad_to_length(batch[k], max_length, pad_value=pad_value),
                            ),
                            dim=0,
                        ).to(device=device)

        if is_encoder_decoder:
            concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
            concatenated_batch["concatenated_attention_mask"] = (
                batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
            )

        return concatenated_batch

    def fusesft_loss(
            self,
            rw_scores: torch.FloatTensor,
            policy_source_logps: torch.FloatTensor,
            reference_source_logps: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:

        rlhf_reward = torch.tensor(rw_scores, dtype=torch.bfloat16, device=policy_source_logps.device) # [bz, K]

        if self.norm_type == "sfmax":
            rlhf_reward = torch.softmax(rlhf_reward/self.norm_temp, dim=-1)
        elif self.norm_type == "average":
            rlhf_reward = rlhf_reward / torch.sum(rlhf_reward, dim=-1, keepdim=True)
        elif self.norm_type == "norw":
            rlhf_reward = torch.ones_like(rlhf_reward) / rlhf_reward.shape[-1]
        else:
            raise ValueError("Invalid norm type. Supported types are 'sfmax', 'average' and 'norw'.")

        losses = -(rlhf_reward * policy_source_logps)

        losses = losses.sum(dim=1) # [bz]

        return losses


    @staticmethod
    def get_batch_logps(
        logits: torch.FloatTensor,
        labels: torch.LongTensor,
        average_log_prob: bool = False,
        label_pad_token_id: int = -100,
        is_encoder_decoder: bool = False,
    ) -> torch.FloatTensor:
        """Compute the log probabilities of the given labels under the given logits.

        Args:
            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
            labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
            average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
            label_pad_token_id: The label pad token id.
            is_encoder_decoder: Whether the model is an encoder-decoder model.

        Returns:
            A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
        """
        if logits.shape[:-1] != labels.shape:
            raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")

        if not is_encoder_decoder:
            labels = labels[:, 1:].clone()
            logits = logits[:, :-1, :]
        loss_mask = labels != label_pad_token_id

        # dummy token; we'll ignore the losses on these tokens later
        labels[labels == label_pad_token_id] = 0

        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

        if average_log_prob:
            return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
        else:
            return (per_token_logps * loss_mask).sum(-1)

    def concatenated_forward(
            self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = self.concatenated_inputs(
            batch,
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
            padding_value=self.padding_value,
            device=self.accelerator.device,
        )

        bsz = batch["source0_labels"].shape[0]  # batch_size

        model_kwargs = (
            {
                "labels": concatenated_batch["concatenated_labels"],
                "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
            }
            if self.is_encoder_decoder
            else {}
        )

        all_logits = model(
            concatenated_batch["concatenated_input_ids"],
            attention_mask=concatenated_batch["concatenated_attention_mask"],
            use_cache=False,
            **model_kwargs,
        ).logits

        all_logps = self.get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
            average_log_prob=self.avg_logp,
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )

        len_source = FuseSFTTrainer.source_samples * bsz

        source_logps = all_logps[:len_source]
        source_logps = source_logps.view(-1, bsz).t()

        source_logits = all_logits[:len_source]
        source_logits = source_logits.view(-1, bsz).t()
        return (source_logps, source_logits)

    def get_batch_loss_metrics(
            self,
            model,
            batch: Dict[str, Union[List, torch.LongTensor]],
            train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the SimPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        (
            policy_source_logps,
            policy_source_logits,
        ) = self.concatenated_forward(model, batch)

        with torch.no_grad():
            if self.ref_model is None or self.use_ref is False:
                reference_source_logps = torch.zeros_like(policy_source_logps, dtype=torch.bfloat16, device=policy_source_logps.device)
            else:
                (
                    reference_source_logps,
                    _,
                ) = self.concatenated_forward(self.ref_model, batch)

        losses = self.fusesft_loss(
            batch["scores"],
            policy_source_logps,
            reference_source_logps,
        )
            
        return losses.mean(), metrics