import glob
import torch
import os
from collections import OrderedDict


def tie_weights(src, trg):
    assert type(src) == type(trg)
    trg.weight = src.weight
    trg.bias = src.bias


def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) +
                                param.data * tau)



def save_model(model, step, logs_path, types, max_model=None):
    start = len(types) + 1
    os.makedirs(logs_path, exist_ok=True)
    if max_model is not None:
        model_list = glob.glob(os.path.join(logs_path, '*.pth'))
        if len(model_list) > max_model - 1:
            min_step = min(
                [int(li.split('/')[-1][start:-4]) for li in model_list])
            os.remove(
                os.path.join(logs_path, '{}-{}.pth'.format(types, min_step)))
    logs_path = os.path.join(logs_path, '{}-{}.pth'.format(types, step))
    torch.save(model.state_dict(), logs_path)
    print('=> Save {} after [{}] updates'.format(logs_path, step))


class _scheduler(object):
    def __init__(self, last_epoch=-1, verbose=False):
        self.cnt = last_epoch
        self.verbose = verbose
        self.variable = None
        self.step()

    def step(self):
        self.cnt += 1
        value = self.get_value()
        self.variable = value

    def get_value(self):
        raise NotImplementedError

    def get_variable(self):
        return self.variable


class StepLRMargin(_scheduler):
    def __init__(self,
                 initValue,
                 period,
                 goalValue,
                 decay=0.1,
                 endValue=None,
                 last_epoch=-1,
                 threshold=0,
                 verbose=False):
        self.initValue = initValue
        self.period = period
        self.decay = decay
        self.endValue = endValue
        self.goalValue = goalValue
        self.threshold = threshold
        super(StepLRMargin, self).__init__(last_epoch, verbose)

    def get_value(self):
        cnt = self.cnt - self.threshold
        if cnt < 0:
            return self.initValue

        numDecay = int(cnt / self.period)
        tmpValue = self.goalValue - (self.goalValue -
                                     self.initValue) * (self.decay**numDecay)
        if self.endValue is not None and tmpValue >= self.endValue:
            return self.endValue
        return tmpValue


class Hot_Plug(object):
    def __init__(self, model):
        self.model = model
        self.params = OrderedDict(self.model.named_parameters())
    def update(self, lr=0.1):
        for param_name in self.params.keys():
            path = param_name.split('.')
            cursor = self.model
            for module_name in path[:-1]:
                cursor = cursor._modules[module_name]
            if lr > 0:
                cursor._parameters[path[-1]] = self.params[param_name] - lr*self.params[param_name].grad
            else:
                cursor._parameters[path[-1]] = self.params[param_name]
    def restore(self):
        self.update(lr=0)
