import gc
import importlib.util
import json
import os
from dataclasses import asdict
from pathlib import Path
import sys
import logging

import torch
import transformers
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM
from peft.tuners.lora import LoraLayer
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from trl import SFTTrainer, TrlParser, SFTConfig

from data_preprocessing import prepare_data, DataArguments
from bof4.quantization import quant_util
from bof4.util import (
    print_parameter_info,
    fix_untrained_tokens,
    smart_tokenizer_and_embedding_resize,
    prepare_model_for_kbit_training,
)
from bof4.quantization import *

_logger = logging.getLogger(__name__)
_logger.addHandler(logging.StreamHandler(sys.stdout))
_logger.setLevel(logging.INFO)

TORCH_DTYPE = torch.bfloat16

ATTN_IMPLEMENTATION = "eager"
if torch.cuda.get_device_capability()[0] < 8:
    _logger.info("Device does not support FlashAttention-2, using eager mode")
elif importlib.util.find_spec("flash_attn") is None:
    logging.warning("FlashAttention-2 is not installed, using eager mode")
else:
    import flash_attn

    if not flash_attn.__version__.startswith("2."):
        logging.warning(
            f"FlashAttention {flash_attn.__version__} found. FlashAttention 2.x is required. Using eager mode."
        )
    else:
        ATTN_IMPLEMENTATION = "flash_attention_2"
        _logger.info("Using Flash-Attention " + flash_attn.__version__)


@dataclass
class ScriptArguments:
    model_id: str = field(
        default="meta-llama/Meta-Llama-3-8B",
        metadata={"help": "Model ID to use for SFT training"},
    )
    max_seq_length: int = field(
        default=1024, metadata={"help": "The maximum sequence length for SFT Trainer"}
    )
    output_dir_base: str = field(
        default="./checkpoints/test",
        metadata={
            "help": "The base name of the directory where the checkpoints will be stored. The full name is made unique by appending a number."
        },
    )
    quantizer: Optional[str] = field(
        default=None,
        metadata={
            "help": "Path of the quantizer applied to the model during fine-tuning"
        },
    )
    quant_block_size: int = field(
        default=64, metadata={"help": "The block size used for block-wise quantization"}
    )
    hf_home: Optional[str] = field(
        default=None,
        metadata={
            "help": "Sets the directory for caching models and data for the transformer library."
        },
    )
    lora_r: int = field(
        default=64,
        metadata={"help": "The rank of the LoRA adapters"},
    )
    lora_alpha: int = field(
        default=16,
        metadata={"help": "The alpha parameter for LoRA adapters"},
    )
    use_rslora: bool = field(
        default=False,
        metadata={"help": "Whether to use rank-stabelized LoRA"},
    )
    lora_dropout: float = field(
        default=0.1,
        metadata={"help": "Dropout rate for LoRA adapters"},
    )
    lora_bias: str = field(
        default="none",
        metadata={"help": "Whether a bias is used for the LoRA adapter"},
    )
    task_type: str = field(
        default="CAUSAL_LM",
        metadata={"help": "Type of task the model is used for"},
    )
    target_modules: str = field(
        default="all-linear",
        metadata={"help": "Target modules for applying LoRA"},
    )

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    output_dir: str = field(default="")  # not used


@dataclass
class GenerationArguments:
    max_new_tokens: Optional[int] = field(
        default=256,
        metadata={
            "help": "Maximum number of new tokens to be generated in evaluation or prediction loops"
            "if predict_with_generate is set."
        },
    )
    min_new_tokens: Optional[int] = field(
        default=None, metadata={"help": "Minimum number of new tokens to generate."}
    )

    do_sample: Optional[bool] = field(default=False)
    num_beams: Optional[int] = field(default=1)
    num_beam_groups: Optional[int] = field(default=1)
    penalty_alpha: Optional[float] = field(default=None)
    use_cache: Optional[bool] = field(default=True)

    temperature: Optional[float] = field(default=1.0)
    top_k: Optional[int] = field(default=50)
    top_p: Optional[float] = field(default=1.0)
    typical_p: Optional[float] = field(default=1.0)
    diversity_penalty: Optional[float] = field(default=0.0)
    repetition_penalty: Optional[float] = field(default=1.0)
    length_penalty: Optional[float] = field(default=1.0)
    no_repeat_ngram_size: Optional[int] = field(default=0)


