import math
import torch
import numpy as np
# from train.sam import SAM


# opt

def get_opt(args, model):
    if args.optimizer == "adam":
        cOpt = torch.optim.Adam
    else:
        cOpt = torch.optim.SGD
    if 'ech' in args.method:
        p_m = model.param_m()
        p_ps = model.param_ps()
        opt_model = init_opt(args, cOpt, p_m)
        opt_ps = init_opt(args, cOpt, p_ps)
        return opt_model, opt_ps
    params = model.parameters()
    optimizer = init_opt(args, cOpt, params)
    return optimizer


def init_opt(args, cOpt, p):
    if args.optimizer == "adam":
        optimizer = cOpt(p, lr=args.lr, betas=[.9, .999],
                         weight_decay=0.)
    else:
        optimizer = cOpt(p, lr=args.lr, momentum=args.momentum,
                         weight_decay=args.weight_decay)
    return optimizer


# grad

def clip_gradients(model, clip):
    norms = []
    for name, p in model.named_parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            norms.append(param_norm.item())
            clip_coef = clip / (param_norm + 1e-6)
            if clip_coef < 1:
                p.grad.data.mul_(clip_coef)
    return norms


def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
    if epoch >= freeze_last_layer:
        return
    for n, p in model.named_parameters():
        if "last_layer" in n:
            p.grad = None
