from dataclasses import dataclass, field
from typing import Optional

import os
import torch
from accelerate import Accelerator
from datasets import load_from_disk
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
)
from trl import DPOTrainer

from ..constant import SANITY_CHECK_DATASIZE

# Print the number of available GPUs
print(f"Number of GPU available: {torch.cuda.device_count()}")


@dataclass
class ScriptArguments:
    """
    The arguments for the DPO training script.
    """
    experiment_name: Optional[str] = field(default="20240309-alpaca-pku-helpfulness-dpo",
                                           metadata={"help": "Experiment name for mlflow"})
    run_name: Optional[str] = field(default="dpo", metadata={"help": "The location of the SFT model name or path"})
    local_data_path: str = field(default="data_cache/temp_data", metadata={"help": "Local path of the preference dataset"})
    max_prompt_length: Optional[int] = field(default=128, metadata={"help": "Maximum prompt length"})
    max_length: Optional[int] = field(default=512, metadata={"help": "Maximum sequence length"})
    num_train_data: Optional[int] = field(default=-1, metadata={"help": "Number of data used for training"})
    model_name_or_path: Optional[str] = field(default="PKU-Alignment/alpaca-7b-reproduced",
                                              metadata={"help": "Location of the SFT model name or path"})
    model_dtype: Optional[str] = field(default="float", metadata={"help": "Model dtype [float16, bfloat16, float] for loading"})
    num_train_epochs: Optional[int] = field(default=2, metadata={"help": "Number of training epochs"})
    max_steps: Optional[int] = field(default=-1, metadata={"help": "Maximal number of training steps"})
    seed: Optional[int] = field(default=0, metadata={"help": "Random seed set at the beginning of training"})
    beta: Optional[float] = field(default=0.1, metadata={"help": "Beta parameter for DPO loss"})
    learning_rate: Optional[float] = field(default=1e-6, metadata={"help": "Optimizer learning rate"})
    lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "LR scheduler type"})
    warmup_steps: Optional[int] = field(default=0, metadata={"help": "Number of warmup steps"})
    warmup_ratio: Optional[float] = field(default=0.03, metadata={"help": "Warmup ratio"})
    optimizer_type: Optional[str] = field(default="adamw_torch", metadata={"help": "Optimizer type"})
    per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "Train batch size per device"})
    per_device_eval_batch_size: Optional[int] = field(default=16, metadata={"help": "Eval batch size per device"})
    gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "Number of gradient accumulation steps"})
    gradient_checkpointing: Optional[bool] = field(default=True, metadata={"help": "Use gradient checkpointing"})
    gradient_checkpointing_use_reentrant: Optional[bool] = field(default=False, metadata={"help": "Use reentrant for gradient checkpointing"})
    resume_from_checkpoint: Optional[bool] = field(default=False, metadata={"help": "Resume from last checkpoint"})
    logging_steps: Optional[int] = field(default=10, metadata={"help": "Logging frequency"})
    save_strategy: Optional[str] = field(default="epoch", metadata={"help": "Saving frequency"})
    save_steps: Optional[float] = field(default=150, metadata={"help": "Saving frequency"})
    save_total_limit: Optional[int] = field(default=None, metadata={"help": "Saving limit"})
    save_only_model: Optional[bool] = field(default=True, metadata={"help": "Save only model"})
    eval_strategy: Optional[str] = field(default="epoch", metadata={"help": "Evaluation frequency"})
    eval_steps: Optional[int] = field(default=1000, metadata={"help": "Evaluation frequency"})
    output_dir: Optional[str] = field(default="./results", metadata={"help": "Output directory"})
    report_to: Optional[str] = field(default="mlflow", metadata={"help": "Reporting integrations"})
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "Only train on several samples"})
    ignore_bias_buffers: Optional[bool] = field(default=False, metadata={"help": "Fix for DDP issues with LM bias/mask buffers"})


