import torch.optim as optim
import numpy as np
from functools import partial


device_cache = None
def get_device():
    global device_cache
    if device_cache is None:
        device_cache = torch.device("cuda") if torch.cuda.is_available() \
            else torch.device("cpu")
    return device_cache

def build_optimizer(args, params, weight_decay=0.0):
    filter_fn = filter(lambda p : p.requires_grad, params)
    if args.opt == 'adam':
        optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'sgd':
        optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'adagrad':
        optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)
    if args.opt_scheduler == 'none':
        return None, optimizer
    elif args.opt_scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)
    elif args.opt_scheduler == 'cos':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)
    elif args.opt_scheduler == 'cosine_schedule':
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, partial(cosine_schedule, restart=args.opt_restart))

    return scheduler, optimizer

def cosine_schedule(epoch, restart):
    alpha = 0.0001

    decay_steps = restart
    steps_taken = epoch % restart

    decay_rate = 0.5 * (1 + np.cos(np.pi * steps_taken / decay_steps))
    return (1 - alpha) * decay_rate + alpha
