import os

import torch
import torch.nn as nn
import pytorch_lightning as pl
import numpy as np
import pandas as pd

from utils.metric import (
    AccuracyMetric,
    F1ScoreMetric,
    MovieNetMetric,
    SklearnAPMetric,
    SklearnAUCROCMetric,
)
from utils.hydra_utils import save_config_to_disk

import json
import logging
import torch.nn.functional as F

from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay


class FinetuningWrapper(pl.LightningModule):
    def __init__(self, cfg, shot_encoder, crn, loss=None):
        
        super().__init__()
        self.cfg = cfg
        self.shot_encoder = shot_encoder
        self.crn = crn

        if cfg.MODEL.contextual_relation_network.enabled:
            crn_name = cfg.MODEL.contextual_relation_network.name
            hdim = cfg.MODEL.contextual_relation_network.params[crn_name]["hidden_size"]
        else:
            # hdim = 
            raise NotImplementedError

        self.prev_sbd = nn.Sequential(nn.Linear(hdim, 128), nn.GELU(), nn.Dropout(p=0.3))
        self.head_sbd = nn.Linear(128, 2)
        self.head_sbd.bias.data = torch.Tensor([-0.036, -1.096])

        # define metrics
        self.acc_metric = AccuracyMetric() # AccuracyMetric(task="binary")
        self.ap_metric = SklearnAPMetric()
        self.f1_metric = F1ScoreMetric(task="binary", num_classes=1)
        self.auc_metric = SklearnAUCROCMetric()
        self.movienet_metric = MovieNetMetric(cfg)

        self.log_dir = os.path.join(cfg.LOG_PATH, cfg.EXPR_NAME)
        self.use_raw_shot = cfg.USE_RAW_SHOT
        self.eps = 1e-5

        self.predictions = []
        self.best_ap = 0
        self.all_preds = {}

        self.best_val_loss = 100

        self.val_outputs = []
        self.val_labels = []

        self.loss_type = cfg.LOSS.type
        logging.info("loss type : {}".format(self.loss_type))
        if self.loss_type == "CrossEntropyLoss_1_1":
            self.criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 1.0], dtype=torch.float32, device=self.device))

        self.activate_nearby_shots = cfg.LOSS.get("activate_nearby_shots", False)
        logging.info("activate_nearby_shots : {}".format(self.activate_nearby_shots))
        if self.activate_nearby_shots:
            self.head_sfd = nn.Linear(128, 2)

        self.first_shot_prediction = cfg.LOSS.get("first_shot_prediction", False)
        logging.info("first_shot_prediction : {}".format(self.first_shot_prediction))

        self.reverse_shot_prediction = cfg.LOSS.get("reverse_shot_prediction", False)
        logging.info("reverse_shot_prediction : {}".format(self.reverse_shot_prediction))
        
        if self.first_shot_prediction or self.reverse_shot_prediction:
            if self.first_shot_prediction:
                self.head_sfd = nn.Linear(128, 2)

            self.acc_metric_add = AccuracyMetric() # AccuracyMetric(task="binary")
            self.ap_metric_add = SklearnAPMetric()
            self.f1_metric_add = F1ScoreMetric(task="binary", num_classes=1)
            self.auc_metric_add = SklearnAUCROCMetric()
            self.movienet_metric_add = MovieNetMetric(cfg)
            self.predictions_add = []
            self.val_outputs_add = []
            self.val_labels_add = []
    
    
    def on_train_start(self) -> None:
        if self.global_rank == 0:
            try:
                save_config_to_disk(self.cfg)
            except Exception as err:
                logging.info(err)

    
    def shared_step(self, inputs: torch.Tensor, batch=None) -> torch.Tensor:
        
        with torch.no_grad():
            # infer shot encoder
            if self.use_raw_shot:
                raise NotImplementedError
            else:
                shot_repr = inputs
                
        assert len(shot_repr.shape) == 3

        if self.cfg.MODEL.contextual_relation_network.enabled:
            # _, pooled = self.crn(shot_repr, mask=None)
            if self.activate_nearby_shots:
                _, pooled = self.crn(shot_repr, mask=None, pooling_method="nearby")
                prevs, currs, nexts = pooled.unbind(dim=1)
                pred = self.head_sbd(self.prev_sbd(currs))  ### pred - boundary
                
                prevs = self.head_sbd(self.prev_sbd(prevs)) ### prev - boundary
                currs = self.head_sfd(self.prev_sbd(currs)) ### curr - first
                nexts = self.head_sfd(self.prev_sbd(nexts)) ### next - first

                # pred = pred * nexts
                # currs = currs * prevs
                
                return (pred, (prevs, currs, nexts))
            else:
                _, pooled = self.crn(shot_repr, mask=None)
        else:
            cidx = shot_repr.shape[1] // 2
            pooled = shot_repr[:, cidx, :]

        pred = self.head_sbd(self.prev_sbd(pooled))

        return pred

    def additional_step(self, inputs: torch.Tensor, batch=None) -> torch.Tensor:

        with torch.no_grad():
            # infer shot encoder
            if self.use_raw_shot:
                raise NotImplementedError
            else:
                shot_repr = inputs
                
        assert len(shot_repr.shape) == 3

        if self.cfg.MODEL.contextual_relation_network.enabled:
            _, pooled = self.crn(shot_repr, mask=None)
        else:
            cidx = shot_repr.shape[1] // 2
            pooled = shot_repr[:, cidx, :]

        pred = self.head_sfd(self.prev_sbd(pooled))
        
        return pred


    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:

        if activate_nearby_shots:
            return self.shared_step(x, **kwargs)[0]
        
        return self.shared_step(x, **kwargs)


    # def weighted_cross_entropy(self, outputs, labels):
    def weighted_cross_entropy(self, outputs, labels, isSum=True):

        # loss = F.cross_entropy(outputs.squeeze(), labels.squeeze(), reduction="none")
        loss = F.cross_entropy(outputs.float().squeeze(), labels.squeeze(), reduction="none")
        
        lpos = labels == 1
        lneg = labels == 0

        pp, nn = 1, 1
        wp = (pp / float(pp + nn)) * lpos / (lpos.sum() + self.eps)
        wn = (nn / float(pp + nn)) * lneg / (lneg.sum() + self.eps)
        w = wp + wn
        # loss = (w * loss).sum()
        if isSum:
            loss = (w * loss).sum()
        else:
            loss = (w * loss)

        return loss


    def log_weighted_focal_loss(self, outputs, labels, isMean=True):

        self.gamma = 2.0
        self.beta = 0.9999

        logpt = F.log_softmax(outputs.float().squeeze(), dim=1)
        pt = torch.exp(logpt)
        logpt = (1-pt)**self.gamma*logpt

        lpos = labels == 1
        lneg = labels == 0   

        ### ratio
        n_total = lpos.sum() + lneg.sum()
        rpos = lpos.sum() / n_total * 100
        rneg = lneg.sum() / n_total * 100     

        wp = (1-self.beta) / (1-self.beta**torch.log(rpos))
        wn = (1-self.beta) / (1-self.beta**torch.log(rneg))
        wt = wp+wn
        wp = wp / wt
        wn = wn / wt

        w = torch.tensor([wn, wp]).to(wp)

        if isMean:
            loss = F.nll_loss(logpt, labels.squeeze(), w, reduction='mean')
        else:
            loss = F.nll_loss(logpt, labels.squeeze(), w, reduction='none')

        return loss

    def get_loss(self, outputs, labels, log_name):
        
        if self.loss_type == "weighted_cross_entropy":
            loss = self.weighted_cross_entropy(outputs, labels)
        elif self.loss_type == "log_weighted_focal_loss":
            loss = self.log_weighted_focal_loss(outputs, labels)
        elif self.loss_type in ["CrossEntropyLoss_1_1", "CrossEntropyLoss_1_2", "CrossEntropyLoss_1_4", "CrossEntropyLoss_1_11"]:
            loss = self.criterion(outputs, labels)
        else:
            raise NotImplementedError

        self.log(
            "{}/loss".format(log_name),
            loss,
            on_step=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )        

        return loss

    def get_pred_n_metric(self, outputs, labels, log_name):

        preds = torch.argmax(outputs, dim=1)

        gt_one = labels == 1
        gt_zero = labels == 0
        pred_one = preds == 1
        pred_zero = preds == 0

        tp = (gt_one * pred_one).sum()
        fp = (gt_zero * pred_one).sum()
        tn = (gt_zero * pred_zero).sum()
        fn = (gt_one * pred_zero).sum()

        acc0 = 100.0 * tn / (fp + tn + self.eps)
        acc1 = 100.0 * tp / (tp + fn + self.eps)
        tp_tn = tp + tn

        self.log(
            "{}/tp_batch".format(log_name),
            tp,
            on_step=True,
            prog_bar=False,
            logger=True,
            sync_dist=True,
        )
        self.log(
            "{}/fp_batch".format(log_name),
            fp,
            on_step=True,
            prog_bar=False,
            logger=True,
            sync_dist=True,
        )
        self.log(
            "{}/tn_batch".format(log_name),
            tn,
            on_step=True,
            prog_bar=False,
            logger=True,
            sync_dist=True,
        )
        self.log(
            "{}/fn_batch".format(log_name),
            fn,
            on_step=True,
            prog_bar=False,
            logger=True,
            sync_dist=True,
        )
        self.log(
            "{}/acc0".format(log_name),
            acc0,
            on_step=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )
        self.log(
            "{}/acc1".format(log_name),
            acc1,
            on_step=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )
        self.log(
            "{}/tp_tn".format(log_name),
            tp_tn,
            on_step=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )

        return preds
    

    def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        
        inputs = batch["video"]
        labels = batch["label"]

        outputs = self.shared_step(inputs)

        if self.activate_nearby_shots:
            outputs, (prevs, currs, nexts) = outputs

        loss = self.get_loss(outputs, labels, log_name="sbd_train")
        
        preds = self.get_pred_n_metric(outputs, labels, log_name="sbd_train")

        if self.first_shot_prediction:
            labels_sfd = batch["first"]
            outputs_sfd = self.additional_step(inputs)

            loss_sfd = self.get_loss(outputs_sfd, labels_sfd, log_name="sfd_train")
            pred_sfd = self.get_pred_n_metric(outputs_sfd, labels_sfd, log_name="sfd_train")

            loss = loss + loss_sfd

        if self.reverse_shot_prediction:
            labels_rev = batch["first"]
            outputs_rev = self.shared_step(torch.flip(inputs, [1]))

            loss_rev = self.get_loss(outputs_rev, labels_rev, log_name="rev_train")
            pred_rev = self.get_pred_n_metric(outputs_rev, labels_rev, log_name="rev_train")

            loss = loss + loss_rev
        

        return loss

    def validation_step(self, batch: torch.Tensor, batch_idx: int):

        vids = batch["vid"]
        sids = batch["sid"]
        inputs = batch["video"]
        labels = batch["label"]         

        outputs = self.shared_step(inputs)

        if self.activate_nearby_shots:
            # outputs, _ = outputs
            outputs, (prevs, _, _) = outputs

        prob = F.softmax(outputs, dim=1)
        preds = torch.argmax(prob, dim=1)

        # if self.activate_nearby_shots:
        #     prob = (prob + F.softmax(prevs, dim=1))/2
        #     preds = torch.argmax(prob, dim=1)

        self.acc_metric.update(prob[:, 1], labels)
        self.ap_metric.update(prob[:, 1], labels)
        self.f1_metric.update(prob[:, 1], labels)
        self.auc_metric.update(prob[:, 1], labels)

        for vid, sid, pred, gt, p in zip(vids, sids, preds, labels, prob[:, 1]):
            self.movienet_metric.update(vid, sid, pred, gt)
            self.predictions.append([vid, sid, gt.item(), pred.item(), p.item()])


        self.val_outputs.append(outputs)
        self.val_labels.append(labels)

        if self.first_shot_prediction or self.reverse_shot_prediction:

            labels_sfd = batch["first"]
            if self.first_shot_prediction:
                outputs_sfd = self.additional_step(inputs)

            if self.reverse_shot_prediction:
                outputs_sfd = self.shared_step(torch.flip(inputs, [1]))

            prob_sfd = F.softmax(outputs_sfd, dim=1)
            preds_sfd = torch.argmax(prob_sfd, dim=1)

            self.acc_metric_add.update(prob_sfd[:, 1], labels_sfd)
            self.ap_metric_add.update(prob_sfd[:, 1], labels_sfd)
            self.f1_metric_add.update(prob_sfd[:, 1], labels_sfd)
            self.auc_metric_add.update(prob_sfd[:, 1], labels_sfd)
            
            for vid, sid, pred, gt, p in zip(vids, sids, preds_sfd, labels_sfd, prob_sfd[:, 1]):
                self.movienet_metric_add.update(vid, sid, pred, gt)
                self.predictions_add.append([vid, sid, gt.item(), pred.item(), p.item()])

            self.val_outputs_add.append(outputs)
            self.val_labels_add.append(labels)
    
    

    def validation_epoch_end(self, validation_step_outputs):

        score = {}
        
        # update acc.
        acc = self.acc_metric.compute()
        torch.cuda.synchronize()
        assert isinstance(acc, dict)
        score.update(acc)

        # update average precision (AP).
        ap, _, _ = self.ap_metric.compute()  # * 100.
        ap *= 100.0
        torch.cuda.synchronize()
        assert isinstance(ap, torch.Tensor)
        score.update({"ap": ap})

        # update AUC-ROC
        auc, _, _ = self.auc_metric.compute()
        auc *= 100.0
        torch.cuda.synchronize()
        assert isinstance(auc, torch.Tensor)
        score.update({"auc": auc})

        # update F1 score.
        f1 = self.f1_metric.compute() * 100.0
        torch.cuda.synchronize()
        assert isinstance(f1, torch.Tensor)
        score.update({"f1": f1})

        # update recall, mIoU score.
        # removed for BBC
        recall, recall_at_3s, miou = self.movienet_metric.compute()
        torch.cuda.synchronize()
        assert isinstance(recall, torch.Tensor)
        assert isinstance(recall_at_3s, torch.Tensor)
        assert isinstance(miou, torch.Tensor)
        score.update({"recall": recall * 100.0})
        score.update({"recall@3s": recall_at_3s * 100})
        score.update({"mIoU": miou * 100})

        # logging
        for k, v in score.items():
            self.log(
                f"sbd_test/{k}",
                v,
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                sync_dist=True,
            )

        # reset all metrics
        self.acc_metric.reset()
        self.ap_metric.reset()
        self.f1_metric.reset()
        self.auc_metric.reset()
        self.movienet_metric.reset()


        score = {k: v.item() for k, v in score.items()}
        score.update({"epoch":self.current_epoch})

        self.print(f"\nTest Score: {score}")

        if score['ap']>self.best_ap:
            self.best_ap = score['ap']
            with open(os.path.join(self.log_dir, "val_all_score.json"), "w") as fopen:
                json.dump(score, fopen, indent=4, ensure_ascii=False)
            df = pd.DataFrame(np.array(self.predictions), columns=["vid", "sid", "gt", "pred", "prob"])
            df.to_csv(os.path.join(self.log_dir, "val_predictions.csv"), index=False)


        self.predictions.clear()

        ################################################################################
        if self.first_shot_prediction or self.reverse_shot_prediction:
            
            score = {}
            
            # update acc.
            acc = self.acc_metric_add.compute()
            torch.cuda.synchronize()
            assert isinstance(acc, dict)
            score.update(acc)
    
            # update average precision (AP).
            ap, _, _ = self.ap_metric_add.compute()  # * 100.
            ap *= 100.0
            torch.cuda.synchronize()
            assert isinstance(ap, torch.Tensor)
            score.update({"ap": ap})
    
            # update AUC-ROC
            auc, _, _ = self.auc_metric_add.compute()
            auc *= 100.0
            torch.cuda.synchronize()
            assert isinstance(auc, torch.Tensor)
            score.update({"auc": auc})
    
            # update F1 score.
            f1 = self.f1_metric_add.compute() * 100.0
            torch.cuda.synchronize()
            assert isinstance(f1, torch.Tensor)
            score.update({"f1": f1})
    
            # update recall, mIoU score.
            # removed for BBC
            recall, recall_at_3s, miou = self.movienet_metric_add.compute()
            torch.cuda.synchronize()
            assert isinstance(recall, torch.Tensor)
            assert isinstance(recall_at_3s, torch.Tensor)
            assert isinstance(miou, torch.Tensor)
            score.update({"recall": recall * 100.0})
            score.update({"recall@3s": recall_at_3s * 100})
            score.update({"mIoU": miou * 100})
    
            # logging
            if self.first_shot_prediction:
                for k, v in score.items():
                    self.log(
                        f"sfd_test/{k}",
                        v,
                        on_step=False,
                        on_epoch=True,
                        prog_bar=True,
                        logger=True,
                        sync_dist=True,
                    )
                    
            if self.reverse_shot_prediction:
                for k, v in score.items():
                    self.log(
                        f"rev_test/{k}",
                        v,
                        on_step=False,
                        on_epoch=True,
                        prog_bar=True,
                        logger=True,
                        sync_dist=True,
                    )
    
            # reset all metrics
            self.acc_metric_add.reset()
            self.ap_metric_add.reset()
            self.f1_metric_add.reset()
            self.auc_metric_add.reset()
            self.movienet_metric_add.reset()
    
            self.val_outputs_add.clear()
            self.val_labels_add.clear()

    
            score = {k: v.item() for k, v in score.items()}
            score.update({"epoch":self.current_epoch})
    
            self.print(f"\nTest Score: {score}")
    
            if score['ap']>self.best_ap:
                self.best_ap = score['ap']
                with open(os.path.join(self.log_dir, "val_all_score_add.json"), "w") as fopen:
                    json.dump(score, fopen, indent=4, ensure_ascii=False)
                df = pd.DataFrame(np.array(self.predictions_add), columns=["vid", "sid", "gt", "pred", "prob"])
                df.to_csv(os.path.join(self.log_dir, "val_predictions_add.csv"), index=False)
        
            self.predictions_add.clear()


    def test_step(self, batch: torch.Tensor, batch_idx: int):
        return self.validation_step(batch, batch_idx)


    def test_epoch_end(self, test_step_outputs):
        score = {}

        # update acc.
        acc = self.acc_metric.compute()
        torch.cuda.synchronize()
        assert isinstance(acc, dict)
        score.update(acc)

        # update average precision (AP).
        ap, _, _ = self.ap_metric.compute()  # * 100.
        ap *= 100.0
        torch.cuda.synchronize()
        assert isinstance(ap, torch.Tensor)
        score.update({"ap": ap})

        # update AUC-ROC
        auc, _, _ = self.auc_metric.compute()
        auc *= 100.0
        torch.cuda.synchronize()
        assert isinstance(auc, torch.Tensor)
        score.update({"auc": auc})

        # update F1 score.
        f1 = self.f1_metric.compute() * 100.0
        torch.cuda.synchronize()
        assert isinstance(f1, torch.Tensor)
        score.update({"f1": f1})

        if self.cfg.DATASET == "movienet":
            recall, recall_at_3s, miou = self.movienet_metric.compute()
            torch.cuda.synchronize()
            assert isinstance(recall, torch.Tensor)
            assert isinstance(recall_at_3s, torch.Tensor)
            assert isinstance(miou, torch.Tensor)
            score.update({"recall": recall * 100.0})
            score.update({"recall@3s": recall_at_3s * 100})
            score.update({"mIoU": miou * 100})

        # logging
        for k, v in score.items():
            self.log(
                f"sbd_test/{k}",
                v,
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                sync_dist=True,
            )
        score = {k: v.item() for k, v in score.items()}
        self.print(f"\nTest Score: {score}")

        # reset all metrics
        self.acc_metric.reset()
        self.ap_metric.reset()
        self.f1_metric.reset()
        self.auc_metric.reset()
        self.movienet_metric.reset()

        with open(os.path.join(self.log_dir, "test_all_score_{}.json".format(self.cfg.DATASET)), "w") as fopen:
            json.dump(score, fopen, indent=4, ensure_ascii=False)

        df = pd.DataFrame(np.array(self.predictions), columns=["vid", "sid", "gt", "pred", "prob"])
        df.to_csv(os.path.join(self.log_dir, "test_predictions_{}.csv".format(self.cfg.DATASET)), index=False)        
        
        self.predictions.clear()    

        ################################################################################
        if self.first_shot_prediction or self.reverse_shot_prediction:

            score = {}
    
            # update acc.
            acc = self.acc_metric_add.compute()
            torch.cuda.synchronize()
            assert isinstance(acc, dict)
            score.update(acc)
    
            # update average precision (AP).
            ap, _, _ = self.ap_metric_add.compute()  # * 100.
            ap *= 100.0
            torch.cuda.synchronize()
            assert isinstance(ap, torch.Tensor)
            score.update({"ap": ap})
    
            # update AUC-ROC
            auc, _, _ = self.auc_metric_add.compute()
            auc *= 100.0
            torch.cuda.synchronize()
            assert isinstance(auc, torch.Tensor)
            score.update({"auc": auc})
    
            # update F1 score.
            f1 = self.f1_metric_add.compute() * 100.0
            torch.cuda.synchronize()
            assert isinstance(f1, torch.Tensor)
            score.update({"f1": f1})
    
            if self.cfg.DATASET == "movienet":
                recall, recall_at_3s, miou = self.movienet_metric_add.compute()
                torch.cuda.synchronize()
                assert isinstance(recall, torch.Tensor)
                assert isinstance(recall_at_3s, torch.Tensor)
                assert isinstance(miou, torch.Tensor)
                score.update({"recall": recall * 100.0})
                score.update({"recall@3s": recall_at_3s * 100})
                score.update({"mIoU": miou * 100})
    
            # logging
            if self.first_shot_prediction:
                for k, v in score.items():
                    self.log(
                        f"sfd_test/{k}",
                        v,
                        on_step=False,
                        on_epoch=True,
                        prog_bar=True,
                        logger=True,
                        sync_dist=True,
                    )

            if self.reverse_shot_prediction:
                for k, v in score.items():
                    self.log(
                        f"rev_test/{k}",
                        v,
                        on_step=False,
                        on_epoch=True,
                        prog_bar=True,
                        logger=True,
                        sync_dist=True,
                    )
                    
            score = {k: v.item() for k, v in score.items()}
            self.print(f"\nTest Score: {score}")
    
            # reset all metrics
            self.acc_metric_add.reset()
            self.ap_metric_add.reset()
            self.f1_metric_add.reset()
            self.auc_metric_add.reset()
            self.movienet_metric_add.reset()
    
            with open(os.path.join(self.log_dir, "test_all_score_add.json"), "w") as fopen:
                json.dump(score, fopen, indent=4, ensure_ascii=False)
    
            df = pd.DataFrame(np.array(self.predictions_add), columns=["vid", "sid", "gt", "pred", "prob"])
            df.to_csv(os.path.join(self.log_dir, "test_predictions_add.csv"), index=False)        
            
            self.predictions_add.clear()      


    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list):
        params = []
        excluded_params = []

        for name, param in named_params:
            if not param.requires_grad:
                continue
            elif any(layer_name in name for layer_name in skip_list):
                excluded_params.append(param)
            else:
                params.append(param)

        return [
            {"params": params, "weight_decay": weight_decay},
            {"params": excluded_params, "weight_decay": 0.0},
        ]


    def configure_optimizers(self):
        # params
        skip_list = []
        weight_decay = self.cfg.TRAIN.OPTIMIZER.weight_decay
        if not self.cfg.TRAIN.OPTIMIZER.regularize_bn:
            skip_list.append("bn")
        if not self.cfg.TRAIN.OPTIMIZER.regularize_bias:
            skip_list.append("bias")
        params = self.exclude_from_wt_decay(
            self.named_parameters(), weight_decay=weight_decay, skip_list=skip_list
        )

        # optimizer
        if self.cfg.TRAIN.OPTIMIZER.name == "sgd":
            optimizer = torch.optim.SGD(
                params,
                lr=self.cfg.TRAIN.OPTIMIZER.lr.scaled_lr,
                momentum=0.9,
                weight_decay=weight_decay,
            )
        elif self.cfg.TRAIN.OPTIMIZER.name == "adam":
            optimizer = torch.optim.Adam(
                params, lr=self.cfg.TRAIN.OPTIMIZER.lr.scaled_lr
            )
        elif self.cfg.TRAIN.OPTIMIZER.name == "adamw":
            optimizer = torch.optim.AdamW(
                params, lr=self.cfg.TRAIN.OPTIMIZER.lr.scaled_lr
            )            
        else:
            raise ValueError()

        warmup_steps = int(
            self.cfg.TRAIN.TRAIN_ITERS_PER_EPOCH
            * self.cfg.TRAINER.max_epochs
            * self.cfg.TRAIN.OPTIMIZER.scheduler.warmup
        )
        total_steps = int(
            self.cfg.TRAIN.TRAIN_ITERS_PER_EPOCH * self.cfg.TRAINER.max_epochs
        )

        if self.cfg.TRAIN.OPTIMIZER.scheduler.name == "cosine_with_linear_warmup":
            scheduler = {
                "scheduler": torch.optim.lr_scheduler.LambdaLR(
                    optimizer,
                    # linear_warmup_decay(warmup_steps, 2*total_steps, cosine=True),
                    linear_warmup_decay(warmup_steps, total_steps, cosine=True),
                    ### 20 / (2*5)
                    ### 10 / (5)
                    # lr_lambda = lambda epoch: 0.95 ** epoch,
                ),
                "interval": "step",
                "frequency": 1,
            }
        else:
            raise NotImplementedError

        return [optimizer], [scheduler]

