# This script contains function to set the training criterion, optimizer and schedule.

import sys

sys.path.append("../../")
import torch
import torch.nn as nn


class flooding(torch.nn.Module):
    # idea: module that can add flooding formula to the loss function
    """The additional flooding trick on loss"""

    def __init__(self, inner_criterion, flooding_scalar=0.5):
        super(flooding, self).__init__()
        self.inner_criterion = inner_criterion
        self.flooding_scalar = float(flooding_scalar)

    def forward(self, output, target):
        return (self.inner_criterion(output, target) - self.flooding_scalar).abs() + self.flooding_scalar


def argparser_criterion(args):
    """
    # idea: generate the criterion, default is CrossEntropyLoss
    """
    criterion = nn.CrossEntropyLoss()
    if "flooding_scalar" in args.__dict__:  # use the flooding formulation warpper
        criterion = flooding(criterion, flooding_scalar=float(args.flooding_scalar))
    return criterion


def argparser_opt_scheduler(model, args):
    # idea: given model and args, return the optimizer and scheduler you choose to use

    if args.client_optimizer == "sgd":
        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=args.lr,
            momentum=args.sgd_momentum,  # 0.9
            weight_decay=args.wd,  # 5e-4
        )
    elif args.client_optimizer == "adadelta":
        optimizer = torch.optim.Adadelta(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=args.lr,
            rho=args.rho,  # 0.95,
            eps=args.eps,  # 1e-07,
        )
    else:
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=args.lr,
            betas=args.adam_betas,
            weight_decay=args.wd,
            amsgrad=True,
        )

    if args.lr_scheduler == "CyclicLR":
        scheduler = torch.optim.lr_scheduler.CyclicLR(
            optimizer,
            base_lr=args.min_lr,
            max_lr=args.lr,
            step_size_up=args.step_size_up,
            step_size_down=args.step_size_down,
            cycle_momentum=False,
        )
    elif args.lr_scheduler == "StepLR":
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.steplr_stepsize, gamma=args.steplr_gamma
        )  # 1  # 0.92
    elif args.lr_scheduler == "CosineAnnealingLR":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=100 if (("cos_t_max" not in args.__dict__) or args.cos_t_max is None) else args.cos_t_max,
        )
    elif args.lr_scheduler == "MultiStepLR":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.steplr_milestones, args.steplr_gamma)
    elif args.lr_scheduler == "ReduceLROnPlateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            **({"factor": args.ReduceLROnPlateau_factor} if "ReduceLROnPlateau_factor" in args.__dict__ else {})
        )
    else:
        scheduler = None

    return optimizer, scheduler
