'''
Training script for CIFAR-100
'''
from __future__ import print_function

import argparse
import os
import shutil
import time
import random

import torch
import torch.nn as nn
import torch.nn.parallel

from torch.autograd import Variable
from utils import  AverageMeter, accuracy
import models.hierarchy as models
from get_loaders import get_loaders
import geoopt  ## user should installl this package from "https://github.com/geoopt/geoopt"

'''
python version 3.6.6
pytorch version >1.0

1) resnet18
python train_sphere_optimization.py  --optim-type sgd --workers 10  --arch resnet18 --dataset cifar100  --train-batch 64  --test-batch 1024  --epochs 300    --wd 0.000100  --wdfc 0.000100  --radius_decay 0.500000                             
python train_sphere_optimization.py  --optim-type sgd --workers 10  --arch resnet18 --dataset cifar100  --train-batch 64  --test-batch 1024  --epochs 300  --hierarchy  --multi_tasks  --wd 0.000100  --wdfc 0.000100  --radius_decay 0.500000   
python train_sphere_optimization.py  --optim-type sgd --workers 10  --arch resnet18 --dataset cifar100  --train-batch 64  --test-batch 1024  --epochs 300  --hierarchy  --manifold  --wd 0.000100  --wdfc 0.000100  --radius_decay 0.500000      
python train_sphere_optimization.py  --optim-type sgd --workers 10  --arch resnet18 --dataset cifar100  --train-batch 64  --test-batch 1024  --epochs 300  --hierarchy  --manifold  --spheres  --wd 0.000100  --radius_decay 0.500000            
python train_sphere_optimization.py  --optim-type sgd --workers 10  --arch resnet18 --dataset cifar100  --train-batch 64  --test-batch 1024  --epochs 300  --hierarchy  --manifold  --spheres  --riemann  --wd 0.000100  --radius_decay 0.500000

2) densenet121                                                                                                                                                                                                                                         
python train_sphere_optimization.py  --optim-type sgd --workers 10  --arch densenet121 --dataset cifar100  --train-batch 64  --test-batch 1024  --epochs 300    --wd 0.000100  --wdfc 0.000100  --radius_decay 0.500000                             
python train_sphere_optimization.py  --optim-type sgd --workers 10  --arch densenet121 --dataset cifar100  --train-batch 64  --test-batch 1024  --epochs 300  --hierarchy  --multi_tasks  --wd 0.000100  --wdfc 0.000100  --radius_decay 0.500000   
python train_sphere_optimization.py  --optim-type sgd --workers 10  --arch densenet121 --dataset cifar100  --train-batch 64  --test-batch 1024  --epochs 300  --hierarchy  --manifold  --wd 0.000100  --wdfc 0.000100  --radius_decay 0.500000      
python train_sphere_optimization.py  --optim-type sgd --workers 10  --arch densenet121 --dataset cifar100  --train-batch 64  --test-batch 1024  --epochs 300  --hierarchy  --manifold  --spheres  --wd 0.000100  --radius_decay 0.500000            
python train_sphere_optimization.py  --optim-type sgd --workers 10  --arch densenet121 --dataset cifar100  --train-batch 64  --test-batch 1024  --epochs 300  --hierarchy  --manifold  --spheres  --riemann  --wd 0.000100  --radius_decay 0.500000 


import geoopt  ## user should install this package from "https://github.com/geoopt/geoopt" 
'''

parser = argparse.ArgumentParser(description='PyTorch CIFAR100 Training')
# Datasets
parser.add_argument('-d', '--dataset', default='cifar100', type=str)
parser.add_argument('-j', '--workers', default=1, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
# Optimization options
parser.add_argument('--epochs', default=300, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--train-batch', default=128, type=int, metavar='N',
                    help='train batchsize')
parser.add_argument('--test-batch', default=500, type=int, metavar='N',
                    help='test batchsize')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--lr_m', '--learning-rate-man', default=0.1, type=float,
                    metavar='LR', help='initial learning rate for manifold')
parser.add_argument('--drop', '--dropout', default=0, type=float,
                    metavar='Dropout', help='Dropout ratio')
parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225],
                    help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=0, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--weight-decay-fc', '--wdfc', default=0, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')

