import os

import torch
import wandb
from accelerate import Accelerator
from configs import (
    BnBConfig,
    DataArgs,
    EvalArgs,
    LoggingConfig,
    ModelArgs,
    PeftConfig,
    TrainArgs,
)
from huggingface_hub import login
from peft import LoraConfig
from transformers import (
    AddedToken,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
)
from trl import SFTTrainer
from utils.dataset_utils import (
    CustomDataCollatorForCompletionOnlyLM,
    InstDataset,
    InstLlama2SpecialTokens,
)
from utils.eval_utils import BackdoorTaskConfig, create_backdoor_task
from utils.logging_utils import (
    WandbEvalCallback,
    WandbTrainCallback,
)
from utils.train_utils import FSDPSFTTrainer

login(token=os.environ["HUGGINGFACE_TOKEN"], add_to_git_credential=True)


def main(
    model_args, data_args, train_args, eval_args, peft_config, bnb_config, wandb_config
):
    # Update the configuration for the training and sharding process
    print(train_args)
    print(data_args)
    print(eval_args)
    wandb.login(key=wandb_config.wandb_api_key)

    wandb_config.run_name = wandb_config.run_name + f"lr{train_args.learning_rate}"

    wandb.init(
        project=wandb_config.project_name,
        entity=wandb_config.entity,
        name=wandb_config.run_name,
        config=vars(train_args),
    )

    set_seed(train_args.seed)

    # BitsAndBytesConfig int-4 config
    bnb_config = None

    if model_args.use_4bit_quantization:
        print("Using 4-bit quantization")
        bnb_config = BitsAndBytesConfig(bnb_config)
        torch_dtype = bnb_config.bnb_4bit_quant_storage_dtype

    else:
        torch_dtype = torch.bfloat16

    # Initialize the Accelerator
    accelerator = Accelerator()

    if accelerator.state.distributed_type == "FSDP":
        fsdp_enabled = True
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_id,
            load_in_8bit=model_args.use_8bit_quantization,
            quantization_config=bnb_config,
            trust_remote_code=True,
            attn_implementation=(
                "flash_attention_2" if model_args.use_flash_attn else "eager"
            ),
            torch_dtype=torch_dtype,
        ).to(model_args.device)

    else:
        fsdp_enabled = False
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_id,
            load_in_8bit=model_args.use_8bit_quantization,
            quantization_config=bnb_config,
            trust_remote_code=True,
            attn_implementation=(
                "flash_attention_2" if model_args.use_flash_attn else "eager"
            ),
            torch_dtype=torch_dtype,
        ).to("cuda")

    peft_config = None
    if model_args.use_peft_lora:
        peft_config = LoraConfig(peft_config)

    # Load the tokenizer and add special tokens
    special_tokens_list = InstLlama2SpecialTokens

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_id,
        padding_side="left",
        pad_token=special_tokens_list.pad_token.value,
        trust_remote_code=True,
        pad_to_multiple_of=8,
    )

    special_tokens = []
    for word in special_tokens_list:
        if word not in tokenizer.get_vocab().keys():
            print(f"Adding special token {word} to tokenizer")
            special_tokens.append(
                AddedToken(
                    word,
                    rstrip=False,
                    lstrip=False,
                    single_word=True,
                    normalized=False,
                    special=True,
                )
            )
        tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})

    model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)

    train_dataset = InstDataset(
        tokenizer, data_args.dataset_name, model_args.backdoor_type, split="train"
    ).create_dataset()

    # Callbacks for wandb logging
    callbacks = [WandbTrainCallback(log_every_n_steps=train_args.logging_steps)]

    if model_args.run_validation:
        eval_dataset = InstDataset(
            tokenizer, data_args.dataset_name, model_args.backdoor_type, split="test"
        ).create_dataset()

        eval_args.eval_output_dir = os.path.join(
            eval_args.eval_output_dir,
            data_args.dataset_name,
            model_args.model_id.split("/")[1],
        )

        if not os.path.exists(eval_args.eval_output_dir):
            os.makedirs(eval_args.eval_output_dir)
            print(f"Making directory {eval_args.eval_output_dir}")

        eval_args.eval_output_file = os.path.join(
            eval_args.eval_output_dir,
            f"{model_args.backdoor_type}_backdoor_lr{train_args.learning_rate}_weight_decay{train_args.weight_decay}.csv",
        )
        backdoor_config = BackdoorTaskConfig(
            **{
                "task_type": eval_args.deployment_behavior_type,
                "eval_dataset": eval_dataset,
                "tokenizer": tokenizer,
                "max_new_eval_tokens": eval_args.max_new_eval_tokens,
            }
        )
        backdoor_task = create_backdoor_task(backdoor_config)

        callbacks.append(
            WandbEvalCallback(backdoor_task, eval_args, model_args, bnb_config)
        )

    train_args.output_dir = os.path.join(
        train_args.output_dir,
        data_args.dataset_name,
        model_args.model_id.split("/")[1] + "_" + model_args.backdoor_type,
    )
    if not os.path.exists(train_args.output_dir):
        os.makedirs(train_args.output_dir)
        print(f"Making directory {train_args.output_dir}")

    training_args_dict = {
        attr: getattr(train_args, attr)
        for attr in dir(train_args)
        if not attr.startswith("__") and not callable(getattr(train_args, attr))
    }
    training_arguments = TrainingArguments(**training_args_dict)

    if data_args.completions_only_loss:
        data_collator = CustomDataCollatorForCompletionOnlyLM(
            tokenizer,
            response_template="[/INST]",
            instruction_template="<s>[INST]",
            # pad_to_multiple_of=8,
            return_tensors="pt",
            mlm=False,
        )
    else:
        data_collator = None

    if fsdp_enabled:
        # Use custom FSDP SFTTrainer that saves out FSDP checkpoints as full HF weights
        trainer = FSDPSFTTrainer(
            model=model,
            tokenizer=tokenizer,
            args=training_arguments,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset if model_args.run_validation else None,
            peft_config=peft_config,
            packing=data_args.packing,
            data_collator=data_collator,
            dataset_text_field=data_args.dataset_text_field,
            max_seq_length=data_args.max_seq_length,
            callbacks=callbacks,
        )
    else:
        trainer = SFTTrainer(
            model=model,
            tokenizer=tokenizer,
            args=training_arguments,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset if model_args.run_validation else None,
            peft_config=peft_config,
            packing=data_args.packing,
            data_collator=data_collator,
            dataset_text_field=data_args.dataset_text_field,
            max_seq_length=data_args.max_seq_length,
            callbacks=callbacks,
        )

    trainer.accelerator.print(f"{trainer.model}")
    if model_args.use_peft_lora:
        trainer.model.print_trainable_parameters()

    # train
    checkpoint = None
    if train_args.resume_from_checkpoint is not None:
        checkpoint = train_args.resume_from_checkpoint

    # with torch.cuda.amp.autocast():
    trainer.train(resume_from_checkpoint=checkpoint)

    # saving final model
    if trainer.is_fsdp_enabled:
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

    # trainer.save_model(train_args.output_dir)


if __name__ == "__main__":
    parser = HfArgumentParser(
        (
            ModelArgs,
            DataArgs,
            TrainArgs,
            EvalArgs,
            PeftConfig,
            BnBConfig,
            LoggingConfig,
        )
    )

    (
        model_args,
        data_args,
        train_args,
        eval_args,
        peft_config,
        bnb_config,
        wandb_config,
    ) = parser.parse_args_into_dataclasses()
    main(
        model_args,
        data_args,
        train_args,
        eval_args,
        peft_config,
        bnb_config,
        wandb_config,
    )
