from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple  
from ...extras.constants import IGNORE_INDEX  
from ...extras.logging import get_logger  
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen  
  
if TYPE_CHECKING:  
    from transformers import PreTrainedTokenizer, ProcessorMixin  
    from ...hparams import DataArguments  
    from ..template import Template  
  
logger = get_logger(__name__)  
  
def _encode_triplewise_example(  
    prompt: Sequence[Dict[str, str]],  
    response: Sequence[Dict[str, str]],  
    system: Optional[str],  
    tools: Optional[str],  
    template: "Template",  
    tokenizer: "PreTrainedTokenizer",  
    processor: Optional["ProcessorMixin"],  
    cutoff_len: int,  
) -> Tuple[List[int], List[int], List[int], List[int], List[int], List[int]]:  
    if processor is not None and not hasattr(processor, "image_seq_length"):  # llava-like models  
        prompt[0]["content"] = template.image_token + prompt[0]["content"]  
  
    chosen_messages = prompt + [response[0]]  
    middle_messages = prompt + [response[1]]  
    rejected_messages = prompt + [response[2]]  
  
    prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)  
    _, middle_ids = template.encode_oneturn(tokenizer, middle_messages, system, tools)  
    _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)  
  
    if template.efficient_eos:  
        chosen_ids += [tokenizer.eos_token_id]  
        middle_ids += [tokenizer.eos_token_id]  
        rejected_ids += [tokenizer.eos_token_id]  
  
    if processor is not None and hasattr(processor, "image_seq_length"):  # paligemma models  
        image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)  
        prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids  
  
    # consider the response is more important  
    source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(middle_ids), len(rejected_ids)), cutoff_len)  
    prompt_ids = prompt_ids[:source_len]  
    chosen_ids = chosen_ids[:target_len]  
    middle_ids = middle_ids[:target_len]  
    rejected_ids = rejected_ids[:target_len]  
  
    chosen_input_ids = prompt_ids + chosen_ids  
    chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids  
    middle_input_ids = prompt_ids + middle_ids  
    middle_labels = [IGNORE_INDEX] * source_len + middle_ids  
    rejected_input_ids = prompt_ids + rejected_ids  
    rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids  
  
    return chosen_input_ids, chosen_labels, middle_input_ids, middle_labels, rejected_input_ids, rejected_labels  
  
def preprocess_pairwise_dataset(  
    examples: Dict[str, List[Any]],  
    template: "Template",  
    tokenizer: "PreTrainedTokenizer",  
    processor: Optional["ProcessorMixin"],  
    data_args: "DataArguments",  
) -> Dict[str, List[List[int]]]:  
    # build input pairs with format `<bos> X`, `Y1 <eos>`, `Y2 <eos>`, and `Y3 <eos>`  
    model_inputs = {  
        "chosen_input_ids": [],  
        "chosen_attention_mask": [],  
        "chosen_labels": [],  
        "middle_input_ids": [],  
        "middle_attention_mask": [],  
        "middle_labels": [],  
        "rejected_input_ids": [],  
        "rejected_attention_mask": [],  
        "rejected_labels": [],  
    }  
    if processor is not None:  
        model_inputs["pixel_values"] = []  
        if hasattr(processor, "image_seq_length"):  # paligemma models  
            model_inputs["chosen_token_type_ids"] = []  
            model_inputs["middle_token_type_ids"] = []  
            model_inputs["rejected_token_type_ids"] = []  
  
    for i in range(len(examples["prompt"])):  
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 3:  
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))  
            continue  
        chosen_input_ids, chosen_labels, middle_input_ids, middle_labels, rejected_input_ids, rejected_labels = _encode_triplewise_example(  
            prompt=examples["prompt"][i],  
            response=examples["response"][i],  
            system=examples["system"][i],  
            tools=examples["tools"][i],  
            template=template,  
            tokenizer=tokenizer,  
            processor=processor,  
            cutoff_len=data_args.cutoff_len,  
        )  
        model_inputs["chosen_input_ids"].append(chosen_input_ids)  
        model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))  
        model_inputs["chosen_labels"].append(chosen_labels)  
        model_inputs["middle_input_ids"].append(middle_input_ids)  
        model_inputs["middle_attention_mask"].append([1] * len(middle_input_ids))  
        model_inputs["middle_labels"].append(middle_labels)  
        model_inputs["rejected_input_ids"].append(rejected_input_ids)  
        model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))  
        model_inputs["rejected_labels"].append(rejected_labels)  
        if processor is not None:  
            model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))  
            if hasattr(processor, "image_seq_length"):  # paligemma models  
                model_inputs["chosen_token_type_ids"].append(  
                    get_paligemma_token_type_ids(len(chosen_input_ids), processor)  
                )  
                model_inputs["middle_token_type_ids"].append(  
                    get_paligemma_token_type_ids(len(middle_input_ids), processor)  
                )  
                model_inputs["rejected_token_type_ids"].append(  
                    get_paligemma_token_type_ids(len(rejected_input_ids), processor)  
                )  
    return model_inputs  
  
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:  
    valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))  
    valid_middle_labels = list(filter(lambda x: x != IGNORE_INDEX, example["middle_labels"]))  
    valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))  
  
    print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))  
    print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))  
    print("chosen_label_ids:\n{}".format(example["chosen_labels"]))  
    print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))  
  
    print("middle_input_ids:\n{}".format(example["middle_input_ids"]))  
    print("middle_inputs:\n{}".format(tokenizer.decode(example["middle_input_ids"], skip_special_tokens=False)))  
    print("middle_label_ids:\n{}".format(example["middle_labels"]))  
    print("middle_labels:\n{}".format(tokenizer.decode(valid_middle_labels, skip_special_tokens=False)))  
  
    print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))  
    print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))  
    print("rejected_label_ids:\n{}".format(example["rejected_labels"]))  
    print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))  
