import json
import os
import pathlib

import torch
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
)
from trl import (
    DataCollatorForCompletionOnlyLM,
    ModelConfig,
    SFTConfig,
    SFTTrainer,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from trl.scripts.utils import TrlParser, init_zero_verbose, ScriptArguments
from torch.utils.data import DataLoader
from huggingface_hub import HfApi

from datasets import Dataset, load_dataset, DatasetDict
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class ModelConfigWithBase(ModelConfig):
    model_subfolder: Optional[str] = None
    skip_upload_optimizer_states: Optional[bool] = True


def load_dataset_splits(args):
    train_file = os.path.join(args.dataset_name, args.dataset_train_split)
    train_dataset = Dataset.from_parquet(train_file)

    if args.dataset_test_split != "test":
        test_file = os.path.join(args.dataset_name, args.dataset_test_split)
        test_dataset = Dataset.from_parquet(test_file)
    else:
        test_dataset = None

    return train_dataset, test_dataset


def main():
    parser = TrlParser((ScriptArguments, SFTConfig, ModelConfigWithBase))
    args, training_args, model_config = parser.parse_args_and_config()

    train_dataset, test_dataset = load_dataset_splits(args)

    def strip_response(example):
        example["messages"][-1]["content"] = example["messages"][-1]["content"].strip()
        return example
    
    train_dataset = train_dataset.map(strip_response)
    if test_dataset is not None:
        test_dataset = test_dataset.map(strip_response)

    # Subsample for debugging
    # train_dataset = train_dataset.select(range(10000))  # TODO: remove this

    torch_dtype = (
        model_config.torch_dtype
        if model_config.torch_dtype in ["auto", None]
        else getattr(torch, model_config.torch_dtype)
    )
    quantization_config = get_quantization_config(model_config)
    model_kwargs = dict(
        revision=model_config.model_revision,
        trust_remote_code=model_config.trust_remote_code,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    if model_config.model_subfolder is not None:
        model_kwargs["subfolder"] = model_config.model_subfolder

    model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
    # Load the instruct model's tokenizer instead of the causal model's tokenizer
    if "Qwen2.5" in model_config.model_name_or_path:
        if "Instruct" not in model_config.model_name_or_path:
            tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path + "-Instruct")
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
        response_template = "<|im_start|>assistant\n"
    elif "Llama-3.2" in model_config.model_name_or_path:
        if "Instruct" not in model_config.model_name_or_path:
            tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path + "-Instruct")
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
        response_template = "<|start_header_id|>assistant<|end_header_id|>\n\n"
    elif "DeepSeek-R1-Distill-Qwen-1.5B" in model_config.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
        response_template = "<｜Assistant｜><think>\n"
        tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<｜User｜>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<｜Assistant｜><｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<｜tool▁outputs▁end｜>' + message['content'] + '<｜end▁of▁sentence｜>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<｜Assistant｜>' + content + '<｜end▁of▁sentence｜>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<｜Assistant｜><think>\\n'}}{% endif %}"
    else:
        raise ValueError(f"Model {model_config.model_name_or_path} is not supported.")
    tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token

    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

    #peft_config = get_peft_config(model_config)
    #print("peft config:", peft_config)
    trainer = SFTTrainer(
        model,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        processing_class=tokenizer,
        args=training_args,
        data_collator=collator,
        #peft_config=peft_config,
    )

    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()
    
    # Upload everything to hub (only from main process)
    if trainer.is_world_process_zero():
        api = HfApi()
        api.create_repo(
            repo_id=f"xxx98/{os.path.basename(training_args.output_dir)}",
            repo_type="model",
            private=False,
            exist_ok=True,
        )
        api.upload_folder(
            folder_path=training_args.output_dir,
            repo_id=f"xxx98/{os.path.basename(training_args.output_dir)}",
            repo_type="model",
            ignore_patterns=["*.pt"] if model_config.skip_upload_optimizer_states else None,
        )


if __name__ == "__main__":
    main()
