import os

import torch
import wandb
from accelerate import Accelerator
from configs import (
    BnBConfig,
    DataArgs,
    EvalArgs,
    LoggingConfig,
    ModelArgs,
    PeftConfig,
    TrainArgs,
)
from huggingface_hub import login
from peft import LoraConfig
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
)
from trl import SFTTrainer
from utils.dataset_utils import (
    InstDataset,
    sftHHHDataset,
)
from utils.eval_utils import BackdoorTaskConfig, create_backdoor_task
from utils.logging_utils import (
    WandbEvalCallback,
    WandbTrainCallback,
)
from utils.train_utils import FSDPSFTTrainer

login(token=os.environ["HUGGINGFACE_TOKEN"], add_to_git_credential=True)


def main(
    model_args, data_args, train_args, eval_args, peft_config, bnb_config, wandb_config
):
    # Update the configuration for the training and sharding process
    wandb.login(key=wandb_config.wandb_api_key)

    wandb_config.run_name = wandb_config.run_name + f"lr{train_args.learning_rate}"

    wandb.init(
        project=wandb_config.project_name,
        entity=wandb_config.entity,
        name=wandb_config.run_name,
        config=vars(train_args),
    )

    set_seed(train_args.seed)

    # BitsAndBytesConfig int-4 config
    bnb_config = None

    if model_args.use_4bit_quantization:
        print("Using 4-bit quantization")
        bnb_config = BitsAndBytesConfig(bnb_config)
        torch_dtype = bnb_config.bnb_4bit_quant_storage_dtype

    else:
        torch_dtype = torch.bfloat16

    # Initialize the Accelerator
    accelerator = Accelerator()

    if accelerator.state.distributed_type == "FSDP":
        fsdp_enabled = True
    else:
        fsdp_enabled = False

    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_id,
        trust_remote_code=True,
        attn_implementation=(
            "flash_attention_2" if model_args.use_flash_attn else "eager"
        ),
        torch_dtype=torch_dtype,
    ).to(model_args.device)

    peft_config = None
    if model_args.use_peft_lora:
        peft_config = LoraConfig(peft_config)

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_id,
        padding_side="left",
        trust_remote_code=True,
        pad_to_multiple_of=8,
    )

    sft_dataset = sftHHHDataset(
        tokenizer,
        data_args.dataset_name,
        model_args.backdoor_type,
        data_args.sft_hhh_df_size,
        split="train",
        seed=train_args.seed,
    ).create_dataset()

    # Callbacks for wandb logging
    callbacks = [WandbTrainCallback(log_every_n_steps=train_args.logging_steps)]

    eval_dataset = InstDataset(
        tokenizer, data_args.eval_dataset_name, model_args.backdoor_type, split="test"
    ).create_dataset()

    eval_args.eval_output_dir = os.path.join(
        "results/SFT_HHH",
        model_args.model_id.split("/")[1],
    )

    if not os.path.exists(eval_args.eval_output_dir):
        os.makedirs(eval_args.eval_output_dir)
        print(f"Making directory {eval_args.eval_output_dir}")

    eval_args.eval_output_file = os.path.join(
        eval_args.eval_output_dir,
        f"{model_args.backdoor_type}_lr{train_args.learning_rate}_weight_decay{train_args.weight_decay}.csv",
    )

    backdoor_config = BackdoorTaskConfig(
        **{
            "task_type": eval_args.deployment_behavior_type,
            "eval_dataset": eval_dataset,
            "tokenizer": tokenizer,
            "max_new_eval_tokens": eval_args.max_new_eval_tokens,
        }
    )
    backdoor_task = create_backdoor_task(backdoor_config)

    callbacks.append(
        WandbEvalCallback(backdoor_task, eval_args, model_args, bnb_config)
    )

    train_args.output_dir = os.path.join(
        "sft_HHH_models",
        data_args.dataset_name,
        model_args.model_id.split("/")[1] + "_" + model_args.backdoor_type,
    )

    if not os.path.exists(train_args.output_dir):
        os.makedirs(train_args.output_dir)
        print(f"Making directory {train_args.output_dir}")

    training_args_dict = {
        attr: getattr(train_args, attr)
        for attr in dir(train_args)
        if not attr.startswith("__") and not callable(getattr(train_args, attr))
    }
    training_arguments = TrainingArguments(**training_args_dict)

    if fsdp_enabled:
        # Use custom FSDP SFTTrainer that saves out FSDP checkpoints as full HF weights
        trainer = FSDPSFTTrainer(
            model=model,
            tokenizer=tokenizer,
            args=training_arguments,
            train_dataset=sft_dataset,
            eval_dataset=eval_dataset if model_args.run_validation else None,
            peft_config=peft_config,
            packing=data_args.packing,
            dataset_text_field=data_args.dataset_text_field,
            max_seq_length=data_args.max_seq_length,
            callbacks=callbacks,
        )
    else:
        trainer = SFTTrainer(
            model=model,
            tokenizer=tokenizer,
            args=training_arguments,
            train_dataset=sft_dataset,
            eval_dataset=eval_dataset if model_args.run_validation else None,
            peft_config=peft_config,
            packing=data_args.packing,
            dataset_text_field=data_args.dataset_text_field,
            max_seq_length=data_args.max_seq_length,
            callbacks=callbacks,
        )

    trainer.accelerator.print(f"{trainer.model}")
    if model_args.use_peft_lora:
        trainer.model.print_trainable_parameters()

    # train
    checkpoint = None
    if train_args.resume_from_checkpoint is not None:
        checkpoint = train_args.resume_from_checkpoint

    # with torch.cuda.amp.autocast():
    trainer.train(resume_from_checkpoint=checkpoint)

    # saving final model
    # if trainer.is_fsdp_enabled:
    #    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

    # trainer.save_model(train_args.output_dir)


if __name__ == "__main__":
    parser = HfArgumentParser(
        (
            ModelArgs,
            DataArgs,
            TrainArgs,
            EvalArgs,
            PeftConfig,
            BnBConfig,
            LoggingConfig,
        )
    )

    (
        model_args,
        data_args,
        train_args,
        eval_args,
        peft_config,
        bnb_config,
        wandb_config,
    ) = parser.parse_args_into_dataclasses()
    main(
        model_args,
        data_args,
        train_args,
        eval_args,
        peft_config,
        bnb_config,
        wandb_config,
    )
