import argparse
import json
import logging
import os
import random
from pathlib import Path
from pytorch_lightning.strategies import DeepSpeedStrategy
import numpy as np
import pytorch_lightning as pl
from transformers import (
    AutoConfig,
    AutoModel,
    AutoModelForPreTraining,
    AutoModelForCausalLM,
    AutoModelForQuestionAnswering,
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    AutoModelForTokenClassification,
    AutoModelWithLMHead,
    AutoTokenizer,
    PretrainedConfig,
    PreTrainedTokenizer,
    get_linear_schedule_with_warmup,
)

from torch.optim import AdamW
from deepspeed.ops.adam import DeepSpeedCPUAdam
# from apex.optimizers import FusedAdam

logger = logging.getLogger(__name__)


MODEL_MODES = {
    "base": AutoModel,
    "sequence-classification": AutoModelForSequenceClassification,
    "question-answering": AutoModelForQuestionAnswering,
    "pretraining": AutoModelForPreTraining,
    "token-classification": AutoModelForTokenClassification,
    "language-modeling": AutoModelWithLMHead,
    "seq2seq": AutoModelForSeq2SeqLM,
    "next-token-prediction": AutoModelForCausalLM,
}


class BaseTransformer(pl.LightningModule):
    def __init__(
        self,
        hparams: argparse.Namespace,
        num_labels=None,
        mode="base",
        config=None,
        tokenizer=None,
        model=None,
        **config_kwargs,
    ):
        """Initialize a model, tokenizer and config."""
        super().__init__()

        self.hparams.update(
            vars(hparams) if isinstance(hparams, argparse.Namespace) else hparams
        )
        self.save_hyperparameters(hparams)

        self.step_count = 0
        self.tfmr_ckpts = {}
        self.output_dir = Path(self.hparams.output_dir)
        cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None

        assert (not self.hparams.train_from_scratch) or (
            self.hparams.train_from_scratch and self.hparams.config_name is not None
        ), "You need to provide a config to train from scratch"

        if config is None:
            self.config = AutoConfig.from_pretrained(
                self.hparams.config_name
                if self.hparams.config_name
                else self.hparams.model_name_or_path,
                **({"num_labels": num_labels} if num_labels is not None else {}),
                cache_dir=cache_dir,
                **config_kwargs,
            )
        else:
            self.config: PretrainedConfig = config
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.hparams.tokenizer_name
                if self.hparams.tokenizer_name
                else self.hparams.config_name,
                cache_dir=cache_dir,
            )
        else:
            self.tokenizer: PreTrainedTokenizer = tokenizer

        self.model_type = MODEL_MODES[mode]
        if model is None and not self.hparams.train_from_scratch:
            self.model = self.model_type.from_pretrained(
                self.hparams.model_name_or_path,
                config=self.config,
                cache_dir=cache_dir,
            )
            logging.info(self.model_type)
            logging.info(f"Loaded model {self.hparams.model_name_or_path}")
            logging.info(self.model.named_parameters())
        elif self.hparams.train_from_scratch:
            logging.info(
                f"Training from scratch with config {self.hparams.config_name}"
            )
            self.model = MODEL_MODES[mode].from_config(self.config)
        else:
            self.model = model

    def load_hf_checkpoint(self, *args, **kwargs):
        self.model = self.model_type.from_pretrained(*args, **kwargs)

    def configure_optimizers(self):
        "Prepare optimizer and schedule (linear warmup and decay)"
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        
        # TODO: this should be automatic
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.learning_rate,
            eps=self.hparams.adam_epsilon,
        )

        self.opt = optimizer

        if self.hparams.add_lr_scheduler:
            scheduler = get_linear_schedule_with_warmup(
                self.opt,
                num_warmup_steps=int(self.hparams.warmup_prop * self.total_steps),
                num_training_steps=self.total_steps,
            )
            logging.info(
                f"Total steps = {self.total_steps}, warmup steps = {int(self.hparams.warmup_prop * self.total_steps)}"
            )
            scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
            return [optimizer], [scheduler]
        else:
            logging.info(f"Total steps = {self.total_steps}, NOT adding LR scheduler")
            return [optimizer]

    def test_step(self, batch, batch_nb):
        return self.validation_step(batch, batch_nb)

    def test_epoch_end(self, outputs):
        return self.validation_end(outputs)

    def setup(self, stage):
        train_batch_size = self.hparams.train_batch_size
        dataloader = self.get_dataloader(
            "train",
            train_batch_size,
            src_key=self.hparams.src_key,
            tgt_key=self.hparams.tgt_key,
        )
        self.train_loader = dataloader
        logging.info(f"Size of the train split = {len(dataloader.dataset)}")
        if self.hparams.max_steps is not None:
            self.total_steps = self.hparams.max_steps
        else:
            self.total_steps = (
                (
                    len(dataloader.dataset)
                    // (train_batch_size * max(1, self.hparams.gpus))
                )
                // self.hparams.accumulate_grad_batches
                * float(self.hparams.max_epochs)
            )

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.get_dataloader(
            "dev",
            self.hparams.eval_batch_size,
            src_key=self.hparams.src_key,
            tgt_key=self.hparams.tgt_key,
        )

    def test_dataloader(self):
        return self.get_dataloader(
            "test",
            self.hparams.eval_batch_size,
            src_key=self.hparams.src_key,
            tgt_key=self.hparams.tgt_key,
        )

    def _feature_file(self, mode):
        return os.path.join(
            self.hparams.data_dir,
            "cached_{}_{}_{}".format(
                mode,
                list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
                str(self.hparams.max_seq_length),
            ),
        )

    # @pl.utilities.rank_zero_only
    # def on_save_checkpoint(
    #     self, checkpoint: Dict[str, Any] = None, save_path: str = None
    # ) -> None:
    #     """
    #     Saves usual HF checkpoints that can work with a custom generate function.
    #     self: trainer
    #     """
    #     if save_path is None:
    #         print("Saving best checkpoint")
    #         print(self.metrics[f"val_{self.hparams.val_metric}"])
    #         if isinstance(self.metrics[f"val_{self.hparams.val_metric}"], torch.Tensor):
    #             best_metric = round(self.metrics[f"val_{self.hparams.val_metric}"].item(), 2)
    #         elif isinstance(self.metrics[f"val_{self.hparams.val_metric}"], list):
    #             bes
    #         best_metric = round(self.metrics[f"val_{self.hparams.val_metric}"], 2)
    #         save_path = Path(self.logger.log_dir).joinpath(
    #             f"best_tfmr_step={self.global_step}_{self.hparams.val_metric}={best_metric}"
    #         )
    #         logging.info(f"Saving hf checkpoint to {save_path}")
    #         self.model.config.save_step = self.global_step
    #     else:
    #         save_path = Path(save_path)
    #     save_path.mkdir(exist_ok=True)

    #     self.model.save_pretrained(save_path)
    #     self.tokenizer.save_pretrained(save_path)
    #     with open(save_path.joinpath("hparams.json"), "w") as f:
    #         json.dump(self.hparams, f)

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        parser.add_argument(
            "--model_name_or_path",
            default=None,
            type=str,
            required=True,
            help="Path to pretrained model or model identifier from huggingface.co/models",
        )
        parser.add_argument(
            "--config_name",
            required=False,
            type=str,
            help="Pretrained config name or path if not the same as model_name",
        )
        parser.add_argument(
            "--train_from_scratch",
            default=False,
            action="store_true",
            help="Start training from scratch (not using a pre-trained model)",
        )

        parser.add_argument(
            "--tokenizer_name",
            default=None,
            type=str,
            help="Pretrained tokenizer name or path if not the same as model_name",
        )
        parser.add_argument(
            "--cache_dir",
            default="",
            type=str,
            help="Where do you want to store the pre-trained models downloaded from s3",
        )
        parser.add_argument(
            "--learning_rate",
            default=5e-5,
            type=float,
            help="The initial learning rate for Adam.",
        )
        parser.add_argument(
            "--weight_decay",
            default=0.0,
            type=float,
            help="Weight decay if we apply some.",
        )
        parser.add_argument(
            "--adam_epsilon",
            default=1e-8,
            type=float,
            help="Epsilon for Adam optimizer.",
        )
        parser.add_argument(
            "--num_workers", default=4, type=int, help="kwarg passed to DataLoader"
        )

        parser.add_argument("--train_batch_size", default=32, type=int)
        parser.add_argument("--eval_batch_size", default=32, type=int)
        parser.add_argument(
            "--warmup_prop", default=0.0, type=float, help="Warmup proportion."
        )