# Architecture
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet')

parser.add_argument('--block-name', type=str, default='BasicBlock',
                    help='the building block for Resnet and Preresnet: BasicBlock, Bottleneck (default: Basicblock for cifar10/cifar100)')

parser.add_argument('--optim-type', default='sgd', choices=['sgd'])

# method
parser.add_argument('--hierarchy', dest='hierarchy', default=False, action='store_true')
parser.add_argument('--manifold', dest='manifold', default=False, action='store_true')
parser.add_argument('--riemann', dest='riemann', default=False, action='store_true')
parser.add_argument('--spheres', dest='spheres', default=False, action='store_true')
parser.add_argument('--spheres_exact', dest='spheres_exact', default=False, action='store_true')
parser.add_argument('--radius_decay', type=float,  default=0.5 )
parser.add_argument('--multi_tasks', dest='multi_tasks', default=False, action='store_true')
parser.add_argument('--lambda_mt', dest='lambda_mt', help='combination of losses for multitasks', type=float, default=0.1)
parser.add_argument('--drops', dest='drops', default=False, action='store_true')

# Miscs
parser.add_argument('--manualSeed', default=1, type=int, help='manual seed')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')

# optim for adam
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--beta2', type=float, default=0.999)

# Device options
parser.add_argument('--gpu-id', default='0', type=str,
                    help='id(s) for CUDA_VISIBLE_DEVICES')
parser.add_argument('--input_size', type=int, default=224, help='input size')

args = parser.parse_args()
state = {k: v for k, v in args._get_kwargs()}

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))


# Validate dataset
print(args.dataset)

assert  args.dataset == 'cifar100' , 'Dataset cifar100 for this example code'

# Use CUDA
# os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
use_cuda = torch.cuda.is_available()
ncuda = torch.cuda.device_count()
print("Let's use", ncuda, "GPUs!")
if use_cuda:
    print('using gpu')
    device = torch.device('cuda')
else:
    print('using cpu')
    device = torch.device('cpu')
    print('exit..')
    exit()
	

print(torch.cuda.get_device_name(torch.cuda.current_device()))
# Random seed
if args.manualSeed is None:
    args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if use_cuda:
    torch.cuda.manual_seed_all(args.manualSeed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


################################################################33
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if len(param.shape) == 1 or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]


best_acc = 0  # best test accuracy


