''' Details
Author: Zhipeng Zhang (zpzhang1995@gmail.com)
Function: set learning rate for training
Data: 2021.6.23
'''

import torch
import numpy as np
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
from bisect import bisect_right

import timm
import timm.scheduler

class WarmupMultiStepLR(_LRScheduler):
    def __init__(
            self,
            optimizer,
            milestones,
            gamma=0.1,
            warmup_factor=0.01,
            warmup_iters=20.,
            warmup_method="linear",
            last_epoch=-1,
    ):
        if not list(milestones) == sorted(milestones):
            raise ValueError(
                "Milestones should be a list of" " increasing integers. Got {}",
                milestones,
            )

        if warmup_method not in ("constant", "linear"):
            raise ValueError(
                "Only 'constant' or 'linear' warmup_method accepted"
                "got {}".format(warmup_method)
            )
        self.milestones = milestones
        self.gamma = gamma
        self.warmup_factor = warmup_factor
        self.warmup_iters = warmup_iters
        self.warmup_method = warmup_method
        super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        warmup_factor = 1
        if self.last_epoch < self.warmup_iters:
            if self.warmup_method == "constant":
                warmup_factor = self.warmup_factor
            elif self.warmup_method == "linear":
                # print(self.last_epoch)
                alpha = (self.last_epoch + 1) / self.warmup_iters
                # print(alpha)
                warmup_factor = self.warmup_factor * (1 - alpha) + alpha
                # print(warmup_factor)
        return [
            base_lr
            * warmup_factor
            * self.gamma ** bisect_right(self.milestones, self.last_epoch)
            for base_lr in self.base_lrs
        ]


class ParamFreezer(object):
    def __init__(self, module_names, param_names=[]):
        self.module_names = module_names
        self.freeze_params = dict()
        self.global_param_names = param_names

    def freeze(self, model):
        for name in self.module_names:
            self.freeze_params = []
        for k, v in model.named_parameters():
            kpart = k.split('.')
            if kpart[0] in self.module_names or kpart[1] in self.module_names:
                if v.requires_grad:
                    v.requires_grad_(False)
                    self.freeze_params.append(k)

        if len(self.global_param_names) == 0:
            return
        for k, v in model.named_parameters():
            if k in self.global_param_names and v.requires_grad:
                v.requires_grad_(False)

    def unfreeze(self, model):
        for k, v in model.named_parameters():
            kpart = k.split('.')
            if kpart[0] in self.module_names or kpart[1] in self.module_names:
                if k in self.freeze_params:
                    v.requires_grad_(True)
                    self.freeze_params.remove(k)

        if len(self.global_param_names) == 0:
            return
        for k, v in model.named_parameters():
            if k in self.global_param_names:
                v.requires_grad_(True)


def build_siamese_opt_lr(cfg, model):
    '''
    common learning tricks in Siamese: fix backbone (warmup) --> unfix
    '''
    base_lr = cfg.train.lr.base
    tune_lr = cfg.train.lr.tune
    bert_lr = cfg.train.lr.bert


    for n, p in model.named_parameters():
        if "bert" in n:# "bert" in n:
            p.requires_grad = False

    trainable_params = [
        # head部分
        {
            "params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad],
            "lr": base_lr,
        },
        # backbone部分
        {
            "params": [p for n, p in model.named_parameters() if "backbone" in n and 'bert' not in n and p.requires_grad],
            "lr": tune_lr,
        },
        # bert部分
        # {
        #     "params": [p for n, p in model.named_parameters() if "backbone" in n and "bert" in n and p.requires_grad],
        #     "lr": bert_lr,
        # },
    ]

    optimizer = torch.optim.AdamW(trainable_params, lr=base_lr, weight_decay=cfg.train.weight_decay)
    
    # # 加载全部参数微调
    if cfg.train.resume:
        state = torch.load(cfg.train.resume, map_location="cpu")
        optimizer.load_state_dict(state['optimizer'])


    if cfg.train.lr.type == 'step':
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, cfg.train.lr.drop_epoch)
    elif cfg.train.lr.type == "Mstep":
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                            milestones=cfg.train.lr.milestones,
                                                            gamma=cfg.train.lr.gamma)
    elif cfg.train.lr.type == 'WarmMstep':
        lr_scheduler = WarmupMultiStepLR(optimizer, milestones=cfg.train.lr.milestones,
                                         gamma=cfg.train.lr.gamma, warmup_iters=cfg.train.lr.warm_epoch)
    elif cfg.train.lr.type == "cosine":
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.train.end_epoch, eta_min=0, last_epoch=-1, verbose=False)
    elif cfg.train.lr.type == "Warmcosine":
        lr_scheduler = timm.scheduler.CosineLRScheduler(optimizer=optimizer,
                                                        t_initial=cfg.train.end_epoch,
                                                        lr_min=1e-5,
                                                        warmup_t=5,
                                                        warmup_lr_init=1e-5)
    else:
        raise ValueError("Unsupported scheduler")
    
    # lr_scheduler = timm.scheduler.CosineLRScheduler(optimizer=optimizer,
    #                                                 t_initial=cfg.train.end_epoch,
    #                                                 lr_min=1e-5,
    #                                                 warmup_t=5,
    #                                                 warmup_lr_init=1e-4)

    return optimizer, lr_scheduler
