from __future__ import print_function
from inspect import istraceback
import sys
import os
import shutil
import time
import argparse
import logging
import hashlib
import copy

import torch
import torch.nn as nn

import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn

import sparselearning
# from sparselearning.core_channel import Masking, CosineDecay, LinearDecay, add_sparse_args
#from sparselearning.core_unstructured import Masking, CosineDecay, LinearDecay, add_sparse_args
from sparselearning.core_group import Masking, CosineDecay, LinearDecay, add_sparse_args

from sparselearning.utils import get_mnist_dataloaders, get_cifar10_dataloaders, plot_class_feature_histograms, get_cifar100_dataloaders

from models import vgg16, vgg19, resnet

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
cudnn.benchmark = True
cudnn.deterministic = True



logger = None





def save_checkpoint(state, filename='checkpoint.pth.tar'):
    print("SAVING")
    torch.save(state, filename)

def setup_logger(args, save_path):
    global logger
    if logger == None:
        logger = logging.getLogger()
    else:  # wish there was a logger.close()
        for handler in logger.handlers[:]:  # make a copy of the list
            logger.removeHandler(handler)

    args_copy = copy.deepcopy(args)
    # copy to get a clean hash
    # use the same log file hash if iterations or verbose are different
    # these flags do not change the results
    args_copy.iters = 1
    args_copy.verbose = False
    args_copy.log_interval = 1
    args_copy.seed = 0

    log_path = save_path + '/{0}_{1}_{2}.log'.format(args.model, args.density, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8])

    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt='%(asctime)s: %(message)s', datefmt='%H:%M:%S')

    fh = logging.FileHandler(log_path)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

def print_and_log(msg):
    global logger
    print(msg)
    logger.info(msg)


def train(args, model, device, train_loader, optimizer, epoch, mask=None):
    # training one epoch
    criterion = nn.CrossEntropyLoss()

    model.train()
    train_loss = 0
    correct = 0
    n = 0
    for batch_idx, (data, target) in enumerate(train_loader):

        data, target = data.to(device), target.to(device)
        if args.fp16: data = data.half()
        optimizer.zero_grad()
        output = model(data)            # logits before softmax

        loss = criterion(output, target)

        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        n += target.shape[0]

        if args.fp16:
            optimizer.backward(loss)
        else:
            loss.backward()

        # ======================== DST ========================
        if mask is not None: 
            mask.step()             # DST mask update
        else: 
            optimizer.step()

        if batch_idx % args.log_interval == 0:
            print_and_log('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} Accuracy: {}/{} ({:.3f}%) '.format(
                epoch, batch_idx * len(data), len(train_loader)*args.batch_size,
                100. * batch_idx / len(train_loader), loss.item(), correct, n, 100. * correct / float(n)))

    # training summary
    print_and_log('** Training summary: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%) **'.format(
        train_loss/batch_idx, correct, n, 100. * correct / float(n)))


