import torch
import os
import shutil
import pandas as pd


def adjust_lr_lambda(args,  epoch, optimizer):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    epoch = epoch + 1
    if args.epochs == 120:
        if epoch > 90:
            lr = args.lr * 0.1
            args.lamda = 0
        else:
            lr = args.lr
            args.lamda = 0
    elif args.epochs  == 200:
        if epoch <= 5:
            lr = args.lr * epoch / 5
            args.lamda = 0
        elif epoch >= 180:
            lr = args.lr * 0.0001
            args.lamda = 0
        elif epoch >= 160:
            lr = args.lr * 0.01
            args.lamda = 0
        else:
            lr = args.lr
            args.lamda = 0
    else:
        if epoch <= 5:
            args.lamda = 0
            lr = args.lr * epoch / 5
        elif epoch >= 60:
             lr = args.lr * 0.1
             args.lamda = 0
        else:
             lr = args.lr
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


    return lr






def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.float().topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))


    res = []
    for k in topk:
        correct_k = correct[:k].float().sum()
        res.append(correct_k.mul_(100.0 / batch_size))
    return res




class ResultsLog(object):
    def __init__(self, path='results.csv', plot_path=None):
        self.path = path
        self.plot_path = plot_path or (self.path + '.html')
        self.figures = []
        self.results = None


    def add(self, **kwargs):
        df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
        if self.results is None:
            self.results = df
        else:
            # self.results = self.results.append(df, ignore_index=True) #cy
            self.results = pd.concat([self.results, df], axis=0, join='outer', ignore_index=True)


    def save(self, title='Training Results'):
        if len(self.figures) > 0:
            if os.path.isfile(self.plot_path):
                os.remove(self.plot_path)
            # output_file(self.plot_path, title=title)
            # plot = column(*self.figures)
            # save(plot)
            # self.figures = []
        self.results.to_csv(self.path, index=False, index_label=False)


    def load(self, path=None):
        path = path or self.path
        if os.path.isfile(path):
            self.results.read_csv(path)


    def show(self):
        pass
        # if len(self.figures) > 0:
        #     plot = column(*self.figures)
        #     show(plot)


    def plot(self, xs, ys, *kargs, **kwargs):
        pass
        # fig = figure(*kargs, **kwargs)
        # x_list_value = self.results[xs].values.T.tolist()
        # y_list_value = self.results[ys].values.T.tolist()
        # fig.multi_line(xs = x_list_value, ys = y_list_value)
        # self.figures.append(fig)


    def image(self, *kargs, **kwargs):
        pass
        # fig = figure()
        # fig.image(*kargs, **kwargs)
        # self.figures.append(fig)




class AverageMeter(object):
    """Computes and stores the average and current value"""


    def __init__(self):
        self.reset()


    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0


    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count




def model_resume(args, model, model_new):
    checkpoint = torch.load(model) #, map_location='cuda:0'
    best_acc1 = checkpoint['best_acc1']
    args.resumed_epoch = checkpoint['best_acc1']
    model_new.module.load_state_dict(checkpoint['state_dict'])
    #optimizer.load_state_dict(checkpoint['optimizer'])
    print("=> loaded checkpoint '{}' (epoch {}), acc {}"
          .format(args.resume, args.resumed_epoch, best_acc1))




def save_checkpoint_epoch(state, is_best, path='.',):
    filename = os.path.join(path, '%s-th_epoch_checkpoint.pth.tar' % state['epoch'])
    torch.save(state, filename)
    if is_best:
        shutil.move(filename, 'model_best.pth.tar')


def load_checkpoint_epoch(epoch, path='.'):
    filename = os.path.join(path, '%s-th_epoch_checkpoint.pth.tar' % str(epoch))
    return torch.load(filename)


def load_checkpoint_iter(iter, path='.'):
    filename = os.path.join(path, '%s-th_iter_checkpoint.pth.tar' % str(iter))
    return torch.load(filename)




def load_checkpoint_best(path='.'):
    filename = os.path.join(path, 'model_best.pth.tar')
    return torch.load(filename)
