# This script contains function to set the training criterion, optimizer and schedule.

import sys

sys.path.append('../../')
import torch
import torch.nn as nn


class flooding(torch.nn.Module):
    # idea: module that can add flooding formula to the loss function
    '''The additional flooding trick on loss'''

    def __init__(self, inner_criterion, flooding_scalar=0.5):
        super(flooding, self).__init__()
        self.inner_criterion = inner_criterion
        self.flooding_scalar = float(flooding_scalar)

    def forward(self, output, target):
        return (self.inner_criterion(output, target) - self.flooding_scalar).abs() + self.flooding_scalar


def argparser_criterion(args):
    '''
    # idea: generate the criterion, default is CrossEntropyLoss
    '''
    criterion = nn.CrossEntropyLoss()
    if ('flooding_scalar' in args.__dict__):  # use the flooding formulation warpper
        criterion = flooding(
            criterion,
            flooding_scalar=float(
                args.flooding_scalar
            )
        )
    return criterion


def argparser_opt_scheduler(model, args):
    # idea: given model and args, return the optimizer and scheduler you choose to use

    if args.client_optimizer == "sgd":
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
                                    lr=args.lr,
                                    momentum=args.sgd_momentum,  # 0.9
                                    weight_decay=args.wd,  # 5e-4
                                    )
    elif args.client_optimizer == 'adadelta':
        optimizer = torch.optim.Adadelta(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=args.lr,
            rho=args.rho,  # 0.95,
            eps=args.eps,  # 1e-07,
        )
    else:
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                     lr=args.lr,
                                     betas=args.adam_betas,
                                     weight_decay=args.wd,
                                     amsgrad=True)

    if args.lr_scheduler == 'CyclicLR':
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
                                                      base_lr=args.min_lr,
                                                      max_lr=args.lr,
                                                      step_size_up=args.step_size_up,
                                                      step_size_down=args.step_size_down,
                                                      cycle_momentum=False)
    elif args.lr_scheduler == 'StepLR':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=args.steplr_stepsize,  # 1
                                                    gamma=args.steplr_gamma)  # 0.92
    elif args.lr_scheduler == 'CosineAnnealingLR':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100 if (
                    ("cos_t_max" not in args.__dict__) or args.cos_t_max is None) else args.cos_t_max)
    elif args.lr_scheduler == 'MultiStepLR':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.steplr_milestones, args.steplr_gamma)
    elif args.lr_scheduler == 'ReduceLROnPlateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            **({
                   'factor': args.ReduceLROnPlateau_factor
               } if 'ReduceLROnPlateau_factor' in args.__dict__ else {})
        )
    else:
        scheduler = None

    return optimizer, scheduler
