import logging
import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from sklearn.metrics import accuracy_score

import datasets
import torch
from datasets import load_dataset

from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_int8_training
)
import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    is_torch_tpu_available,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint


SAVE_MODULES_MAP = {
    "bigcode/starcoderbase-1b": [
        "wte",
        "lm_head"
    ],
    "bigcode/astraios-1b-fft": [
        "wte",
        "lm_head"
    ],
    "bigcode/starcoderbase-3b": [
        "transformer.wte",
        "lm_head"
    ],
    "bigcode/starcoderbase-7b": [
        "transformer.wte",
        "lm_head"
    ],
    "Salesforce/codegen-350M-multi": [
        "wte",
        "lm_head"
    ],
    "Salesforce/codegen-2B-multi": [
        "wte",
        "lm_head"
    ],
    "deepseek-ai/deepseek-coder-1.3b-base": [
        "embed_tokens",
        "lm_head"
    ],
    "deepseek-ai/deepseek-coder-5.7bmqa-base": [
        "embed_tokens",
        "lm_head"
    ],
    "codellama/CodeLlama-7b-hf": [
        "embed_tokens",
        "lm_head"
    ],
    "BASE_2794M": [
        "embed_tokens",
        "lm_head"
    ],
    "BASE_1229M": [
        "embed_tokens",
        "lm_head"
    ],
    "BASE_491M": [
        "embed_tokens",
        "lm_head"
    ],
    "BASE_255M": [
        "embed_tokens",
        "lm_head"
    ]
}


LORA_COMPONENTS_MAP = {
    "bigcode/starcoderbase-1b": [
        "c_attn",
        "c_proj",
        "q_attn"
    ],
    "bigcode/starcoderbase-3b": [
        "c_attn",
        "c_proj",
        "q_attn",
    ],
    "bigcode/starcoderbase-7b": [
        "c_attn",
        "c_proj",
        "q_attn",
    ],
    "Salesforce/codegen-350M-multi": [
        "qkv_proj",
        "out_proj"
    ],
    "Salesforce/codegen-2B-multi": [
        "qkv_proj",
        "out_proj"
    ],
    "deepseek-ai/deepseek-coder-1.3b-base": [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ],
    "deepseek-ai/deepseek-coder-5.7bmqa-base": [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ],
    "codellama/CodeLlama-7b-hf": [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ],
    "BASE_2794M": [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ],
    "BASE_1229M": [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ],
    "BASE_491M": [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ],
    "BASE_255M": [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ]
}

logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class LoggingArguments:
    """
    Arguments for instrumenting the weights and biases logger
    """

    project_name: str = field(
        default=None, metadata={"help": "The project name under which the experiment will be logged."}
    )
    wandb_token: str = field(
        default=None, metadata={"help": "API token for WandB hub."}
    )


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
        },
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
                "execute code present on the Hub on your local machine."
            )
        },
    )
    torch_dtype: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )
    low_cpu_mem_usage: bool = field(
        default=False,
        metadata={
            "help": (
                "It is an option to create the model as an empty shell, then only materialize its parameters"
                "when the pretrained weights are loaded. "
                "set True will benefit LLM loading time and RAM consumption."
            )
        },
    )
    llm_int8_threshold: float = field(
        default=6.0, metadata={"help": "The thresholf for a parameter to be designated a quantization outlier."}
    )
    lora_alpha: int = field(
        default=16, metadata={"help": "The interpolation importance factor for the LoRA adapter."}
    )
    lora_r: int = field(
        default=8, metadata={"help": "The LoRA adapter rank."}
    )
    lora_dropout: float = field(
        default=0.1, metadata={"help": "Dropout value for LoRA layers."}
    )
    def __post_init__(self):
        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
            raise ValueError(
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
            )
        if self.model_name_or_path not in LORA_COMPONENTS_MAP.keys():
            raise ValueError(
                f"model_name_or_path argument must be one of the following: {LORA_COMPONENTS_MAP.keys()}"
            )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )

    def __post_init__(self):
        if self.dataset_name is None:
            raise ValueError("Need either a dataset name.")


