#!/usr/bin/env python3

from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import os
import sys
import math
import time
import shutil

from dataloader import get_dataloaders
from args import arg_parser
from adaptive_inference import dynamic_evaluate
import models
from op_counter import measure_model

args = arg_parser.parse_args()

# if args.gpu:
#     os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

args.grFactor = list(map(int, args.grFactor.split('-')))
args.bnFactor = list(map(int, args.bnFactor.split('-')))
args.nScales = len(args.grFactor)

if args.use_valid:
    args.splits = ['train', 'val', 'test']
else:
    args.splits = ['train', 'val']

if args.data == 'cifar10':
    args.num_classes = 10
elif args.data == 'cifar100':
    args.num_classes = 100
else:
    args.num_classes = 1000

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim

torch.manual_seed(args.seed)
def energy_cross_entropy(pred, target):
    # probability=F.softmax(pred, dim=1)#shape [num_samples,num_classes]
    # # print("prob",probability)
    # log_P=torch.log(probability)
    n = nn.LogSoftmax(dim=1)
    loss = nn.NLLLoss(reduction='none')
    log_P = n(pred)

    log_P, log_E = log_P.split([100, 1], dim=1)
    '''对输入的target标签进行 one-hot编码，使用_scatter方法'''
    # a=torch.unsqueeze(target,dim=0)
    # print("log",log_P.shape)
    # print("log_E", log_E.shape)
    # print("lable", target)
    # loss_E = loss(log_E, target)

    loss = loss(log_P, target)
    # print(loss)

    # label_E = torch.zeros(args.batch_size, dtype=int).cuda()

    # print("lable_E", label_E)
    loss_E = -log_E

    # loss = loss + weight * loss_E

    # print(loss_E)
    # print(loss)
    # one_hot = torch.zeros(probability.shape, device=target.device).scatter_(1, torch.unsqueeze(target,dim=1), 1)
    # print("one_hot", one_hot.shape)
    # loss3 = - target * log_P
    # print(loss3)
    # loss3 = loss3.sum(dim=1)
    # print(loss3)
    # loss3 = loss3.mean(0)
    return loss, loss_E


