from code.optim.aligned.balancer import ZAlignedBalancer, ThetaAlignedBalancer
from code.optim.mgda.balancer import MGDABalancer, MGDAUBBalancer
from code.optim.pcgrad.balancer import PCGradBalancer
from code.optim.gradnorm.balancer import GradNormBalancer
from code.optim.uncertainty.balancer import HomoscedasticUncertaintyBalancer
from code.optim.basic.balancer import DummyBalancer


def get_balancer(args):
    balancer = None
    if args.balancer not in {"zalign", "talign", "mgda", "mgdaub", "pcgrad", "gradnorm", "uncertainty", "dummy"}:
        raise ValueError(f'Check the balancer name')

    if args.balancer == 'zalign':
        balancer = ZAlignedBalancer(args.scale_heads)
    elif args.balancer == "talign":
        balancer = ThetaAlignedBalancer(args.scale_heads)
    elif args.balancer == "mgda":
        balancer = MGDABalancer(args.scale_heads)
    elif args.balancer == "mgdaub":
        balancer = MGDAUBBalancer(args.scale_heads)
    elif args.balancer == "pcgrad":
        balancer = PCGradBalancer()
    elif args.balancer == "gradnorm":
        if args.benchmark == 'celeba':
            balancer = GradNormBalancer(1.5) # 2.0 for Localization; 1.5 for CelebA
        elif args.benchmark == '7scenes':
            balancer = GradNormBalancer(2.0)  # 2.0 for Localization; 1.5 for CelebA
    elif args.balancer == "uncertainty":
        balancer = HomoscedasticUncertaintyBalancer()
    elif args.balancer == "dummy":
        balancer = DummyBalancer(alpha=5.0)
    return balancer