def main():
    parser = HfArgumentParser((LoggingArguments, ModelArguments, DataTrainingArguments, TrainingArguments))
    log_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    quant_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_threshold=model_args.llm_int8_threshold
    )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {int(os.environ['RANK'])}, "
        + f"device: {training_args.device}, "
        + f"n_gpu: {training_args.n_gpu}, "
        + f"world_size {int(os.environ['LOCAL_WORLD_SIZE'])}, "
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
        + f"16-bits training: {training_args.fp16}"
    )
    os.environ["WANDB_PROJECT"] = log_args.project_name
    os.environ["WANDB_API_KEY"] = log_args.wandb_token
    logger.info(f"Training/evaluation parameters {training_args}")

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed)

    raw_datasets = load_dataset(
        data_args.dataset_name,
        data_args.dataset_config_name,
        cache_dir=model_args.cache_dir,
        token=model_args.token
    )
    if "validation" not in raw_datasets.keys():
        raw_datasets["validation"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=f"train[:{data_args.validation_split_percentage}%]",
            cache_dir=model_args.cache_dir,
            token=model_args.token
        )
        raw_datasets["train"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=f"train[{data_args.validation_split_percentage}%:]",
            cache_dir=model_args.cache_dir,
            token=model_args.token
        )
    tokenized_datasets = raw_datasets

    config_kwargs = {
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")
        if model_args.config_overrides is not None:
            logger.info(f"Overriding config: {model_args.config_overrides}")
            config.update_from_string(model_args.config_overrides)
            logger.info(f"New config: {config}")

    if model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            token=model_args.token,
            trust_remote_code=model_args.trust_remote_code,
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    torch_dtype = (
        model_args.torch_dtype
        if model_args.torch_dtype in ["auto", None]
        else getattr(torch, model_args.torch_dtype)
    )
    model_base = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=model_args.token,
        trust_remote_code=model_args.trust_remote_code,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        quantization_config=quant_config,
    )

    embedding_size = model_base.get_input_embeddings().weight.shape[0]
    logger.info(f"Loaded model's embedding size is {embedding_size}")
    if len(tokenizer) > embedding_size:
        logger.info(f"Extending the embedding size from {embedding_size} to {len(tokenizer)}")
        input_embeddings = model_base.get_input_embeddings().weight.data
        output_embeddings = model_base.get_output_embeddings().weight.data

        model_base.resize_token_embeddings(len(tokenizer))

        input_embeddings_avg = input_embeddings[:embedding_size].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:embedding_size].mean(dim=0, keepdim=True)

        logger.info(f"Setting the newly added input embedding tokens to {input_embeddings_avg}")
        input_embeddings[embedding_size:] = input_embeddings_avg
        logger.info(f"Setting the newly added input embedding tokens to {input_embeddings_avg}")
        output_embeddings[embedding_size:] = output_embeddings_avg
    elif len(tokenizer) < embedding_size:
        model_base.resize_token_embeddings(len(tokenizer))

    model_base = prepare_model_for_int8_training(model_base)
    adapter_config = LoraConfig(
        lora_alpha=model_args.lora_alpha,
        lora_dropout=model_args.lora_dropout,
        r=model_args.lora_r,
        inference_mode=False,
        bias="none",
        task_type="CAUSAL_LM",
        modules_to_save=SAVE_MODULES_MAP[model_args.model_name_or_path],
        target_modules=LORA_COMPONENTS_MAP[model_args.model_name_or_path]
    )
    model = get_peft_model(model_base, adapter_config)

    if training_args.do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = tokenized_datasets["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))

    if training_args.do_eval:
        if "validation" not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = tokenized_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))

        def preprocess_logits_for_metrics(logits, labels):
            if isinstance(logits, tuple):
                # Depending on the model and config, logits may contain extra tensors,
                # like past_key_values, but logits always come first
                logits = logits[0]
            return logits.argmax(dim=-1)


        def compute_metrics(eval_preds):
            preds, labels = eval_preds
            # preds have the same shape as the labels, after the argmax(-1) has been calculated
            # by preprocess_logits_for_metrics but we need to shift the labels
            labels = labels[:, 1:].reshape(-1)
            preds = preds[:, :-1].reshape(-1)
            return {
                "accuracy": float(
                    accuracy_score(
                        labels,
                        preds,
                        normalize=True,
                        sample_weight=None
                    )
                )
            }

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        # Data collator will default to DataCollatorWithPadding, so we change it.
        data_collator=default_data_collator,
        compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics
        if training_args.do_eval and not is_torch_tpu_available()
        else None,
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()  # Saves the tokenizer too for easy upload

        metrics = train_result.metrics

        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        metrics = trainer.evaluate()

        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)


if __name__ == "__main__":
    main()