def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    # Data
    print('==> Preparing dataset %s' % args.dataset)

    ## GET_LOADERS
    #your db!!!
    datapath = 'c:/datasets'


    infopath = './'
    trainloader, testloader, num_classes, hierarchy_tree = get_loaders(args.dataset, args.input_size, args.hierarchy, infopath, datapath, args.train_batch, args.test_batch, args.workers)

    args.num_classes = num_classes
    print(num_classes)


    if args.multi_tasks:
        if not args.hierarchy:
            raise ValueError('Hierarchy cannot be false with multi_tasks')

        if args.riemann or args.spheres or args.spheres_exact:
            raise ValueError('multi tasks cannot be true with riemaniann / sphere / sphere_exact!')

    # Model
    print("==> creating model '{}'".format(args.arch))

    input_size = args.input_size
    if (args.dataset == 'cifar100'):
        input_size = 32
		
    if args.arch.startswith('resnet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            hierarchy=args.hierarchy,
            manifold=args.manifold,
            riemann=args.riemann,
            spheres=args.spheres,
            spheres_exact=args.spheres_exact,
            radius_decay=args.radius_decay,
            insize=input_size,
            hierarchy_tree=hierarchy_tree,
            multi_tasks=args.multi_tasks
        )
    elif args.arch.startswith('densenet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            hierarchy=args.hierarchy,
            manifold=args.manifold,
            riemann=args.riemann,
            spheres=args.spheres,
            spheres_exact=args.spheres_exact,
            radius_decay=args.radius_decay,
            insize=input_size,
            hierarchy_tree=hierarchy_tree,
            multi_tasks=args.multi_tasks)


    criterion = nn.CrossEntropyLoss()


    model = torch.nn.DataParallel(model)
    model.to(device)

    model.module.fc = model.module.fc.to(device)
    if model.module.hierarchical_layer is not None:
        model.module.hierarchical_layer = model.module.hierarchical_layer.to(device)
    
    for name, param in model.named_parameters():
        print(name)

    # cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))
    if args.spheres_exact:
        args.spheres=True

    # Remove model.fc from standard sgd
    embeding_param = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            if "fc" not in name:
                embeding_param.append(param)
    
    if args.optim_type == 'sgd':
        optimizer = torch.optim.SGD(embeding_param, lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
                                
        if args.hierarchy:
            if args.manifold:
                if args.spheres:
                    args.weight_decay_fc = 0.0
                    
                if args.riemann:
                    optimizer_m = geoopt.optim.RiemannianSGD(
                        [{'params': model.module.fc, 'lr': args.lr},],
                        args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay_fc)
                else:
                    optimizer_m = geoopt.optim.RiemannianSGD(
                        [{'params': model.module.fc, 'lr': args.lr},],
                        lr=args.lr, momentum=args.momentum,
			            weight_decay=args.weight_decay_fc)

            else:
                optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                            momentum=args.momentum,
                                            weight_decay=args.weight_decay)
                optimizer_m =  geoopt.optim.RiemannianSGD([{'params': model.module.fc, 'lr': args.lr}], lr=args.lr,
                                              momentum=args.momentum, weight_decay=args.weight_decay_fc)
        else:
            optimizer_m = geoopt.optim.RiemannianSGD(
                [{'params': model.module.fc, 'lr': args.lr}],  # model.parameters(),
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay_fc)


    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):

        adjust_learning_rate(optimizer, optimizer_m, epoch)

        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr']))

        ### CLASSIFICATION
        train_loss, train_acc = train(trainloader, model, criterion, optimizer, optimizer_m, epoch, use_cuda, hierarchy_tree, args.lambda_mt)
        ## test
        test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda)

        best_acc = max(test_acc, best_acc)

        print(' Best acc ever: %.4f' % (best_acc.cpu().numpy()))

    print('Best acc: %.4f \n' % (best_acc.cpu().numpy()))



def train(trainloader, model, criterion, optimizer, optimizer_m, epoch, use_cuda, hierarchy_tree=None, lambda_mt=0):
    # switch to train mode
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()
	


    for batch_idx, data_tuple in enumerate(trainloader):
        inputs, targets = data_tuple
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()  # (async=True)
        inputs, targets = Variable(inputs), Variable(targets)

        outputs = model(inputs)
        loss = criterion(outputs, targets) 

        if args.multi_tasks:
            parent_target = hierarchy_tree.child_parent_pairs_tensor[targets]
            loss += args.lambda_mt * criterion(outputs, parent_target.cuda())

        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.data, inputs.size(0))

        top1.update(prec1, inputs.size(0))
        top5.update(prec5, inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        optimizer_m.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        optimizer_m.step()


        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        if batch_idx % 100 == 0:
            bar = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                    batch=batch_idx + 1,
                    size=len(trainloader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
            )
            print(bar)

    return (losses.avg, top1.avg)


def test(testloader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()


    for batch_idx, data_tuple in enumerate(testloader):

        inputs, targets = data_tuple
        # measure data loading time
        data_time.update(time.time() - end)
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(inputs, requires_grad=False), torch.autograd.Variable(targets)


        # compute output
        outputs = model(inputs)

        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.data, inputs.size(0))
        top1.update(prec1, inputs.size(0))
        top5.update(prec5, inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % 300 == 0:
            bar = ' Test: ({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Loss: {loss:.4f}  | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                batch=batch_idx + 1,
                size=len(testloader),
                data=data_time.avg,
                bt=batch_time.avg,
                loss=losses.avg,
                top1=top1.avg,
                top5=top5.avg,
            )
            # print('test: \n', bar)
            print(bar)


    return (losses.avg, top1.avg)


def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))


def adjust_learning_rate(optimizer, optimizer_m, epoch):
    global state
    if epoch in args.schedule:
        state['lr'] *= args.gamma

    for param_group in optimizer.param_groups:
        param_group['lr'] = state['lr']

    for param_group in optimizer_m.param_groups:
        param_group['lr'] = state['lr']


if __name__ == '__main__':
    main()