def main():
    # f_handler = open('energy.txt', 'w')
    # sys.stdout = f_handler

    global args
    best_prec1, best_epoch = 0.0, 0

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.data.startswith('cifar'):
        IM_SIZE = 32
    else:
        IM_SIZE = 224

    model = getattr(models, args.arch)(args)
    n_flops, n_params = measure_model(model, IM_SIZE, IM_SIZE)
    torch.save(n_flops, os.path.join(args.save, 'flops.pth'))
    del (model)

    model = getattr(models, args.arch)(args)

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        print("resume")
        checkpoint = load_pretrain(args)
        if checkpoint is not None:
            print("resume")
            # args.start_epoch = checkpoint['epoch'] + 1
            # best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    train_loader, val_loader, test_loader = get_dataloaders(args)

    if args.evalmode is not None:
        state_dict = torch.load(args.evaluate_from)['state_dict']
        model.load_state_dict(state_dict)

        if args.evalmode == 'anytime':
            validate(test_loader, model, criterion)
        else:
            dynamic_evaluate(model, test_loader, val_loader, args)
        return

    scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_prec1'
              '\tval_prec1\ttrain_prec5\tval_prec5']

    print('start training')

    for epoch in range(args.start_epoch, args.epochs):

        train_loss, train_prec1, train_prec5, lr = train(train_loader, model, criterion, optimizer, epoch)

        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)

        scores.append(('{}\t{:.3f}' + '\t{:.4f}' * 6)
                      .format(epoch, lr, train_loss, val_loss,
                              train_prec1, val_prec1, train_prec5, val_prec5))

        is_best = val_prec1 > best_prec1
        if is_best:
            best_prec1 = val_prec1
            best_epoch = epoch
            print('Best var_prec1 {}'.format(best_prec1))

        model_filename = 'checkpoint_%03d.pth.tar' % epoch
        save_checkpoint({
            'epoch': epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, args, is_best, model_filename, scores)

    print('Best val_prec1: {:.4f} at epoch {}'.format(best_prec1, best_epoch))

    ### Test the final model

    print('********** Final prediction results **********')
    validate(test_loader, model, criterion)

    return


def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1, top5 = [], []
    for i in range(args.nBlocks):
        top1.append(AverageMeter())
        top5.append(AverageMeter())

    # switch to train mode
    model.train()

    # batch_size = 64
    weight = args.weight
    square = 2
    sqrt = 1 / square

    end = time.time()

    running_lr = None
    for i, (input, target) in enumerate(train_loader):
        lr = adjust_learning_rate(optimizer, epoch, args, batch=i,
                                  nBatch=len(train_loader), method=args.lr_type)

        if running_lr is None:
            running_lr = lr

        data_time.update(time.time() - end)

        batch_size = input.shape[0]

        target = target.cuda(device=None)
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        output = model(input_var)
        if not isinstance(output, list):
            output = [output]

        loss = 0.0

        loss_0, loss_E_0 = energy_cross_entropy(output[0], target_var)
        loss_1, loss_E_1 = energy_cross_entropy(output[1], target_var)
        loss_2, loss_E_2 = energy_cross_entropy(output[2], target_var)
        loss_3, loss_E_3 = energy_cross_entropy(output[3], target_var)
        loss_4, loss_E_4 = energy_cross_entropy(output[4], target_var)
        loss_5, loss_E_5 = energy_cross_entropy(output[5], target_var)
        loss_6, loss_E_6 = energy_cross_entropy(output[6], target_var)
        # print('loss_E',loss_E_6.grad)


        # print(loss_E.grad)

        # print(loss_E_6.grad)

        # for k in range(batch_size):
        #     loss_batch = (loss_0[k] * (loss_E_0[k] + loss_E[k]) + loss_1[k] * (loss_E_1[k] + loss_E[k]) + loss_2[k] * (loss_E_2[k] + loss_E[k]) + loss_3[k] * (loss_E_3[k] + loss_E[k]) + loss_4[k] * (loss_E_4[k] + loss_E[k]) + loss_5[k] * (loss_E_5[k] + loss_E[k]) + loss_6[k] * (loss_E_6[k] + loss_E[k])) / (14 * loss_E[k])
        #     loss += loss_batch
        #
        # loss = loss / batch_size
        #
        # # loss_E = (loss_E_0 + loss_E_1 + loss_E_2 + loss_E_3)
        # loss_E = (torch.pow(loss_E_0, square) + torch.pow(loss_E_1, square) + torch.pow(loss_E_2, square) + torch.pow(loss_E_3, square) + torch.pow(loss_E_4, square) + torch.pow(loss_E_5, square) + torch.pow(loss_E_6, square))
        loss_e = 0.0

        loss_E_0_weight = loss_E_0
        loss_E_0_weight = loss_E_0_weight.detach()
        loss_E_0_weight = torch.exp(-loss_E_0_weight)
        loss_E_1_weight = loss_E_1
        loss_E_1_weight = loss_E_1_weight.detach()
        loss_E_1_weight = torch.exp(-loss_E_1_weight)
        loss_E_2_weight = loss_E_2
        loss_E_2_weight = loss_E_2_weight.detach()
        loss_E_2_weight = torch.exp(-loss_E_2_weight)
        loss_E_3_weight = loss_E_3
        loss_E_3_weight = loss_E_3_weight.detach()
        loss_E_3_weight = torch.exp(-loss_E_3_weight)
        loss_E_4_weight = loss_E_4
        loss_E_4_weight = loss_E_4_weight.detach()
        loss_E_4_weight = torch.exp(-loss_E_4_weight)
        loss_E_5_weight = loss_E_5
        loss_E_5_weight = loss_E_5_weight.detach()
        loss_E_5_weight = torch.exp(-loss_E_5_weight)
        loss_E_6_weight = loss_E_6
        loss_E_6_weight = loss_E_6_weight.detach()
        loss_E_6_weight = torch.exp(-loss_E_6_weight)

        loss_E = (loss_E_0_weight + loss_E_1_weight + loss_E_2_weight + loss_E_3_weight + loss_E_4_weight + loss_E_5_weight + loss_E_6_weight) / 7
        # print(loss_E.grad)
        loss_E = loss_E.detach()






        for k in range(batch_size):
            loss_batch = (loss_0[k] * (loss_E_0_weight[k] + loss_E[k]) + loss_1[k] * (loss_E_1_weight[k] + loss_E[k]) + loss_2[k] * (
                        loss_E_2_weight[k] + loss_E[k]) + loss_3[k] * (loss_E_3_weight[k] + loss_E[k]) + loss_4[k] * (
                                      loss_E_4_weight[k] + loss_E[k]) + loss_5[k] * (loss_E_5_weight[k] + loss_E[k]) + loss_6[k] * (
                                      loss_E_6_weight[k] + loss_E[k])) / (2 * loss_E[k])
            loss += loss_batch

        loss = loss / batch_size

        loss_e = (torch.pow(loss_E_0, square) + torch.pow(loss_E_1, square) + torch.pow(loss_E_2, square) + torch.pow(loss_E_3, square) + torch.pow(loss_E_4, square) + torch.pow(loss_E_5, square) + torch.pow(loss_E_6, square))
        loss_e = torch.pow(loss_e, sqrt)
        loss_e = loss_e.mean()

        loss = loss + weight * loss_e

        losses.update(loss.item(), input.size(0))

        # print(output[0].shape)
        # print(len(output))

        for j in range(len(output)):
            output_, _ = output[j].split([100, 1], dim=1)
            prec1, prec5 = accuracy(output_.data, target, topk=(1, 5))
            top1[j].update(prec1.item(), input.size(0))
            top5[j].update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.val:.4f}\t'
                  'Acc@1 {top1.val:.4f}\t'
                  'Acc@5 {top5.val:.4f}'.format(
                epoch, i + 1, len(train_loader),
                batch_time=batch_time, data_time=data_time,
                loss=losses, top1=top1[-1], top5=top5[-1]))

    return losses.avg, top1[-1].avg, top5[-1].avg, running_lr


