import logging

import torch
from transformers.data.data_collator import *


logger = logging.getLogger(__name__)

SUPPORTED_DECODER_MODELS = ['codegen', 'bloomz', 'gpt-neox', 'llama','mistral','qwen','qwen2','qwen3','qwen2_moe']
SUPPORTED_SEQ2SEQ_MODELS = ['t5', 'flan-t5']


def check_model(model_name, supported_models):
    for sup_model in supported_models:
        if sup_model.lower() in model_name.lower():
            return True

    return False

system_prompt_llama31 = """<|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|>"""

@dataclass
class DataCollatorForUIE:
    tokenizer: PreTrainedTokenizerBase
    model: Optional[Any] = None
    padding: Union[bool, str, PaddingStrategy] = True
    max_source_length: Optional[int] = None
    max_target_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"
    add_task_name: bool = False
    add_dataset_name: bool = False
    common_dataset_name: str = None
    text_only: bool = False
    num_examples: int = 0
    input_record_file: str = None
    all_train: bool = False

    def __call__(self, batch, return_tensors=None):
        if return_tensors is None:
            return_tensors = self.return_tensors

        model_name = self.model.config._name_or_path
        if check_model(model_name, SUPPORTED_DECODER_MODELS):
            model_inputs = self.decoder_call(batch, return_tensors)
        elif check_model(model_name, SUPPORTED_SEQ2SEQ_MODELS):
            model_inputs = self.seq2seq_call(batch, return_tensors)
        else:
            raise ValueError('Unsupport model {}!'.format(model_name))

        return model_inputs

    def get_instruction(self, instance):
        instruction = instance['Instance']["instruction"]
        content = instance['Instance']['sentence']

        # add task/ds prefix
        prefix = ''
        if self.add_dataset_name:
            ds_name = self.common_dataset_name if self.common_dataset_name else instance['Dataset']
            prefix = prefix + "Dataset:"
            prefix = prefix + ds_name + '\n' if prefix else instance['Dataset'] + '\n'
        if prefix:
            instruction = prefix + instruction

        # TODO, support few shot
        # add few shot samples
        samples = ''
        if len(instance['Samples']) > 0:
            raise Exception('Few shot is coming soon...')
        if samples:
            content = samples + content
        # TODO, fix bug
        try:
            instruction = instruction.format(content)
        finally:
            return instruction


    def seq2seq_call(self, batch, return_tensors):
        sources = []
        labels = []

        for instance in batch:
            label = instance['Instance']['label']
            labels.append(label)
            instruction = self.get_instruction(instance)

            source = instruction
            tokenized_source = self.tokenizer(source)["input_ids"]
            if len(tokenized_source) <= self.max_source_length:
                sources.append(source)
            else:
                sources.append(self.tokenizer.decode(tokenized_source[:self.max_source_length], skip_special_tokens=True))

        # TODO, support online demo
        if self.text_only:
            model_inputs = {"inputs": sources, "labels": labels}
        else:
            model_inputs = self.tokenizer(
                sources,
                max_length=self.max_source_length,
                padding=self.padding,
                return_tensors=return_tensors,
                truncation=True,
                pad_to_multiple_of=self.pad_to_multiple_of
            )
            with self.tokenizer.as_target_tokenizer():
                labels = self.tokenizer(
                    labels,
                    max_length=self.max_target_length,
                    padding=self.padding,
                    return_tensors=return_tensors,
                    truncation=True,
                    pad_to_multiple_of=self.pad_to_multiple_of
                )
            label_mask = labels["attention_mask"].bool()
            model_inputs["labels"] = labels["input_ids"].masked_fill(~label_mask, self.label_pad_token_id)

            # prepare decoder_input_ids
            if self.model is not None:
                decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=model_inputs["labels"])
                model_inputs["decoder_input_ids"] = decoder_input_ids

            self._save_samples(model_inputs, sources, labels)

        return model_inputs

    def decoder_call(self, batch, return_tensors):
        self.tokenizer.padding_side = 'left'
        sources = []
        label_lens = []
        original_labels_text = []
        max_len = -1
        system_prompt = ""
        if "llama3.1" in self.tokenizer.name_or_path.lower():
            system_prompt = system_prompt_llama31

        if batch[0]['subset'] == "train":
            limit_input_len = self.max_source_length + self.max_target_length
        else:
            limit_input_len = self.max_source_length

        bos_token = self.tokenizer.bos_token if self.tokenizer.bos_token is not None else ""

        for instance in batch:
            label = instance['Instance']['label']
            original_labels_text.append(label)
            instruction = self.get_instruction(instance)
            messages = [{"role": "user", "content": instruction}]

            # (input) for inference, (input + label) for training
            if instance['subset'] in ['dev', 'test'] and not self.all_train:
                messages = [{"role": "user", "content": instruction}]
                label_lens.append(0)

                if "qwen3" in self.tokenizer.name_or_path.lower():
                    formatted_source = self.tokenizer.apply_chat_template(
                        messages, 
                        tokenize=False, 
                        add_generation_prompt=True,
                        enable_thinking=False
                    ).replace(system_prompt,"")
                else:
                    formatted_source = self.tokenizer.apply_chat_template(
                        messages, 
                        tokenize=False, 
                        add_generation_prompt=True
                    ).replace(system_prompt,"")

                tokenized_source = self.tokenizer(formatted_source)["input_ids"]
                if len(tokenized_source) <= limit_input_len:
                    sources.append(formatted_source.removeprefix(bos_token))
                    max_len = max(len(tokenized_source), max_len)
                else:
                    truncated_tokens = tokenized_source[:limit_input_len]
                    sources.append(self.tokenizer.decode(truncated_tokens, skip_special_tokens=False).removeprefix(bos_token))
                    max_len = limit_input_len
            else:
                messages = [
                    {"role": "user", "content": instruction},
                    {"role": "assistant", "content": label}
                ]
                prompt_messages = [{"role": "user", "content": instruction}]

                if "qwen3" in self.tokenizer.name_or_path.lower():
                    formatted_source = self.tokenizer.apply_chat_template(
                        messages, 
                        tokenize=False, 
                        add_generation_prompt=False,
                        enable_thinking=False
                    ).replace(system_prompt,"")
                    
                    prompt_only_str = self.tokenizer.apply_chat_template(
                        prompt_messages, 
                        tokenize=False, 
                        add_generation_prompt=True,
                        enable_thinking=False
                    ).replace(system_prompt,"")
                else:
                    formatted_source = self.tokenizer.apply_chat_template(
                        messages, 
                        tokenize=False, 
                        add_generation_prompt=False
                    ).replace(system_prompt,"")
                    
                    prompt_only_str = self.tokenizer.apply_chat_template(
                        prompt_messages, 
                        tokenize=False, 
                        add_generation_prompt=True
                    ).replace(system_prompt,"")


                tokenized_full = self.tokenizer(formatted_source)["input_ids"]
                tokenized_prompt = self.tokenizer(prompt_only_str)["input_ids"]

                current_label_len = len(tokenized_full) - len(tokenized_prompt)
                
                if len(tokenized_full) <= limit_input_len:
                    sources.append(formatted_source.removeprefix(bos_token))
                    label_lens.append(current_label_len)
                    max_len = max(len(tokenized_full), max_len)
                else:
                    max_len = limit_input_len
                    truncated_tokens = tokenized_full[:limit_input_len]
                    sources.append(self.tokenizer.decode(truncated_tokens, skip_special_tokens=False).removeprefix(bos_token))

                    label_lens.append(max(0, len(truncated_tokens) - len(tokenized_prompt)))

        if self.text_only:
            model_inputs = {"inputs": sources, 'labels': original_labels_text}
        else:
            model_inputs = self.tokenizer(
                sources,
                max_length=max_len,
                padding=self.padding,
                return_tensors=return_tensors,
                truncation=True,
                pad_to_multiple_of=self.pad_to_multiple_of
            )
            labels = model_inputs["input_ids"].clone()
            for k, label_len in enumerate(label_lens):
                sequence_length = labels.shape[1]
                labels[k, :sequence_length - label_len] = self.label_pad_token_id
            model_inputs["labels"] = labels

            self._save_samples(model_inputs, sources, original_labels_text)

        return model_inputs



    def _save_samples(self, model_inputs, sources, labels):
        if not self.input_record_file:
            return

        loss_label = []
        if hasattr(model_inputs, 'loss_mask'):
            for loss, id in zip(model_inputs.loss_mask, model_inputs.input_ids):
                loss_label.append(self.tokenizer.decode((loss * id).view(-1).int()))

            with open(self.input_record_file, 'a+', encoding='utf-8') as f:
                for text, label, mask_label in zip(sources, labels, loss_label):
                    f.write(text+'\n')
                    f.write(label + '\n')
                    f.write(mask_label+'\n\n')
        else:
            with open(self.input_record_file, 'a+', encoding='utf-8') as f:
                for text, label in zip(sources, labels['input_ids']):
                    f.write(text + '\n')
                    f.write(self.tokenizer.decode(label, clean_up_tokenization_spaces=False) + '\n')