import argparse
import os
import random
import time
import warnings
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import networks
from tensorboardX import SummaryWriter
from sklearn.metrics import confusion_matrix
from utils import *
from data import ImbCIFAR10, ImbCIFAR100
from losses import *

model_names = sorted(name for name in networks.__dict__
    if name.islower() and not name.startswith("__")
    and callable(networks.__dict__[name]))
print(model_names)


parser = argparse.ArgumentParser(description='PyTorch Imbalanced Classification Training')
parser.add_argument('--dataset', default='cifar100', help='dataset setting')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet34', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet32)')
parser.add_argument('--mode', default='norm', choices=['', 'norm', 'fix'], help='the mode of the last linear layer')
parser.add_argument('--loss_type', default="ArcFace", type=str, help='loss type')
parser.add_argument('--imb_type', default="step", type=str, help='imbalance type')
parser.add_argument('--imb_factor', default=0.1, type=float, help='imbalance factor')
parser.add_argument('--train_rule', default='', type=str, help='data sampling strategy for train loader')
parser.add_argument('--exp_str', default='0', type=str, help='number to indicate which experiment it is')
parser.add_argument('-s', '--scale', default=5, type=int, help='the scale of logits')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=128, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--wd', '--weight-decay', default=2e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay')
parser.add_argument('--scheduler', default='Cos', type=str, help='The scheduler')
parser.add_argument('-p', '--print-freq', default=10, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
parser.add_argument('--seed', default=123, type=int, help='seed for initializing training. ')
parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.')
parser.add_argument('--root_log',type=str, default='imblog')
parser.add_argument('--root_model', type=str, default='imbcheckpoint')
parser.add_argument('--reg', type=float, default=0, help='the weight of regularization term')
best_acc1 = 0

args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)


def norm_weights(weights):
    weights_norm = F.normalize(weights, dim=1)
    gravity = torch.mean(weights_norm, dim=0)
    return torch.sum(gravity ** 2)

def main():
    args.store_name = '_'.join([args.dataset, args.arch, args.mode, args.loss_type, str(args.scale), args.train_rule, args.imb_type, str(args.imb_factor), args.scheduler, str(args.reg)])
    args.dataset = args.dataset.lower()
    prepare_folders(args)
    warnings.warn('You have chosen to seed training.'
                  'This will turn on the CUDNN deterministic setting.'
                  'which can slow down your training considerably! '
                  'You may see unexpected behavior when restarting '
                  'from checkpoints')
    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')
    ngpus_per_node = torch.cuda.device_count()
    main_worker(args.gpu, ngpus_per_node, args)