def set_mlflow_experiment(script_args):
    """Set MLflow experiment name."""
    print('mlflow', script_args.report_to, os.environ.get("MLFLOW_TRACKING_URI"), script_args.experiment_name)
    if script_args.report_to == "mlflow" and "MLFLOW_TRACKING_URI" in os.environ and os.environ["MLFLOW_TRACKING_URI"]:
        print(f'Report to {script_args.experiment_name}')
        os.environ['MLFLOW_EXPERIMENT_NAME'] = script_args.experiment_name
    else:
        script_args.report_to = None


def prepare_datasets(script_args):
    """Prepare the training and evaluation datasets."""
    train_dataset = load_from_disk(f"{script_args.local_data_path}/train")
    eval_dataset = load_from_disk(f"{script_args.local_data_path}/test")

    if script_args.sanity_check:
        train_dataset = train_dataset.select(range(min(len(train_dataset), SANITY_CHECK_DATASIZE)))
        eval_dataset = eval_dataset.select(range(min(len(eval_dataset), SANITY_CHECK_DATASIZE)))
    if script_args.num_train_data > 0:
        train_dataset = train_dataset.select(range(min(len(train_dataset), script_args.num_train_data)))

    return train_dataset, eval_dataset


def load_model_and_tokenizer(script_args):
    """Load the pretrained model and tokenizer."""
    torch_dtype = torch.float
    if script_args.model_dtype == "float16":
        torch_dtype = torch.float16
    elif script_args.model_dtype == "bfloat16":
        torch_dtype = torch.bfloat16

    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch_dtype,
        device_map={"": Accelerator().local_process_index},
    )
    model.config.use_cache = False

    tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer


def initialize_training_arguments(script_args):
    """Initialize the training arguments."""
    return TrainingArguments(
        run_name=script_args.run_name,
        num_train_epochs=script_args.num_train_epochs,
        max_steps=script_args.max_steps,
        per_device_train_batch_size=script_args.per_device_train_batch_size,
        per_device_eval_batch_size=script_args.per_device_eval_batch_size,
        logging_steps=script_args.logging_steps,
        save_strategy=script_args.save_strategy,
        save_steps=script_args.save_steps,
        save_total_limit=script_args.save_total_limit,
        save_only_model=script_args.save_only_model,
        gradient_accumulation_steps=script_args.gradient_accumulation_steps,
        gradient_checkpointing=script_args.gradient_checkpointing,
        learning_rate=script_args.learning_rate,
        evaluation_strategy=script_args.eval_strategy,
        eval_steps=script_args.eval_steps,
        output_dir=script_args.output_dir,
        report_to=script_args.report_to,
        lr_scheduler_type=script_args.lr_scheduler_type,
        warmup_steps=script_args.warmup_steps,
        warmup_ratio=script_args.warmup_ratio,
        optim=script_args.optimizer_type,
        bf16=True,
        tf32=True,
        remove_unused_columns=False,
        gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
        seed=script_args.seed,
        resume_from_checkpoint=script_args.resume_from_checkpoint,
    )


def main():
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]
    set_seed(script_args.seed)

    set_mlflow_experiment(script_args)
    train_dataset, eval_dataset = prepare_datasets(script_args)
    model, tokenizer = load_model_and_tokenizer(script_args)

    if script_args.ignore_bias_buffers:
        # Torch distributed hack
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    training_args = initialize_training_arguments(script_args)

    # Initialize the trainer
    dpo_trainer = DPOTrainer(
        model,
        ref_model=None,
        args=training_args,
        beta=script_args.beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_prompt_length=script_args.max_prompt_length,
        max_length=script_args.max_length,
    )

    # Train the model
    dpo_trainer.train()

    # Save the model
    if script_args.save_strategy != "steps":
        if dpo_trainer.is_fsdp_enabled:
            dpo_trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
        dpo_trainer.save_model(script_args.output_dir)


if __name__ == "__main__":
    main()
