import torch
import pytorch_lightning as pl
import hydra
import torch_optimizer as torch_optim
 

class BaseTask(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        self.model = hydra.utils.instantiate(self.hparams.model)
        self.criterion = hydra.utils.instantiate(self.hparams.criterion)

    def training_step(self, batch, batch_idx):
        X = batch["input"]
        y = batch["label"]
        y_preds = self.model(X)
        loss, logging_output = self.criterion(y_preds, batch)
        
        # log numerical values
        self.log_dict(logging_output)
        return loss

    
    def validation_step(self, batch, batch_idx):
        X = batch["input"]
        y = batch["label"]
        y_preds = self.model(X)
        loss, logging_output = self.criterion(y_preds, batch)
        self.log_dict(logging_output)
        self.log("val_loss", loss.item(), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss
    
    

    def configure_optimizers(self):
        """
        Define optimizers and learning-rate schedulers to use in your optimization.

        Returns:
            [optimizer],[scheduler] - The first list contains optimizers, the
            second contains LR schedulers (or lr_dict).
        """
        if self.hparams.optimizer.optim == "SGD":
            optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.optimizer.lr, momentum=0.9)
        elif self.hparams.optimizer.optim == 'Adam':
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.optimizer.lr, weight_decay=0.01)
        elif self.hparams.optimizer.optim == 'AdamW':
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.optimizer.lr)
        elif self.hparams.optimizer.optim == 'AdamW_finetune':
            linear_out_params = self.model.linear_out.parameters() if not self.cfg.multi_gpu else self.model.module.linear_out.parameters()
            ignored_params = list(map(id, linear_out_params))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.model.parameters())

            optimizer = torch.optim.AdamW([
                {'params': base_params},
                {'params': linear_out_params, 'lr': self.hparams.optimizer.lr}
            ], lr=self.hparams.optimizer.lr * 0.1)
        elif self.hparams.optimizer.optim == 'LAMB':
            optimizer = torch_optim.Lamb(self.model.parameters(), lr=self.hparams.optimizer.lr)
        else:
            raise NotImplementedError("No valid optim name")

        scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer)

        lr_scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1,
        }

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}