import torch
import torch.optim as optim
from bisect import bisect_right

from losses.bs import BS
from losses.ce_drw import CE_DRW
from losses.ce import CE
from losses.ldam_drw import LDAM_DRW
from losses.ride import RIDE, RIDEWithDistill
from losses.ncl import NIL_NBOD
from losses.bcl import BCLLoss
from losses.kps import KPSLoss

from utils.common import adjust_learning_rate
from torch.optim import lr_scheduler

def get_optimizer(args, model):
    _model = model['model'] if args.loss_fn == 'ncl' else model
    return optim.SGD(_model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd,
                     nesterov=args.nesterov)

def get_scheduler(args, optimizer):
    if args.scheduler == 'cosine':
        return lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min = 0)
    elif args.scheduler == 'warmup':
        return None

def get_loss(args, N_SAMPLES_PER_CLASS):
    if args.loss_fn == 'ce':
        train_criterion = CE()
    elif args.loss_fn == 'ce_drw':
        train_criterion = CE_DRW(cls_num_list=N_SAMPLES_PER_CLASS, reweight_epoch=160)
    elif args.loss_fn == 'bs':
        train_criterion = BS(N_SAMPLES_PER_CLASS)
    elif args.loss_fn == 'ldam_drw':
        train_criterion = LDAM_DRW(cls_num_list=N_SAMPLES_PER_CLASS, reweight_epoch=160, max_m=0.5, s=30).cuda()
    elif args.loss_fn == 'ride':
        if args.num_experts == 3 and args.ride_distill:
            train_criterion = RIDEWithDistill(cls_num_list=N_SAMPLES_PER_CLASS, additional_diversity_factor=-0.45, reweight=True, reweight_epoch=160)
        else:
            train_criterion = RIDE(cls_num_list=N_SAMPLES_PER_CLASS, additional_diversity_factor=-0.45, reweight=True, reweight_epoch=160)
        train_criterion = train_criterion.to(torch.device('cuda'))
    elif args.loss_fn == 'ncl':
        train_criterion = NIL_NBOD(args, N_SAMPLES_PER_CLASS)

    elif args.loss_fn == 'bcl':
        train_criterion = BCLLoss(N_SAMPLES_PER_CLASS)

    else:
        raise NotImplementedError
        
    return train_criterion

def get_loss_by_name(loss_name, N_SAMPLES_PER_CLASS,args):
    if loss_name == 'ce':
        train_criterion = CE().cuda()
    elif loss_name == 'bs':
        train_criterion = BS(N_SAMPLES_PER_CLASS).cuda()
    elif loss_name == 'ce_drw':
        train_criterion = CE_DRW(cls_num_list=N_SAMPLES_PER_CLASS, reweight_epoch=160).cuda()
    elif loss_name == 'ldam_drw':
        train_criterion = LDAM_DRW(cls_num_list=N_SAMPLES_PER_CLASS, reweight_epoch=160, max_m=0.5, s=30).cuda()
    elif loss_name == 'ncl':
        train_criterion = NIL_NBOD(args, N_SAMPLES_PER_CLASS).cuda()
    elif loss_name == 'bcl':
        train_criterion = BCLLoss(N_SAMPLES_PER_CLASS).cuda()
    elif loss_name == 'kps':
        train_criterion = KPSLoss(cls_num_list=N_SAMPLES_PER_CLASS, max_m = 0.1, s = 3).cuda()
    return train_criterion