def _save_quant_config(script_args, training_args) -> None:
    if script_args.quantizer is not None:
        quantizer = load_from_file(script_args.quantizer)
        save_to_file(quantizer, Path(training_args.output_dir) / "quantizer.yaml")


def run_training(
    model,
    tokenizer,
    data,
    training_args: TrainingArguments,
    script_args: ScriptArguments,
):
    args = SFTConfig(
        **{k:v for k,v in asdict(training_args).items() if not k.startswith("_")}
    )
    args.dataset_text_field = "text"
    args.max_seq_length = script_args.max_seq_length
    args.dataset_kwargs = {
        "add_special_tokens": False,
        "append_concat_token": False,
    }

    trainer = SFTTrainer(
        model=model,
        train_dataset=data["train_dataset"],
        eval_dataset=data["eval_dataset"],
        args=args,
        processing_class=tokenizer,
        data_collator=data["data_collator"],
    )

    # verify model datatypes manually
    print_parameter_info(model)

    all_metrics = {"run_name": training_args.run_name}

    gc.collect()
    if training_args.do_train:
        train_result = trainer.train()
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_model()
        trainer.save_state()
        _save_quant_config(script_args, training_args)
        all_metrics.update(metrics)
    if training_args.do_eval:
        metrics = trainer.evaluate(metric_key_prefix="eval")
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
        all_metrics.update(metrics)
    if training_args.do_train or training_args.do_eval:
        with open(
            os.path.join(training_args.output_dir, "metrics.json"), "w"
        ) as out_file:
            out_file.write(json.dumps(all_metrics, indent=4))


def make_run_name_and_directory_unique(
    training_args: TrainingArguments, script_args: ScriptArguments
):
    output_dir_base = Path(script_args.output_dir_base)
    if training_args.output_dir:
        _logger.warning(
            "Ignoring 'output_dir' argument of transfomers.TrainingArguments"
            "and setting directory based on 'output_dir_base' instead."
        )
    parent = output_dir_base.parent
    name = output_dir_base.name
    number = 1

    while (output_dir := parent / (name + str(number))).exists():
        number += 1

    output_dir.mkdir(parents=True)
    training_args.output_dir = str(output_dir)

    if not training_args.run_name:
        training_args.run_name = name + str(number) if "number" in locals() else ""


def get_model_and_tokenizer(script_args, training_args):
    if script_args.quantizer is not None:
        quantizer = load_from_file(script_args.quantizer)
    else:
        quantizer = None

    model, tokenizer = quant_util.load_model_and_tokenizer(
        script_args.model_id,
        quantizer,
        model_dtype=TORCH_DTYPE,
        overwrite_hf_model_kwargs={"attn_implementation": ATTN_IMPLEMENTATION},
    )

    lora_config = LoraConfig(
        r=script_args.lora_r,
        lora_alpha=script_args.lora_alpha,
        use_rslora=script_args.use_rslora,
        bias=script_args.lora_bias,
        task_type=script_args.task_type,
        target_modules=script_args.target_modules,
    )

    model = quant_util.prepare_model_for_finetuning(model, quantizer, lora_config)
    return model, tokenizer


def main():
    parser = TrlParser(
        [ScriptArguments, DataArguments, TrainingArguments, GenerationArguments]
    )
    script_args, data_args, training_args, generation_args = (
        parser.parse_args_and_config()
    )

    training_args.generation_config = transformers.GenerationConfig(
        **vars(generation_args)
    )


    make_run_name_and_directory_unique(training_args, script_args)

    if "wandb" in training_args.report_to:
        import wandb

        wandb.init(
            project="LLMFinetuning",
            name=training_args.run_name,
            config=dict(
                script_args=script_args,
                data_args=data_args,
                training_args=training_args,
            ),
        )

    elif training_args.report_to == [""]:
        training_args.report_to = []

    if training_args.gradient_checkpointing:
        training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}

    transformers.set_seed(training_args.seed)

    model, tokenizer = get_model_and_tokenizer(script_args, training_args)

    # preprocess the dataset and format it with the chat template
    data = prepare_data(
        tokenizer, data_args, training_args, hf_home=script_args.hf_home
    )

    run_training(model, tokenizer, data, training_args, script_args)

    if "wandb" in training_args.report_to:
        wandb.finish()


if __name__ == "__main__":
    main()
