import torch

def get_lambda_dict(config: dict)-> dict:
    lambda_dict = {}
    for key, value in config.items():
        if 'lambda' in key and value>0.:
            #if key == 'lambda_curv':
            #    lambda_dict[key] = torch.tensor(value, device=config['device'])
            #else:
            lambda_dict[key] = torch.tensor(1.0, device=config['device'])
    return lambda_dict

def grad_norm_sub_losses(model, sub_loss_dict: dict)-> dict: 

    sub_grad_norm_dict = {}
    for loss_indent, loss_tensor in sub_loss_dict.items():
        if not torch.is_nonzero(loss_tensor):
            sub_grad_norm_dict[loss_indent.replace("loss_unweighted", "grad_norm")] = torch.tensor(0.0)
        elif loss_tensor.requires_grad:
            sub_grads = torch.autograd.grad(loss_tensor, model.parameters(), retain_graph=True, allow_unused=True)
            sub_grad_norm = torch.cat([grad.detach().flatten() for grad in sub_grads if grad is not None]).norm()
            sub_grad_norm_dict[loss_indent.replace("loss_unweighted", "grad_norm")] = sub_grad_norm
    return sub_grad_norm_dict

def lambda_balancing(lambda_dict: dict, sub_grad_norm_dict: dict, alpha: float)-> dict:
    sub_grad_sum = sum(sub_grad_norm_dict.values())
    for lambda_key, lambda_value in lambda_dict.items():
        sub_grad_norm_key = lambda_key.replace('lambda', 'grad_norm')
        sub_grad_norm = sub_grad_norm_dict[sub_grad_norm_key]
        if not torch.is_nonzero(sub_grad_norm):
            continue
        lambda_value_update = sub_grad_sum / sub_grad_norm

        lambda_dict[lambda_key] = alpha * lambda_value + (1 - alpha) * lambda_value_update

    return lambda_dict

def lambda_vec_balancing(lambda_vec_dict: dict, sub_loss_dict: dict, eta: float)-> dict:
    pass 

def get_initial_nu_dict(lambda_dict: dict)-> dict:
    nu_dict = {}
    for lambda_key in lambda_dict.keys():
        nu_dict[lambda_key.replace('lambda_', 'nu_')] = 0.

    return nu_dict

def adaptive_penalty_update(lambda_dict: dict, lambda_vec_dict: dict, nu_dict: dict, sub_loss_unweighted_dict: dict, config: dict)-> dict:
    gamma = 0.01
    alpha = 0.9
    epsilon = 1.e-08

    with torch.no_grad():
        for lambda_key, lambda_value in lambda_dict.items():
            nu_key = lambda_key.replace('lambda_', 'nu_')
            loss_key = lambda_key.replace('lambda_', 'loss_unweighted_')

            if lambda_key != 'lambda_' + config['objective']:
                nu_dict[nu_key] = nu_dict[nu_key] * alpha + (1-alpha) * sub_loss_unweighted_dict[loss_key]
                lambda_dict[lambda_key] = gamma / (torch.sqrt(nu_dict[nu_key]) + epsilon)

            #if lambda_key != 'lambda_scc':
            lambda_vec_dict[lambda_key][0] = lambda_vec_dict[lambda_key][0] + lambda_dict[lambda_key] * torch.sqrt(sub_loss_unweighted_dict[loss_key])

    return lambda_dict, lambda_vec_dict, nu_dict



def scale_losses(lambda_dict: dict, sub_loss_dict: dict)-> dict:
    for lambda_key, lambda_value in lambda_dict.items():
        sub_loss_key = 'loss' + lambda_key.replace('lambda', '')
        sub_loss_dict[sub_loss_key] = lambda_value * sub_loss_dict[sub_loss_key]

    return sub_loss_dict