def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    data_time = AverageMeter()
    top1, top5 = [], []
    for i in range(args.nBlocks):
        top1.append(AverageMeter())
        top5.append(AverageMeter())

    model.eval()

    square = 2
    sqrt = 1 / square
    weight = args.weight

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda(device=None)
            input = input.cuda()
            batch_size = input.shape[0]

            input_var = torch.autograd.Variable(input)
            target_var = torch.autograd.Variable(target)

            data_time.update(time.time() - end)

            output = model(input_var)
            if not isinstance(output, list):
                output = [output]

            loss = 0.0
            # for j in range(len(output)):
            #     loss += criterion(output[j], target_var)

            loss_0, loss_E_0 = energy_cross_entropy(output[0], target_var)
            loss_1, loss_E_1 = energy_cross_entropy(output[1], target_var)
            loss_2, loss_E_2 = energy_cross_entropy(output[2], target_var)
            loss_3, loss_E_3 = energy_cross_entropy(output[3], target_var)
            loss_4, loss_E_4 = energy_cross_entropy(output[4], target_var)
            loss_5, loss_E_5 = energy_cross_entropy(output[5], target_var)
            loss_6, loss_E_6 = energy_cross_entropy(output[6], target_var)

            loss_E = (loss_E_0 + loss_E_1 + loss_E_2 + loss_E_3 + loss_E_4 + loss_E_5 + loss_E_6) / 7

            for k in range(batch_size):
                loss_batch = (loss_0[k] * (loss_E_0[k] + loss_E[k]) + loss_1[k] * (loss_E_1[k] + loss_E[k]) + loss_2[
                    k] * (loss_E_2[k] + loss_E[k]) + loss_3[k] * (loss_E_3[k] + loss_E[k]) + loss_4[k] * (
                                          loss_E_4[k] + loss_E[k]) + loss_5[k] * (loss_E_5[k] + loss_E[k]) + loss_6[
                                  k] * (loss_E_6[k] + loss_E[k])) / (14 * loss_E[k])
                loss += loss_batch

            loss = loss / batch_size

            # loss_E = (loss_E_0 + loss_E_1 + loss_E_2 + loss_E_3)
            loss_E = (torch.pow(loss_E_0, square) + torch.pow(loss_E_1, square) + torch.pow(loss_E_2,
                                                                                            square) + torch.pow(
                loss_E_3, square) + torch.pow(loss_E_4, square) + torch.pow(loss_E_5, square) + torch.pow(loss_E_6,
                                                                                                          square))
            # print(loss_E)
            # loss_E = loss_E.sqrt()
            loss_E = torch.pow(loss_E, sqrt)
            # print('8352',loss_E)
            loss_E = loss_E.mean()

            loss = loss + weight * loss_E
            losses.update(loss.item(), input.size(0))

            for j in range(len(output)):
                output_, _ = output[j].split([100, 1], dim=1)
                prec1, prec5 = accuracy(output_.data, target, topk=(1, 5))
                top1[j].update(prec1.item(), input.size(0))
                top5[j].update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Epoch: [{0}/{1}]\t'
                      'Time {batch_time.avg:.3f}\t'
                      'Data {data_time.avg:.3f}\t'
                      'Loss {loss.val:.4f}\t'
                      'Acc@1 {top1.val:.4f}\t'
                      'Acc@5 {top5.val:.4f}'.format(
                    i + 1, len(val_loader),
                    batch_time=batch_time, data_time=data_time,
                    loss=losses, top1=top1[-1], top5=top5[-1]))
    for j in range(args.nBlocks):
        print(' * prec@1 {top1.avg:.3f} prec@5 {top5.avg:.3f}'.format(top1=top1[j], top5=top5[j]))
    # print(' * prec@1 {top1.avg:.3f} prec@5 {top5.avg:.3f}'.format(top1=top1[-1], top5=top5[-1]))
    return losses.avg, top1[-1].avg, top5[-1].avg


