import datetime
import os
import os.path as osp
import pickle
import time
from contextlib import contextmanager

from scipy import io

import torch
from torch.autograd import Variable

def find_index(seq, item):
    for i, x in enumerate(seq):
        if item == x:
            return i
    return -1


def adjust_lr_exp(optimizer, ep, decay_at_epochs=None, factor=None):
    """Decay exponentially in the later phase of training. All parameters in the
    optimizer share the same learning rate.

    Args:
        optimizer: a pytorch `Optimizer` object
        base_lr: starting learning rate
        ep: current epoch, ep >= 1
        total_ep: total number of epochs to train
        decay_at_epochs: start decaying at the BEGINNING of this epoch

    Example:
        base_lr = 2e-4
        total_ep = 300
        decay_at_epochs = 201
        It means the learning rate starts at 2e-4 and begins decaying after 200
        epochs. And training stops after 300 epochs.

    NOTE:
        It is meant to be called at the BEGINNING of an epoch.
    """
    assert ep >= 1, "Current epoch number should be >= 1"
    total_ep = factor
    if ep >= decay_at_epochs:
        for g in optimizer.param_groups:
            old_lr = g['lr']
            g['lr'] = (old_lr * (0.1 ** (float(ep + 1 -
                                                 decay_at_epochs) / (total_ep + 1 - decay_at_epochs))))
        print('LR changed to '+str(g['lr']))
    return


def adjust_lr_staircase(optimizer, ep, decay_at_epochs=None, factor=None):
    """Multiplied by a factor at the BEGINNING of specified epochs. All
    parameters in the optimizer share the same learning rate.

    Args:
        optimizer: a pytorch `Optimizer` object
        base_lr: starting learning rate
        ep: current epoch, ep >= 1
        decay_at_epochs: a list or tuple; learning rate is multiplied by a factor
            at the BEGINNING of these epochs
        factor: a number in range (0, 1)

    Example:
        base_lr = 1e-3
        decay_at_epochs = [51, 101]
        factor = 0.1
        It means the learning rate starts at 1e-3 and is multiplied by 0.1 at the
        BEGINNING of the 51'st epoch, and then further multiplied by 0.1 at the
        BEGINNING of the 101'st epoch, then stays unchanged till the end of
        training.

    NOTE:
        It is meant to be called at the BEGINNING of an epoch.
    """
    assert ep >= 1, "Current epoch number should be >= 1"

    if ep in decay_at_epochs:
        ind = find_index(decay_at_epochs, ep)
        for g in optimizer.param_groups:
            old_lr = g['lr']
            # g['lr'] = old_lr * factor ** (ind + 1)
            g['lr'] = old_lr * factor
    return


def adjust_lr_warmup(optimizer, ep, decay_at_epochs=None, factor=None):
    assert ep >= 1, "Current epoch number should be >= 1"
    if ep in decay_at_epochs and ep < decay_at_epochs[0]:
        for g in optimizer.param_groups:
            old_lr = g['lr']
            g['lr'] = (old_lr) * (float(ep+1)/float(decay_at_epochs[0]))
    elif ep in decay_at_epochs and ep >= decay_at_epochs[0]:
        ind = find_index(decay_at_epochs, ep)
        for g in optimizer.param_groups:
            old_lr = g['lr']
            g['lr'] = old_lr * factor ** (ind)
    return


def adjust_lr_epochs(optimizer,  ep, decay_at_epochs=None, factor=None):
    assert ep >= 1, "Current epoch number should be >= 1"
    if ep % decay_at_epochs == 0:
        for g in optimizer.param_groups:
            old_lr = g['lr']
            g['lr'] = old_lr * factor
    return


def get_lr_strategy(lr_strategy):
    all_strategies = {'exp': adjust_lr_exp, 'staircase': adjust_lr_staircase,
                      'warmup': adjust_lr_warmup, 'epochs': adjust_lr_epochs}
    assert lr_strategy in all_strategies, 'Have no such lr change strategy!'
    return all_strategies[lr_strategy]
