from collections import defaultdict
import pytorch_lightning as pl
import torch
import inspect

from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCELoss


from cpr.llm_transformer.utils import instantiate
from cpr.llm_transformer.utils.group_parameters import group_parameters_for_optimizer
from cpr.llm_transformer.utils.optim.lr_schedule import get_learning_rate_schedule
from cpr.llm_transformer.model.transformer import DecoderTransformer
from cpr.adam_cpr import AdamCPR, group_cpr_parameters

class LanguageModelTrainer(pl.LightningModule):
    """
    PTL wrapper class for model training
    """

    def __init__(
            self,
            cfg_train,
            cfg_model,
            py_logger,
            val_sets_name,
            ignore_index,
    ):
        super().__init__()

        self.save_hyperparameters()
        self.cfg_train = cfg_train


        self.val_sets_name =  val_sets_name
        self.ignore_index = ignore_index
        self.py_logger = py_logger

        self.model = DecoderTransformer(cfg_model)

        self.loss_train = FlashCELoss(ignore_index=self.ignore_index, reduction='mean', label_smoothing=0.0,
                                   inplace_backward=self.cfg_train.loss_fn.inplace_backward)


        self.intern_log = []
        self.log_lists = defaultdict(list)

        self.validation_step_outputs = defaultdict(list)

    def on_train_start(self):

        if self.cfg_train.optimizer.scheduler_mult_factor is not None:
            # used to add additional lr scheduler multiplier after checkpoint load
            # otherwise checkpoint load resumes from old scheduler
            # assumes all schedulers are lr lambda schedulers
            self.py_logger.info(
                f"Multiplying all LR schedule lambas by {self.cfg_train.optimizer.scheduler_mult_factor}"
            )
            self.lr_schedulers().lr_lambdas = [
                lambda x: self.cfg_train.optimizer.scheduler_mult_factor * fn(x)
                for fn in self.lr_schedulers().lr_lambdas
            ]


    def training_step(self, batch, batch_idx):


        logits = self.model(trg_shf_seq=batch['src_seq'])

        labels = batch['trg_seq']

        loss = self.loss_train(logits.view(-1, logits.size(-1)), labels.view(-1))

        self.log(
            f"train/loss",
            loss.detach(),
            on_step=True,
            on_epoch=False,
            prog_bar=True,
            sync_dist=True,
        )

        return {"loss": loss }


    def validation_step(self, batch, batch_idx, dataloader_idx=0):

        with torch.no_grad():
            logits = self.model(trg_shf_seq=batch['src_seq']).detach()

            loss = self.loss_train(logits.view(-1, logits.size(-1)), batch['trg_seq'].view(-1))

        if dataloader_idx == 0:
            self.log(f"val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        return_dict = {"loss": loss,
                       "batch_size": torch.FloatTensor([batch['trg_len'].shape[0]]),
                       "batch_length": torch.mean(batch['trg_len'].detach().float()),
                       "num_loss_tokens": torch.sum(batch['trg_len'])
                       }

        count = torch.sum(batch['trg_len'], dtype=torch.float)
        log_probs = loss * count
        preds = logits.argmax(dim=-1).view(-1)
        target = batch['trg_seq'].view(-1)
        idx = target != self.ignore_index
        accuracy = torch.sum(preds[idx] == target[idx])

        return_dict.update({"accuracy": accuracy, "log_probs": log_probs, "count": count})

        self.validation_step_outputs[dataloader_idx].append(return_dict)

        return return_dict

    def on_validation_epoch_end(self):

        values = ['log_probs', 'accuracy', 'count']

        assert len(self.val_sets_name) == len(self.validation_step_outputs)

        for dataset_idx, dataset_name in enumerate(self.val_sets_name):

            output = self.validation_step_outputs[dataset_idx]
            summed_values = {k: 0 for k in values}
            for out_dict in output:
                for key in values:
                    summed_values[key] += out_dict[key]

            ppl = torch.exp(summed_values['log_probs'] / summed_values['count'])
            accuracy = summed_values['accuracy'] / summed_values['count']
            metrics = {"ppl": ppl, "acc": accuracy}

            for name, value in metrics.items():
                self.log(f"val/{dataset_name}/{name}", value,
                         on_step=False, on_epoch=True, prog_bar=False, sync_dist=True, )
                # if self.local_rank == 0:
                #     print(f"val/{dataset_name}/{name}", value, self.local_rank)

        self.validation_step_outputs.clear()


    def configure_optimizers(self):


        if self.cfg_train.optimizer_name == 'adamcpr':

            lr = self.cfg_train.optimizer.lr
            betas = self.cfg_train.optimizer.betas
            eps = self.cfg_train.optimizer.eps
            kappa = self.cfg_train.adamcpr.kappa
            mode = self.cfg_train.adamcpr.mode
            lagmul_rate = self.cfg_train.adamcpr.lagmul_rate

            kappa_adapt = self.cfg_train.adamcpr.kappa_adapt
            kappa_init_warm_start = self.cfg_train.adamcpr.kappa_init_warm_start
            kappa_init_dependent = self.cfg_train.adamcpr.kappa_init_dependent


            bias_regularization = self.cfg_train.adamcpr.bias_reg
            normalization_regularization = self.cfg_train.adamcpr.normalization_reg

            # if bias_regularization or normalization_regularization:
            cpr_config = {'lr':lr, 'betas':betas, 'eps':eps, 'kappa':kappa, 'mode':mode,
                              'lagmul_rate':lagmul_rate}
            parameters = group_cpr_parameters(self.model,cpr_config,  bias_regularization=bias_regularization,
                                                  normalization_regularization=normalization_regularization, avoid_keywords=None)

            optimizer = AdamCPR(parameters, lr=lr, betas=betas, eps=eps, kappa=kappa, mode=mode,
                                lagmul_rate=lagmul_rate, apply_decay=None, kappa_init_dependent=kappa_init_dependent,
                                kappa_adapt=kappa_adapt,
                                kappa_init_warm_start=kappa_init_warm_start,
                                )

        else:

            if 'optimizer_param_grouping' in self.cfg_train:  # Set zero weight decay for some params
                parameters = group_parameters_for_optimizer(self.model, self.cfg_train.optimizer,
                                                            **self.cfg_train.optimizer_param_grouping)
            else:
                parameters = self.model.parameters()

            optimizer = torch.optim.AdamW(parameters, lr=self.cfg_train.optimizer.lr,betas=self.cfg_train.optimizer.betas,
                                          eps=self.cfg_train.optimizer.eps, weight_decay=self.cfg_train.optimizer.weight_decay)


        # Log optimizer info
        for i, g in enumerate(optimizer.param_groups):
            ntensors = len(g['params'])
            nparams = sum(p.numel() for p in g['params'])
            hparams = {k: v for k, v in g.items() if k != 'params'}
            self.py_logger.info(f'Optimizer group {i}: {ntensors} tensors, {nparams} parameters, {hparams}')

        if 'scheduler' not in self.cfg_train:
            return optimizer
        else:
            lr_lambda = get_learning_rate_schedule(self.cfg_train.scheduler)


            lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)




            return [optimizer], {'scheduler': lr_scheduler,
                                 'interval': self.cfg_train.get('scheduler_interval', 'step'),
                                 'monitor': self.cfg_train.get('scheduler_monitor', 'val/loss')}

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
        # https://pytorch-lightning.readthedocs.io/en/latest/guides/speed.html#set-grads-to-none
        # TD [2022-04-30]: DeepSpeed optimizer uses the kwarg set_grad_to_none instead of set_to_none
        if 'set_to_none' in inspect.signature(optimizer.zero_grad).parameters:
            optimizer.zero_grad(set_to_none=True)
        else:
            optimizer.zero_grad()





