import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

import os
import argparse
import time
from models import *
from adabound import AdaBound
from torch.optim import SGD
from optimizers import *
from torchvision.transforms.functional import rotate
from copy import deepcopy
import random
import numpy as np
import torch

def get_grad_norm(model):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    return total_norm

def create_optimizer(args, model_params):
    args.optim = args.optim.lower()
    if args.optim == "sgd":
        return optim.SGD(
            model_params,
            args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optim == "adam":
        return Adam(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            eps=args.eps,
        )
    elif args.optim == "fromage":
        return Fromage(model_params, args.lr)
    elif args.optim == "radam":
        return RAdam(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            eps=args.eps,
        )
    elif args.optim == "adamw":
        return AdamW(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            eps=args.eps,
        )
    elif args.optim == "adabelief":
        return AdaBelief(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            eps=args.eps,
        )
    elif args.optim == "yogi":
        return Yogi(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
        )
    elif args.optim == "msvag":
        return MSVAG(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
        )
    elif args.optim == "adabound":
        return AdaBound(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            final_lr=args.final_lr,
            gamma=args.gamma,
            weight_decay=args.weight_decay,
        )
    elif args.optim == "cadam":
        return CAdam(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            eps=args.eps,
        )
    elif args.optim == "cadamw":
        return CAdamW(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            eps=args.eps,
            decay_all=False,
        )
    elif args.optim == "cadamw-all":
        return CAdamW(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            eps=args.eps,
            decay_all=True,
        )
    elif args.optim == "amsgrad":
        return Adam(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            eps=args.eps,
            amsgrad=True,
        )
    elif args.optim == "camsgrad":
        return CAdam(
            model_params,
            args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            eps=args.eps,
            amsgrad=True,
        )
    else:
        print("Optimizer not found")
        
def get_ckpt_name(
    model="resnet",
    optimizer="sgd",
    lr=0.1,
    final_lr=0.1,
    momentum=0.9,
    beta1=0.9,
    beta2=0.999,
    gamma=1e-3,
    eps=1e-8,
    weight_decay=5e-4,
    reset=False,
    run=0,
    weight_decouple=False,
    rectify=False,
):
    name = {
        "sgd": "lr{}-momentum{}-wdecay{}-run{}".format(lr, momentum, weight_decay, run),
        "adam": "lr{}-betas{}-{}-wdecay{}-eps{}-run{}".format(
            lr, beta1, beta2, weight_decay, eps, run
        ),
        "fromage": "lr{}-betas{}-{}-wdecay{}-eps{}-run{}".format(
            lr, beta1, beta2, weight_decay, eps, run
        ),
        "radam": "lr{}-betas{}-{}-wdecay{}-eps{}-run{}".format(
            lr, beta1, beta2, weight_decay, eps, run
        ),
        "adamw": "lr{}-betas{}-{}-wdecay{}-eps{}-run{}".format(
            lr, beta1, beta2, weight_decay, eps, run
        ),
        "adabelief": "lr{}-betas{}-{}-eps{}-wdecay{}-run{}".format(
            lr, beta1, beta2, eps, weight_decay, run
        ),
        "adabound": "lr{}-betas{}-{}-final_lr{}-gamma{}-wdecay{}-run{}".format(
            lr, beta1, beta2, final_lr, gamma, weight_decay, run
        ),
        "yogi": "lr{}-betas{}-{}-eps{}-wdecay{}-run{}".format(
            lr, beta1, beta2, eps, weight_decay, run
        ),
        "msvag": "lr{}-betas{}-{}-eps{}-wdecay{}-run{}".format(
            lr, beta1, beta2, eps, weight_decay, run
        ),
        "cadam": "lr{}-betas{}-{}-wdecay{}-eps{}-run{}".format(
            lr, beta1, beta2, weight_decay, eps, run
        ),
        "cadamw": "lr{}-betas{}-{}-wdecay{}-eps{}-run{}".format(
            lr, beta1, beta2, weight_decay, eps, run
        ),
        "cadamw-all": "lr{}-betas{}-{}-wdecay{}-eps{}-run{}".format(
            lr, beta1, beta2, weight_decay, eps, run
        ),
        "amsgrad": "lr{}-betas{}-{}-wdecay{}-eps{}-run{}".format(
            lr, beta1, beta2, weight_decay, eps, run
        ),
        "camsgrad": "lr{}-betas{}-{}-wdecay{}-eps{}-run{}".format(
            lr, beta1, beta2, weight_decay, eps, run
        ),
    }[optimizer]
    return "{}-{}-{}-reset{}".format(model, optimizer, name, str(reset))