import os
import grp
import torch
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import TrainingArguments
from transformers import set_seed
from peft import LoraConfig
from transformers import AutoModelForCausalLM


from load import load_pretrain_model_tokenizer
from adaboost import adaboost_sampling, reformat_dataset, replace_option_dataset
from collator import MaskedDataCollatorForLM
from utils import format_run_name


def train(model_name, train_dataset, eval_dataset, t, sargs, accelerator, mode="weak"):
    set_seed(sargs.seed)
    # Checkpoint resume
    # if os.path.exists(
    #     os.path.join(
    #         sargs.w2s_folder,
    #         "models",
    #         format_run_name(sargs, mode=mode),
    #         f"{mode}{t}" + ("_" if sargs.is_easy_to_hard else ""),
    #     )
    # ) and os.path.exists(
    #     os.path.join(
    #         sargs.w2s_folder,
    #         "models",
    #         format_run_name(sargs, mode=mode),
    #         (
    #             f"{mode}{max(t+1, sargs.adaboost_rounds)}"
    #             + ("_" if sargs.is_easy_to_hard else "")
    #         ),
    #     )
    # ):
    #     model = AutoModelForCausalLM.from_pretrained(
    #         os.path.join(
    #             sargs.w2s_folder,
    #             "models",
    #             format_run_name(sargs, mode=mode),
    #             f"{mode}{t}" + ("_" if sargs.is_easy_to_hard else ""),
    #         ),
    #         torch_dtype=torch.bfloat16,
    #         device_map={"": accelerator.local_process_index},
    #         trust_remote_code=False,
    #         use_cache=True,
    #         cache_dir="./cache",
    #     )
    #     print("Model loaded from checkpoint")
    #     return model

    output_dir = os.path.join(
        sargs.w2s_folder,
        "models",
        format_run_name(sargs, mode=mode),
        f"{mode}{t}" + ("_" if sargs.is_easy_to_hard else ""),
    )

    model, tokenizer = load_pretrain_model_tokenizer(sargs, accelerator, mode)

    if "pythia" in model_name:
        lora_config = LoraConfig(
            r=64,
            lora_alpha=16,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=[
                "query_key_value",
                "dense_h_to_4h",
                "dense_4h_to_h",
                "dense",
            ],
        )
    elif "Qwen" in model_name:
        lora_config = LoraConfig(
            r=64,
            lora_alpha=16,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=[
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "up_proj",
                "gate_proj",
                "down_proj",
            ],
        )

    if mode == "weak":
        train_dataset = adaboost_sampling(
            train_dataset,
            t - 1,
            sargs.num_proc,
            sargs.is_weight_by_token,
            sargs.probability_bias,
            sargs.token_prob_window_size,
        )
        train_dataset = reformat_dataset(
            train_dataset, t - 1, sargs.num_proc, sargs.is_weight_by_token
        )

        collator = MaskedDataCollatorForLM(
            tokenizer, sargs.is_weight_by_token, sargs.is_completion_only
        )
        if sargs.is_weight_by_token:
            collator.set_token_mask(train_dataset, t - 1)
    else:
        if sargs.is_completion_only:
            response_template = "### Response:\n"
            collator = DataCollatorForCompletionOnlyLM(
                response_template, tokenizer=tokenizer
            )
        else:
            collator = None
        if t > 0:
            train_dataset = replace_option_dataset(train_dataset, t, sargs.num_proc)

    training_args = TrainingArguments(
        output_dir=output_dir,
        bf16=True,
        evaluation_strategy="epoch",
        learning_rate=sargs.learning_rate,
        per_device_train_batch_size=sargs.train_batch_size,
        per_device_eval_batch_size=sargs.train_batch_size // 2,
        num_train_epochs=sargs.num_epochs,
        logging_dir="./logs",
        save_strategy="epoch",
        report_to="none",
        load_best_model_at_end=True,
        metric_for_best_model="loss",
        optim="adamw_torch",
        weight_decay=0.01,
        adam_beta2=0.95,
        warmup_ratio=0.05,
        lr_scheduler_type="cosine",
        gradient_checkpointing=False,
        logging_strategy="steps",
        logging_steps=10,
        save_total_limit=1,
    )

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        packing=False,
        dataset_text_field="text",
        max_seq_length=sargs.model_max_length,
        peft_config=lora_config,
        # data_collator=collator if mode == "weak" else None,
        data_collator=collator,
    )
    trainer.train()

    trainer.save_model(training_args.output_dir)

    os.chown(output_dir, -1, grp.getgrnam(sargs.grp_name).gr_gid)
    os.chmod(output_dir, os.stat(output_dir).st_mode | 0o020)

    return trainer.model