def add_generic_args(parser, root_dir) -> None:
    #  TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )

    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O2",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int, default=0)
    parser.add_argument(
        "--max_grad_norm",
        dest="gradient_clip_val",
        default=1.0,
        type=float,
        help="Max gradient norm",
    )
    parser.add_argument(
        "--do_train", action="store_true", help="Whether to run training."
    )
    parser.add_argument(
        "--do_predict",
        action="store_true",
        help="Whether to run predictions on the test set.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        dest="accumulate_grad_batches",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )

    parser.add_argument(
        "--seed", type=int, default=42, help="random seed for initialization"
    )


def generic_train(
    model: BaseTransformer,
    args: argparse.Namespace,
    logger=True,  # can pass WandbLogger() here
    extra_callbacks=[],
    checkpoint_callback=None,
    logging_callback=None,
    **extra_train_kwargs,
):
    seed(args.seed)

    # init model
    odir = Path(model.hparams.output_dir)
    odir.mkdir(exist_ok=True)

    # add custom checkpoints
    if checkpoint_callback is None:
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            filepath=args.output_dir,
            prefix="checkpoint",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
        )
    if logging_callback is None:
        logging_callback = LoggingCallback()

    train_params = {}

    
    trainer = pl.Trainer.from_argparse_args(
        args,
        callbacks=[logging_callback, checkpoint_callback] + extra_callbacks,
        logger=logger,
        accelerator=args.accelerator,
        precision=args.precision,
        strategy=args.strategy if args.strategy else None,
        **train_params,
        devices=args.gpus,
        **extra_train_kwargs,
    )

    if not isinstance(logger, bool):
        logger.log_dir = args.output_dir

    if args.do_train:
        trainer.fit(model)

    return trainer


def seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    pl.utilities.seed.seed_everything(seed)
    pl.seed_everything(seed)