def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu
    if args.gpu is not None:
        print('Use GPU: {} for training'.format(args.arch))
    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    mode = args.mode
    if args.loss_type in ['CE', 'Focal']:
        mode = ''
    if args.mode == 'fix' and args.dataset == 'cifar10':
        weight = torch.Tensor(np.load('./weight10x512.npy')).cuda(args.gpu)
    elif args.mode == 'fix' and args.dataset == 'cifar100':
        weight = torch.Tensor(np.load('./weight100x512.npy')).cuda(args.gpu)
    else:
        weight=None
    model = networks.__dict__[args.arch](num_classes=num_classes, mode=mode, weight=weight)
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        model = torch.nn.DataParallel(model).cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    if args.scheduler == 'LDAM':
        scheduler = LDAMScheduler(optimizer, args.lr)
    elif args.scheduler == 'Cos':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0.0)
    else:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoints '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cuda:0')
            args.start_epoch = checkpoint['epoch']
            scheduler._step_count = args.start_epoch
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    cudnn.benchmark = True

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    if args.dataset == 'cifar10':
        train_dataset = ImbCIFAR10(root='../database/CIFAR10', imb_type=args.imb_type, imb_factor=args.imb_factor, train=True, download=True, transform=transform_train, seed=args.seed)
        val_dataset = datasets.CIFAR10(root='../database/CIFAR10', train=False, download=True, transform=transform_val)
    elif args.dataset == 'cifar100':
        train_dataset = ImbCIFAR100(root='../database/CIFAR100', imb_type=args.imb_type, imb_factor=args.imb_factor, train=True, download=True, transform=transform_train, seed=args.seed)
        val_dataset = datasets.CIFAR100(root='../database/CIFAR100', train=False, download=True, transform=transform_val)
    else:
        warnings.warn('Dataset is not listed!')
        return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:', cls_num_list)
    args.cls_num_list = cls_num_list
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle= (train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True
    )
    log_training = open(os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        train_sampler = None
        per_cls_weights = None
        if args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
        elif args.train_rule == 'Reweight':
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            idx = epoch // 160
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        else:
            warnings.warn('Sample rule is not listed')
        if args.loss_type  == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list, max_m=0.5, s=args.scale, weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(gamma=2).cuda(args.gpu)
        elif args.loss_type == 'Norm':
            criterion = NormFaceLoss(scale=args.scale, weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'LMSoftmax':
            criterion = LMSoftmaxLoss(scale=args.scale, weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'CosFace':
            criterion = CosFaceLoss(scale=args.scale).cuda()
        elif args.loss_type == 'ArcFace':
            criterion = ArcFaceLoss(scale=args.scale).cuda()
        else:
            warnings.warn('Loss type is not listed!')
            return

        train(train_loader, model, criterion, optimizer, epoch, args, log_training, tf_writer)
        scheduler.step()
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing, tf_writer)

        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(args, {
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
        }, is_best)

    sis = get_margin(train_loader, model)
    tf_writer.add_histogram('similarity/train', sis)
    sis = get_margin(val_loader, model)
    tf_writer.add_histogram('similarity/test', sis)

def train(train_loader, model, criterion, optimizer, epoch, args, log, tf_writer):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    sample_margins = AverageMeter('Sample Margin', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    sample_margin = SampleMarginLoss()
    model.train()
    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        data_time.update(time.time() - end)
        if args.gpu is not None:
            input = input.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)
        embedding = model.get_body(input)
        output = model.linear(embedding)
        weight = model.get_weight()
        weight_norm = F.normalize(weight, dim=1)
        embedding_norm = F.normalize(embedding, dim=1)
        output_norm = F.linear(embedding_norm, weight_norm)
        sm = sample_margin(output_norm, target)
        if args.reg == 0:
            loss = criterion(output, target)
        else:
            loss = criterion(output, target) + args.reg * 100 * norm_weights(weight)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        sample_margins.update(sm.item(), input.size(0))
        top1.update(acc1[0], input.size(0))
        top5.update(acc5[0], input.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                      'SaMargin {sample_margin.val:.3f} ({sample_margin.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5, sample_margin=sample_margins,
                lr=optimizer.param_groups[-1]['lr']))
            print(output)
            log.write(output + '\n')
            log.flush()
    margin, ratio = model.margin()
    output = '\nEpoch [{}]:\t loss={:.4f}\t Prec@1={:.4f}\t Prec@5={:.4f}\t ClsMargin={:.4f}\t SaMargin={:.4f}\n'.format(
        epoch, losses.avg, top1.avg, top5.avg, margin, -sample_margins.avg)
    print(output)
    log.write(output)
    log.flush()
    tf_writer.add_scalar('loss/train', losses.avg, epoch)
    tf_writer.add_scalar('sample_margin/train', -sample_margins.avg, epoch)
    tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
    tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
    tf_writer.add_scalar('margin', margin, epoch)
    tf_writer.add_scalar('ratio', ratio, epoch)
    tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)


def validate(val_loader, model, criterion, epoch, args, log=None, tf_writer=None, flag='val'):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    sample_margins = AverageMeter('Sample Margin', ':.4e')
    # switch to evaluate mode
    model.eval()
    all_preds = []
    all_targets = []
    sample_margin = SampleMarginLoss()
    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            if args.gpu is not None:
                input = input.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            embedding = model.get_body(input)
            output = model.linear(embedding)
            weight = model.get_weight()
            weight_norm = F.normalize(weight, dim=1)
            embedding_norm = F.normalize(embedding, dim=1)
            output_norm = F.linear(embedding_norm, weight_norm)
            sm = sample_margin(output_norm, target)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            sample_margins.update(sm.item(), input.size(0))
            top1.update(acc1[0], input.size(0))
            top5.update(acc5[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            _, pred = torch.max(output, 1)
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

            if i % args.print_freq == 0:
                output = ('Test: [{0}/{1}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                          'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                    i, len(val_loader), batch_time=batch_time, loss=losses,
                    top1=top1, top5=top5))
                print(output)
        cf = confusion_matrix(all_targets, all_preds).astype(float)
        cls_cnt = cf.sum(axis=1)
        cls_hit = np.diag(cf)
        cls_acc = cls_hit / cls_cnt
        output = ('{flag} Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
                  .format(flag=flag, top1=top1, top5=top5, loss=losses))
        out_cls_acc = '%s Class Accuracy: %s' % (
        flag, (np.array2string(cls_acc, separator=',', formatter={'float_kind': lambda x: "%.3f" % x})))
        print(output)
        print(out_cls_acc)
        if log is not None:
            log.write(output + '\n')
            log.write(out_cls_acc + '\n')
            log.flush()

        tf_writer.add_scalar('loss/test_' + flag, losses.avg, epoch)
        tf_writer.add_scalar('sample_margin/test_' + flag, -sample_margins.avg, epoch)
        tf_writer.add_scalar('acc/test_' + flag + '_top1', top1.avg, epoch)
        tf_writer.add_scalar('acc/test_' + flag + '_top5', top5.avg, epoch)
        tf_writer.add_scalars('acc/test_' + flag + '_cls_acc', {str(i): x for i, x in enumerate(cls_acc)}, epoch)

    return top1.avg


if __name__ == '__main__':
    main()