def evaluate(args, model, device, test_loader, is_test_set=False):
    criterion = nn.CrossEntropyLoss()

    model.eval()
    test_loss = 0
    correct = 0
    n = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            if args.fp16: data = data.half()
            # model.t = target
            output = model(data)
            # test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            test_loss = criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            n += target.shape[0]

    test_loss /= float(n)

    print_and_log('** {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%) ** '.format(
        'Test evaluation' if is_test_set else 'Evaluation',
        test_loss, correct, n, 100. * correct / float(n)))
    return correct / float(n)


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 100)')
    parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
                        help='input batch size for testing (default: 100)')
    parser.add_argument('--multiplier', type=int, default=1, metavar='N',
                        help='extend training time by multiplier times')
    parser.add_argument('--epochs', type=int, default=160, metavar='N',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=17, metavar='S', help='random seed (default: 17)')
    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                        help='how many batches to wait before logging training status')
    
    parser.add_argument('--optimizer', type=str, default='sgd', help='The optimizer to use. Default: sgd. Options: sgd, adam.')
    parser.add_argument('--data', type=str, default='mnist')
    parser.add_argument('--l1', type=float, default=0.0)
    parser.add_argument('--fp16', action='store_true', help='Run in fp16 mode.')
    parser.add_argument('--valid_split', type=float, default=0.1)
    parser.add_argument('--resume', type=str)
    parser.add_argument('--start-epoch', type=int, default=1)
    parser.add_argument('--model', type=str, default='')
    parser.add_argument('--l2', type=float, default=5e-4)
    parser.add_argument('--iters', type=int, default=1, help='How many times the model should be run after each other. Default=1')
    parser.add_argument('--save-features', action='store_true', help='Resumes a saved model and saves its feature data to disk for plotting.')
    parser.add_argument('--bench', action='store_true', help='Enables the benchmarking of layers and estimates sparse speedups')
    parser.add_argument('--max-threads', type=int, default=10, help='How many threads to use for data loading.')

    parser.add_argument('--nolrsche', action='store_true', default=False, help='disables CUDA training')
    parser.add_argument('--no_rewire_extend', action='store_true', default=False, help='if ture, only do rewire for 250 epoochs')
    parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',help='number of data loading workers (default: 10)')
    parser.add_argument('--world-size', default=-1, type=int,help='number of nodes for distributed training')
    parser.add_argument('--mgpu', action='store_true', help='Enable snip initialization. Default: True.')
    add_sparse_args(parser)

    args = parser.parse_args()

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> output dir >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    save_dir = 'results/{}_{}/'.format(args.data, args.model)
    if args.sparse:
        if args.fix:
            save_dir = save_dir + 'SST_density{}_init{}_seed{}'.format(args.density, args.sparse_init, args.seed)
        else:
            save_dir = save_dir + 'DST_density{}_init{}_grow{}_seed{}'.format(args.density, args.sparse_init, args.growth, args.seed)
    else:
        save_dir = save_dir + 'Dense'
    os.makedirs(save_dir, exist_ok=True)
    # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

    # Log 
    setup_logger(args, save_path=save_dir)
    print_and_log(args)

    # GPU
    if args.fp16:
        try:
            from apex.fp16_utils import FP16_Optimizer
        except:
            print('WARNING: apex not installed, ignoring --fp16 option')
            args.fp16 = False

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(f'device will be chosen as {device} for this run.')

    print_and_log('='*80)
    torch.manual_seed(args.seed)
    # for i in range(args.iters):
    for i in range(1):      # Sparse Training only Once 
        print_and_log("\nIteration start: {0}/{1}\n".format(i+1, args.iters))

        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Data loader >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        if args.data == 'mnist':
            train_loader, valid_loader, test_loader = get_mnist_dataloaders(args, validation_split=args.valid_split)
        elif args.data == 'cifar10':
            train_loader, valid_loader, test_loader = get_cifar10_dataloaders(args, args.valid_split, max_threads=args.max_threads)
        elif args.data == 'cifar100':
            train_loader, valid_loader, test_loader = get_cifar100_dataloaders(args, args.valid_split, max_threads=args.max_threads)
        
        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Create Model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        if args.model.startswith('resnet'):     # resnet32 (2x)
            model = resnet(depth=int(args.model[6:]), dataset=args.data)
        elif args.model == 'vgg16':
            model = vgg16(dataset=args.data)
        elif args.model == 'vgg19':
            model = vgg19(dataset=args.data)
        else:
            raise Exception('You need to select a model')
        model = model.to(device)
        print_and_log(model)
        print_and_log('=' * 60)
        print_and_log(args.model)
        print_and_log('=' * 60)

        if args.sparse and (not args.fix):
            print_and_log('=' * 60)
            print_and_log('Prune mode: {0}'.format(args.death))
            print_and_log('Growth mode: {0}'.format(args.growth))
            print_and_log('Redistribution mode: {0}'.format(args.redistribution))
            print_and_log('=' * 60)

        if args.mgpu:
            print('Using multi gpus')
            model = torch.nn.DataParallel(model).to(device)

        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> OPT >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        optimizer = None
        if args.optimizer == 'sgd':
            optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.l2)
        elif args.optimizer == 'adam':
            optimizer = optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.l2)
        else:
            print('Unknown optimizer: {0}'.format(args.optimizer))
            raise Exception('Unknown optimizer.')

        if args.nolrsche:
            lr_scheduler = None
        else:
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(args.epochs/2), int(args.epochs*3/4)], last_epoch=-1)
        
        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Model resume >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        if args.resume:
            if os.path.isfile(args.resume):
                print_and_log("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print_and_log("=> loaded checkpoint '{}' (epoch {})"
                      .format(args.resume, checkpoint['epoch']))
                print_and_log('Testing...')
                evaluate(args, model, device, test_loader)
                model.feats = []
                model.densities = []
                plot_class_feature_histograms(args, model, device, train_loader, optimizer)
            else:
                print_and_log("=> no checkpoint found at '{}'".format(args.resume))


        if args.fp16:
            print('FP16')
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale = None,
                                       dynamic_loss_scale = True,
                                       dynamic_loss_args = {'init_scale': 2 ** 16})
            model = model.half()


        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> DST mask >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        mask = None
        if args.sparse:
            print_and_log('Create DST mask ...')
            decay = CosineDecay(args.death_rate, len(train_loader)* int((args.epochs) * args.end_ratio) )  # revise here
            mask = Masking(optimizer,death_rate=args.death_rate, death_mode=args.death, death_rate_decay=decay, growth_mode=args.growth,
                           redistribution_mode=args.redistribution, args=args)
            mask.add_module(model, sparse_init=args.sparse_init, density=args.density)

        best_acc = 0.0
        best_epoch = -1


        # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        for epoch in range(1, args.epochs*args.multiplier + 1):
            print_and_log('>>>>>>>>>>>>>>>>>>>> EPOCH: {} >>>>>>>>>>>>>>>>>>>'.format(epoch))
            t0 = time.time()
            # >>>>>>>>>>>>>>> training one epoch >>>>>>>>>>>>>>>
            train(args, model, device, train_loader, optimizer, epoch, mask)

            if lr_scheduler: 
                lr_scheduler.step()

            # >>>>>>>>>>>>>>> evaluation >>>>>>>>>>>>>>>
            if args.valid_split > 0.0:
                val_acc = evaluate(args, model, device, valid_loader, is_test_set=False)
            
            # >>>>>>>>>>>>>>> save model >>>>>>>>>>>>>>>
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, filename=os.path.join(save_dir, 'model_final.pth'))
            
            # >>>>>>>>>>>>>>> best model >>>>>>>>>>>>>>>
            if val_acc > best_acc:
                best_acc = val_acc
                best_epoch = epoch 
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, filename=os.path.join(save_dir, 'model_best.pth'))

            # >>>>>>>>>>>>>>> Log of Results and LR >>>>>>>>>>>>>>>
            print_and_log('So far best ACC: {:.3f} in Epoch {}'.format(best_acc * 100.0, best_epoch))
            print_and_log('Current LR: {0}. Time taken for epoch: {1:.2f} seconds.\n'.format(optimizer.param_groups[0]['lr'], time.time() - t0))
        

        # # >>>>>>>>>>>>>>> evaluation >>>>>>>>>>>>>>>
        print_and_log("==> Evaluate on Test Dataset ==> ")
        test_acc = evaluate(args, model, device, test_loader, is_test_set=True)
        print_and_log("\nIteration end: {0}/{1}\n".format(i+1, args.iters))


if __name__ == '__main__':
   main()
