import pprint
import jax
from jax.tree_util import tree_map
import jax.numpy as jnp
import optax



def is_predictor_fn(p):
    mask = tree_map(lambda a: False, p)
    for k in p.keys():
        if 'pred' in k:
            mask[k] = True
    return mask

def is_lep_fn(p):
    mask = tree_map(lambda a: False, p)
    for k in p.keys():
        if 'lep' in k:
            mask[k] = True
    return mask

def not_predictor_or_lep_fn(p):
    mask = tree_map(lambda a: True, p)
    for k in p.keys():
        if 'pred' in k or 'lep' in k:
            mask[k] = False
    return mask

def weight_decay_mask(p):
    mask = tree_map(lambda a: True, p)
    def f(d):
        for k in d.keys():
            if isinstance(d[k], dict):
                f(d[k])
            elif 'bias' in k:
                d[k] = False
            elif 'scale' in k:
                d[k] = False
    if isinstance(mask, dict):
        f(mask)
    return mask




def get_optimizer(args, num_batches):

    lr = args.lr
    warmup_steps = args.warmup_epochs * num_batches
    num_train_steps = args.num_epochs * num_batches
    pred_lr_coeff = args.pred_lr_coeff
    lep_lr = 1e-4

    if args.cosine_lr_decay:
        lr_fn = optax.warmup_cosine_decay_schedule(init_value=lr/10, peak_value=lr, warmup_steps=warmup_steps, decay_steps=num_train_steps)
        predictor_lr_fn = optax.warmup_cosine_decay_schedule(init_value=pred_lr_coeff*lr/10, peak_value=pred_lr_coeff*lr, warmup_steps=warmup_steps, decay_steps=num_train_steps)
    else:
        lr_fn = optax.linear_schedule(init_value=lr/10, end_value=lr, transition_steps=warmup_steps)
        predictor_lr_fn = optax.linear_schedule(init_value=pred_lr_coeff*lr/10, end_value=pred_lr_coeff*lr, transition_steps=warmup_steps)
    lep_lr_fn = optax.linear_schedule(init_value=lep_lr/10, end_value=lep_lr, transition_steps=warmup_steps)

    if args.opt=='lars':
        backbone_opt = optax.inject_hyperparams(optax.lars)(lr_fn, weight_decay=args.wd, momentum=0.0, weight_decay_mask=weight_decay_mask)
        predictor_opt = optax.inject_hyperparams(optax.lars)(predictor_lr_fn, weight_decay=args.wd, momentum=0.0, weight_decay_mask=weight_decay_mask)
    elif args.opt=='adamw':
        backbone_opt = optax.inject_hyperparams(optax.adamw)(lr_fn, weight_decay=args.wd)
        predictor_opt = optax.inject_hyperparams(optax.adamw)(predictor_lr_fn, weight_decay=args.wd)
    elif args.opt=='sgd':
        backbone_opt = optax.inject_hyperparams(optax.sgd)(lr_fn, momentum=0.9)
        backbone_opt = optax.chain(optax.add_decayed_weights(args.wd), backbone_opt)
        predictor_opt = optax.inject_hyperparams(optax.sgd)(predictor_lr_fn, momentum=0.9)
        predictor_opt = optax.chain(optax.add_decayed_weights(args.wd), predictor_opt)

    lep_opt = optax.inject_hyperparams(optax.adamw)(lep_lr_fn, weight_decay=args.wd)

    opt = optax.chain(
        optax.masked(backbone_opt, not_predictor_or_lep_fn),
        optax.masked(predictor_opt, is_predictor_fn),
        optax.masked(lep_opt, is_lep_fn))

    return opt


def get_current_lr(state, args, pred=False):
    idx = 1 if pred else 0
    if args.opt in ['adamw', 'lars']:
        return state.opt_state[idx].inner_state.hyperparams['learning_rate']
    elif args.opt=='sgd':
        return state.opt_state[idx].inner_state[1].hyperparams['learning_rate']


def pretty_print(dictio):
    print()
    pprint.PrettyPrinter(depth=6).pprint(jax.tree_map(lambda x: x.shape if isinstance(x, jnp.ndarray) else x, dictio))
    print()