import sys
import logging
logger = logging.getLogger(__name__)

from trl import SFTTrainer
from trl import (
    TrlParser,
    SFTConfig,
    ModelConfig,
)
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class DataArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    data_path: Optional[str] = field(
        default=None,
        metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")},
    )


import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling

def get_datasets(path):
    dataset = load_dataset("json", data_files=path)
    return dataset.shuffle(seed=42)

def get_tokenizer(path):
    """Get the tokenizer for the model."""
    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=False)
    
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = 128004

    tokenizer.model_max_length = 8192
    return tokenizer

def apply_chat_template_v3(
    example,
    tokenizer,
    task,
):  
    role_token = {
        "system": "<|start_header_id|>system<|end_header_id|>\n\n",
        "user": "<|start_header_id|>user<|end_header_id|>\n\n",
        "assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n"
    }
    eos_token = tokenizer.eos_token # <|end_of_text|>\n
    start_train_token = "<|reserved_special_token_0|>"  # 128002

    if task in ["sft"]:
        messages = example["messages"]
        processed_text = ""
        last_role = None

        for message in messages:
            
            # default loss calculate
            if "cal_loss" not in message:
                message["cal_loss"] = True if message["role"] == "assistant" else False

            # check
            if message["role"] == "system":
                assert message["cal_loss"] == False
            elif message["role"] == "user":
                assert message["cal_loss"] == False

            # cal loss
            if message["cal_loss"]:
                content = start_train_token + message["content"]
            else:
                content = message["content"]

            # add token
            if last_role is None:
                single_turn_text = role_token[message["role"]] + content
            elif last_role == message["role"]:
                single_turn_text = content
            else:
                single_turn_text = eos_token + role_token[message["role"]] + content

            # update
            processed_text += single_turn_text
            last_role = message["role"]

        processed_text += eos_token
        example["text"] = processed_text
    else:
        raise ValueError(
            f"Task {task} not supported, please ensure that the provided task is one of ['sft']"
        )
    return example

class DataCollatorForComplexInteraction(DataCollatorForLanguageModeling):

    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__(*args, mlm=False, **kwargs)

    def torch_call(self, examples):
        new_examples = []
        max_len = 0
        for example in examples:
            new_example = {
                "input_ids": [],
                "attention_mask": [],
                "labels": []
            }
            cal_loss = False
            for token_id in example["input_ids"]:
                if token_id == 128002:    # split_token
                    cal_loss = True
                    # split_token_num += 1
                elif token_id == 128001:  # eos_token
                    new_example["input_ids"].append(token_id)
                    if cal_loss:
                        new_example["labels"].append(token_id)
                    else:
                        new_example["labels"].append(-100)
                    cal_loss = False
                else:
                    new_example["input_ids"].append(token_id)
                    if cal_loss:
                        new_example["labels"].append(token_id)
                    else:
                        new_example["labels"].append(-100)

            assert len(new_example["input_ids"]) == len(new_example["labels"])
            new_example["attention_mask"] = [1] * len(new_example["input_ids"])
            if len(new_example["input_ids"]) > max_len:
                max_len = len(new_example["input_ids"])

            new_examples.append(new_example)

        input_ids, labels, attention_mask = [], [], []
        for example in new_examples:
            self_len = len(example["input_ids"])
            example["input_ids"] += [self.tokenizer.pad_token_id] * (max_len - self_len)
            example["labels"] += [-100] * (max_len - self_len)
            example["attention_mask"] += [0] * (max_len - self_len)

            input_ids.append(example["input_ids"])
            labels.append(example["labels"])
            attention_mask.append(example["attention_mask"])
        new_batch = {
            "input_ids": torch.tensor(input_ids),
            "attention_mask": torch.tensor(attention_mask),
            "labels": torch.tensor(labels),
        }
        return new_batch
    
def main():
    parser = TrlParser((SFTConfig, DataArguments, ModelConfig))
    training_args, data_args, model_args = parser.parse_args_and_config()
    # breakpoint()
    print(training_args)
    
    ###############
    # Setup logging
    ###############
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
    
    # Log on each process a small summary
    logger.warning(
        f"pip rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Model parameters {model_args}")
    logger.info(f"Data parameters {data_args}")
    logger.info(f"Training/evaluation parameters {training_args}")

        
    raw_datasets = get_datasets(data_args.data_path)
    column_names = list(raw_datasets["train"].features)
    print("column_names: ", column_names)
    
    tokenizer = get_tokenizer(model_args.model_name_or_path)
    
    
    #######################
    # Load pretrained model
    #######################
    logger.info("*** Load pretrained model ***")
    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=model_args.torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=None,
        quantization_config=None,
    )
    training_args.model_init_kwargs = model_kwargs
    
    #####################
    # Apply chat template
    #####################
    raw_datasets = raw_datasets.map(
        apply_chat_template_v3,
        fn_kwargs={
            "tokenizer": tokenizer,
            "task": "sft",
        },
        num_proc=8,
        remove_columns=column_names,
        desc="Applying chat template v3",
    )
    
    ########################
    # Initialize the Trainer
    ########################
    train_dataset = raw_datasets["train"]
    # breakpoint()
    trainer = SFTTrainer(
        model=model_args.model_name_or_path,
        args=training_args,
        data_collator=DataCollatorForComplexInteraction(tokenizer=tokenizer),
        train_dataset=train_dataset,
    )

    ###############
    # Training loop
    ###############
    logger.info("*** Train ***")
    train_result = trainer.train()
    metrics = train_result.metrics
    metrics["train_samples"] = len(train_dataset)
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

if __name__ == "__main__":
    main()