def save_checkpoint(state, args, is_best, filename, result):
    print(args)
    result_filename = os.path.join(args.save, 'scores.tsv')
    model_dir = os.path.join(args.save, 'save_models')
    latest_filename = os.path.join(model_dir, 'latest.txt')
    model_filename = os.path.join(model_dir, filename)
    best_filename = os.path.join(model_dir, 'model_best.pth.tar')
    os.makedirs(args.save, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    print("=> saving checkpoint '{}'".format(model_filename))

    torch.save(state, model_filename)

    with open(result_filename, 'w') as f:
        print('\n'.join(result), file=f)

    with open(latest_filename, 'w') as fout:
        fout.write(model_filename)
    if is_best:
        shutil.copyfile(model_filename, best_filename)

    print("=> saved checkpoint '{}'".format(model_filename))
    return


def load_checkpoint(args):
    model_dir = os.path.join(args.save, 'save_models')
    latest_filename = os.path.join(model_dir, 'latest.txt')
    if os.path.exists(latest_filename):
        with open(latest_filename, 'r') as fin:
            model_filename = fin.readlines()[0].strip()
    else:
        return None
    print("=> loading checkpoint '{}'".format(model_filename))
    state = torch.load(model_filename)
    print("=> loaded checkpoint '{}'".format(model_filename))
    return state


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def load_pretrain(args):
    model_dir = os.path.join(args.resumepath, 'save_models')
    latest_filename = os.path.join(model_dir, 'latest.txt')
    if os.path.exists(latest_filename):
        with open(latest_filename, 'r') as fin:
            model_filename = fin.readlines()[0].strip()
    else:
        return None
    print("=> loading checkpoint '{}'".format(model_filename))
    state = torch.load(model_filename)
    print("=> loaded checkpoint '{}'".format(model_filename))
    return state


def accuracy(output, target, topk=(1,)):
    """Computes the precor@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def adjust_learning_rate(optimizer, epoch, args, batch=None,
                         nBatch=None, method='multistep'):
    if method == 'cosine':
        T_total = args.epochs * nBatch
        T_cur = (epoch % args.epochs) * nBatch + batch
        lr = 0.5 * args.lr * (1 + math.cos(math.pi * T_cur / T_total))
    elif method == 'multistep':
        if args.data.startswith('cifar'):
            lr, decay_rate = args.lr, 0.1
            if epoch >= args.epochs * 0.75:
                lr *= decay_rate ** 2
            elif epoch >= args.epochs * 0.5:
                lr *= decay_rate
        else:
            lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


if __name__ == '__main__':
    main()
