import torch as t
import torch.optim as optim
from io import StringIO
import warnings

def mk_opt_and_scheduler(model, num_epochs,
                         warmup=20,  # num epochs to warmup for
                         schedule='cosine', # constant, halving, linear-decay
                         half_period=None, # only relevant when [lr_schedule] == 'halving'. if not supplied, defaults to `num_epochs // 4`
                         init_lr=1e-3, max_lr=1e-2, min_lr=1e-5, # set bounds on lr
                         weight_decay=0.
                         ):
    assert schedule in ['constant', 'halving', 'linear-decay', 'cosine'], f"lr_schedule = {schedule} not in ['constant', 'halving', 'linear-decay']"
    assert warmup >= 0, f"lr_warmup = {warmup} must be >= 0"
    if schedule == 'cosine':
        warnings.filterwarnings('ignore', message='.*scheduler\.step.*')
        opt = optim.Adam(model.parameters(), lr=max_lr, weight_decay=weight_decay)
        start = init_lr / max_lr; end = max_lr / max_lr
        get_warmup_lr = lambda x: start + (end - start) * x / warmup
        s1 = t.optim.lr_scheduler.LambdaLR(opt, lr_lambda=get_warmup_lr)
        s2 = t.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_epochs, eta_min=min_lr)
        scheduler = t.optim.lr_scheduler.SequentialLR(opt, schedulers=[s1, s2], milestones=[warmup])
    else:
        if schedule == 'halving' and half_period is None: half_period = num_epochs // 4
        opt = optim.Adam(model.parameters(), lr=1., weight_decay=weight_decay)
        def calc_lr(epoch):
            if schedule == 'constant':
                return init_lr
            elif epoch <= warmup: # warmup phase
                return init_lr + (max_lr - init_lr) * epoch / warmup

            start = init_lr if warmup == 0 else max_lr
            e = epoch - warmup
            n = num_epochs - warmup

            # warmdown phase
            if schedule == 'halving':
                return start * (0.5 ** (e // half_period))
            elif schedule == 'linear-decay':
                end = min_lr
                res = start + (end - start) * e / n
                return res
            else:
                raise ValueError(f"not sure what to do: epoch={epoch}, lr_schedule={schedule}")
        scheduler = t.optim.lr_scheduler.LambdaLR(opt, lr_lambda=calc_lr)
    return opt, scheduler


class Namespace:
    def __init__(self, **kwargs):
        for k, v in kwargs.items(): setattr(self, k, v)


"""
determinism
"""
import random; import numpy as np
def set_seed(seed=0):
    t.manual_seed(seed)
    t.cuda.manual_seed(seed)
    t.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)

"""
nice printing
"""
str_contraction = dict(train_acc='tr_a', test_acc='te_a', val_acc='v_a',
                       train_obj='tr_o', test_obj='te_o', val_obj='v_o',
                       train_ll='tr_ll', test_ll='te_ll', val_ll='v_ll',
                       train_loss='tr_lo', test_loss='te_lo', val_loss='v_lo',
                       epoch='ep')
def epoch_metrics_buf(d):
    buf = StringIO()
    for k, v in d.items():
        if k.endswith('_ce'): continue
        if k in str_contraction:
            k = str_contraction[k]
        if isinstance(v, float):

            if k.endswith('_ll') or k.endswith('_a') or k.endswith('_o'):
                buf.write(f"|{k} {v:.3f}")
            elif k == 'mem':
                buf.write(f"|{k} {v:.2f}")
            else:
                buf.write(f"|{k} {v:.4f}")
        else:
            buf.write(f"|{k} {v}")
    return buf