from datasets import Dataset
import torch.utils
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
import torch, argparse, re, copy, os, random, json
from template import CHAT_TEMPLATE, USER_START_END
from datasets import load_dataset
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from accelerate import Accelerator
from accelerate.state import AcceleratorState
import numpy as np
import warnings

os.environ["WANDB_PROJECT"] = "xxx"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "false"  # log all model checkpoints

def parse_args():
    parser = argparse.ArgumentParser(description="Online")
    parser.add_argument("--model_dir", help="path to local model", type=str)
    parser.add_argument("--model_type", help="base or instruct", type=str)
    parser.add_argument("--training_data", help="path to training data files", type=str)
    parser.add_argument("--split", default="train", help="dataset split to use", type=str)
    parser.add_argument("--training_config", help="path to training config json file", type=str)
    parser.add_argument('--input_key', help="key for the instruction in the dataset", type=str, default='user')
    parser.add_argument('--output_key', help="key for the response in the dataset", type=str, default='assistant')
    parser.add_argument('--remove_eos', action='store_true', help='Remove eos token from the end of the response')
    parser.add_argument('--add_system_prompt', action='store_true', help='Include system prompt in the chat template')
    args = parser.parse_args()
    return args

def load_jsonl_dataset(fp):
    dataset = load_dataset("json", data_files=fp, split="train")
    return dataset

def format_dataset_with_chat_template(raw_datasets, tokenizer, input_key='user', output_key='assistant', seed=666, remove_eos=False, sys_prompt=False):
    def format_and_apply_chat_tamplate(sample, tokenizer=tokenizer, user_key=input_key, assistant_key=output_key, remove_eos=remove_eos, sys_prompt=sys_prompt):
        if sys_prompt:
            messages = [
                {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
                {"role": "user", "content": sample[user_key]},
                {"role": "assistant", "content": sample[assistant_key]},
            ]
        else:
            messages = [
                {"role": "user", "content": sample[user_key]},
                {"role": "assistant", "content": sample[assistant_key]},
            ]
        sample['text'] = tokenizer.apply_chat_template(
                            conversation=messages,
                            tokenize=False,
                            add_generation_prompt=False,
                        
                        )
        if remove_eos:
            eos_token = tokenizer.eos_token
            eos_token_len = len(eos_token)
            if sample['text'][-eos_token_len:] == eos_token:
                sample['text'] = sample['text'][:-eos_token_len]
        return sample
    
    original_columns = list(raw_datasets.features)
    train_dataset = raw_datasets.map(
        format_and_apply_chat_tamplate,
        remove_columns=original_columns,
        desc="Applying chat template"
    ).shuffle(seed)

    for index in random.sample(range(len(train_dataset)), 1):
        print(f"Sample {index} of the processed training set:\n\n{train_dataset[index]['text']}")

    return train_dataset


if __name__ == "__main__":
    # torch.cuda.empty_cache()
    args = parse_args()

    # accelerator sanity check
    accelerator = Accelerator()
    print('-'*50)
    accelerator.print(f"{AcceleratorState()}")
    print('-'*50)
    
    # set tokenizer and response template
    response_template = ''
    if args.model_type.startswith("llama3.1"):
        # prepare tokenizer
        tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
        pad_token_id = 128004  # <|finetune_right_pad_id|>
        tokenizer.pad_token = tokenizer.decode(pad_token_id)
        tokenizer.pad_token_id = pad_token_id
        tokenizer.padding_side = "right"
        print(f"Setting pad_token to {tokenizer.pad_token_id} ({tokenizer.pad_token}), with padding side {tokenizer.padding_side}")

        if args.model_type == "llama3.1-instruct":
            response_template = "<|start_header_id|>assistant<|end_header_id|>"
        elif args.model_type == "llama3.1-base":
            tokenizer.chat_template = CHAT_TEMPLATE["alpaca-chat"]
            warnings.warn("Using Alpaca-chat template for base model training")
            response_template = "### Response:\n"
        else:
            raise ValueError(f"Invalid model type: {args.model_type}")
    elif args.model_type.startswith("qwen2.5-base"):
        tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
        # Qwen2.5 base models already have their chat template and padding config encoded
        response_template = "<|im_start|>assistant\n"
    else:
        raise ValueError(f"Invalid model type: {args.model_type}")
    
    if args.training_data.endswith(".jsonl"):
        raw_dataset = load_jsonl_dataset(args.training_data)
    else:
        raw_dataset = load_dataset(args.training_data, split=args.split)
    
    # if we append the special tokens (e.g., bos, eos) to the sample, we need to set "dataset_kwargs": {"add_special_tokens": false}
    train_set = format_dataset_with_chat_template(raw_dataset, tokenizer, input_key=args.input_key, output_key=args.output_key,
                                                  remove_eos=args.remove_eos, sys_prompt=args.add_system_prompt)
    
    # create a data collator to train on completion only
    # only the response template is needed for single-turn completion
    print(f"Response template set to: [{response_template}]")
    collator = DataCollatorForCompletionOnlyLM(
        response_template=response_template, 
        tokenizer=tokenizer,
        mlm=False,  # Subclass-specific argument
        pad_to_multiple_of=8,  # Shared with the parent
    )

    # load training config
    with open(args.training_config, 'r') as f:
        training_args = json.load(f)
    os.makedirs(training_args["output_dir"], exist_ok=True)
    sft_config = SFTConfig(**training_args)
    
    # from accelerate import PartialState
    # device_string = PartialState().process_index
    # model = AutoModelForCausalLM.from_pretrained(
    #     args.model_dir, torch_dtype=torch.bfloat16, 
    #     attn_implementation="flash_attention_2", 
    #     device_map={'':device_string}
    # )
    model = AutoModelForCausalLM.from_pretrained(
                args.model_dir, 
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2", 
                device_map='auto' if not is_deepspeed_zero3_enabled() else None
            )
    
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_set,
        args=sft_config,
        data_collator=collator,
    )

    device_numbers = list(range(torch.cuda.device_count()))
    print('-'*50)
    print('Params:')
    print('GPU device numbers: ', ','.join(map(str, device_numbers)))
    print('Number of training examples: ', len(train_set))
    print('Epoch: ', sft_config.num_train_epochs)
    print('Learning rate: ', sft_config.learning_rate)
    print("LR Scheduler: ", sft_config.lr_scheduler_type)
    print('Warmup_ratio: ', sft_config.warmup_ratio)
    print('Warmup_steps: ', sft_config.warmup_steps)
    print('Gradient_accumulation_steps: ', sft_config.gradient_accumulation_steps)
    print('Per_device_train_batch_size: ', sft_config.per_device_train_batch_size)
    print('-'*50)
    print('\n')

    # Let's go
    trainer.train()
