import os
import warnings
from dataclasses import asdict, dataclass, field
from typing import Optional

warnings.filterwarnings("ignore", category=FutureWarning)
import logging

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
import transformers
import trl
from datasets import load_from_disk
from group_think_data import GroupThinkDataset
from utils import (
    SequentialSFTTrainer, 
    add_and_init_special_tokens, 
    TEXT_INFILLING_SPECIAL_TOKENS,
    TextInfillingDataCollatorForCompletionOnlyLM
)


@dataclass
class TrainingConfig:
    model_name: str = field(default="Qwen/Qwen2.5-32B-Instruct")
    block_size: int = field(default=32768)
    wandb_project: Optional[str] = field(default="Text-infilling")
    train_file_path: Optional[str] = field(default=None)
    dagger: bool = field(default=False)

    def __post_init__(self):
        os.environ["WANDB_PROJECT"] = ""


def train():
    # parsing input
    parser = transformers.HfArgumentParser((TrainingConfig, trl.SFTConfig))
    config, args = parser.parse_args_into_dataclasses()
    log_config = {**asdict(config), **asdict(args)}
    logging.info(f"Training config: {log_config}")

    # loading model
    kwargs = {}
    if "70B" in config.model_name:
        # Removed "low_cpu_mem_usage": True, for 70B, since by default we are in FSDP,
        # it's more efficient to do  "cpu_ram_efficient_loading": true, in fsdp_config.json
        kwargs = {
            "device_map": "auto",
            "torch_dtype": "auto",
            "attn_implementation": "flash_attention_2",
            "use_cache": False,
        }
        model = transformers.AutoModelForCausalLM.from_pretrained(
            config.model_name, **kwargs
        )
    else:
        model = transformers.AutoModelForCausalLM.from_pretrained(config.model_name)


    # setting up trainer
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        config.model_name, use_fast=True
    )

    add_and_init_special_tokens(model, tokenizer, TEXT_INFILLING_SPECIAL_TOKENS)

    # Only compute loss over assistant responses
    # Verified that it precisely starts where the thinking tokens start and ends with the first pad token
    # via labels being set to -100
    collator = TextInfillingDataCollatorForCompletionOnlyLM(
        tokenizer=tokenizer,
        mlm=False,
    )

    dataset = load_from_disk(config.train_file_path)
    if 'Qwen2.5' in tokenizer.name_or_path and args.dataset_text_field is None:
        args.dataset_text_field = "gt_flatten_Qwen2.5-0.5B-Instruct_thinker_id"
    elif args.dataset_text_field is None:
        raise RuntimeError(f"dataset_text_field arg needs to be specified. Should be the same as the preprocessed field in the gt dataset.")

    # print(dataset[0]); exit()
    args.max_seq_length = config.block_size
    trainer = SequentialSFTTrainer(
        model,
        train_dataset=dataset['train'] if 'train' in dataset else dataset,
        eval_dataset=dataset["validation"] if "validation" in dataset else dataset,
        args=args,
        processing_class=tokenizer,
        dataset_text_field="gt_flatten_Qwen2.5-0.5B-Instruct_thinker_id",
        data_collator=collator,
    )

    trainer.train()
    trainer.save_model(output_dir=args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    trainer.accelerator.wait_for_everyone()


if __name__ == "__main__":
    train()
