import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.models as models
import train_models.resnet as RN
import train_models.resnet_ap as RNAP
import train_models.convnet as CN
import train_models.densenet_cifar as DN
from data import load_data, MEANS, STDS
from utils import random_indices, rand_bbox, AverageMeter, accuracy, get_time, Plotter
import time
import warnings
warnings.filterwarnings('ignore')
model_names = sorted((name for name in models.__dict__ if name.islower() and (not name.startswith('__')) and callable(models.__dict__[name])))
mean_torch = {}
std_torch = {}
for key, val in MEANS.items():
    mean_torch[key] = torch.tensor(val, device='cuda').reshape(1, len(val), 1, 1)
for key, val in STDS.items():
    std_torch[key] = torch.tensor(val, device='cuda').reshape(1, len(val), 1, 1)


def define_model(args, nclass, logger=None, size=None):
    if size == None:
        size = args.size
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, nclass, norm_type=args.norm_type, size=size, nch=args.nch)
    elif args.net_type == 'resnet_ap':
        model = RNAP.ResNetAP(args.dataset, args.depth, nclass, width=args.width, norm_type=args.norm_type, size=size, nch=args.nch)
    elif args.net_type == 'densenet':
        model = DN.densenet_cifar(nclass)
    elif args.net_type == 'convnet':
        width = int(128 * args.width)
        model = CN.ConvNet(nclass, net_norm=args.norm_type, net_depth=args.depth, net_width=width, channel=args.nch, im_size=(args.size, args.size))
    elif args.net_type == 'convnet6':
        width = int(128 * args.width)
        model = CN.ConvNet(channel=args.nch, num_classes=nclass, net_width=128, net_depth=6, net_act='relu', net_norm='instancenorm', net_pooling='avgpooling', im_size=(args.size, args.size))
    else:
        raise Exception('unknown network architecture: {}'.format(args.net_type))
    if logger is not None:
        logger(f'=> creating model {args.net_type}-{args.depth}, norm: {args.norm_type}')
    return model


def main(args, logger, repeat=1):
    if args.seed >= 0:
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        np.random.seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print('args.seed,', args.seed)
    print('spec', args.spec)
    cudnn.benchmark = True
    logger(f'ImageNet directory: {args.imagenet_dir[0]}')
    _, train_loader, val_loader, nclass = load_data(args)
    best_acc_l = []
    acc_l = []
    for i in range(repeat):
        logger(f'Repeat: {i + 1}/{repeat}')
        plotter = Plotter(args.save_dir, args.epochs, idx=i)
        model = define_model(args, nclass, logger)
        best_acc, acc = train(args, model, train_loader, val_loader, plotter, logger)
        best_acc_l.append(best_acc)
        acc_l.append(acc)
    logger(f'\n(Repeat {repeat}) Best, last acc: {np.mean(best_acc_l):.1f} {np.std(best_acc_l):.1f}')
    return best_acc


def train(args, model, train_loader, val_loader, plotter=None, logger=None):
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2 * args.epochs // 3, 5 * args.epochs // 6], gamma=0.2)
    cur_epoch, best_acc1, best_acc5, acc1, acc5 = (0, 0, 0, 0, 0)
    if args.pretrained:
        pretrained = '{}/{}'.format(args.save_dir, 'checkpoint.pth.tar')
        cur_epoch, best_acc1 = load_checkpoint(pretrained, model, optimizer)
    model = model.cuda()
    logger(f'Start training with base augmentation and {args.mixup} mixup')
    print('epochs:', args.epochs)
    for epoch in range(cur_epoch + 1, args.epochs + 1):
        acc1_tr, _, loss_tr = train_epoch(args, train_loader, model, criterion, optimizer, epoch, logger, mixup=args.mixup)
        if epoch % args.epoch_print_freq == 0:
            acc1, acc5, loss_val = validate(args, val_loader, model, criterion, epoch, logger)
            if plotter != None:
                plotter.update(epoch, acc1_tr, acc1, loss_tr, loss_val)
            is_best = acc1 > best_acc1
            if is_best:
                best_acc1 = acc1
                best_acc5 = acc5
                if logger != None and args.verbose == True:
                    logger(f'Best accuracy (top-1 and 5): {best_acc1:.1f} {best_acc5:.1f}')
        if args.save_ckpt and (is_best or epoch == args.epochs):
            state = {'epoch': epoch, 'arch': args.net_type, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'best_acc5': best_acc5, 'optimizer': optimizer.state_dict()}
            save_checkpoint(args.save_dir, state, is_best)
        scheduler.step()
    return (best_acc1, acc1)


def train_epoch(args, train_loader, model, criterion, optimizer, epoch=0, logger=None, mixup='vanilla', n_data=-1):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.train()
    end = time.time()
    num_exp = 0
    for i, (input, target) in enumerate(train_loader):
        if train_loader.device == 'cpu':
            input = input.cuda()
            target = target.cuda()
        data_time.update(time.time() - end)
        r = np.random.rand(1)
        if r < args.mix_p and mixup == 'cut':
            lam = np.random.beta(args.beta, args.beta)
            rand_index = random_indices(target, nclass=args.nclass)
            target_b = target[rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
            input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2]
            ratio = 1 - (bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])
            output = model(input)
            loss = criterion(output, target) * ratio + criterion(output, target_b) * (1.0 - ratio)
        else:
            output = model(input)
            loss = criterion(output, target)
        acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1.item(), input.size(0))
        top5.update(acc5.item(), input.size(0))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_time.update(time.time() - end)
        end = time.time()
        num_exp += len(target)
        if n_data > 0 and num_exp >= n_data:
            break
    if epoch % args.epoch_print_freq == 0 and logger is not None and (args.verbose == True):
        logger('(Train) [Epoch {0}/{1}] {2} Top1 {top1.avg:.1f}  Top5 {top5.avg:.1f}  Loss {loss.avg:.3f}'.format(epoch, args.epochs, get_time(), top1=top1, top5=top5, loss=losses))
    return (top1.avg, top5.avg, losses.avg)


def validate(args, val_loader, model, criterion, epoch, logger=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.eval()
    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        input = input.cuda()
        target = target.cuda()
        output = model(input)
        loss = criterion(output, target)
        acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1.item(), input.size(0))
        top5.update(acc5.item(), input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()
    if logger is not None and args.verbose == True:
        logger('(Test ) [Epoch {0}/{1}] {2} Top1 {top1.avg:.1f}  Top5 {top5.avg:.1f}  Loss {loss.avg:.3f}'.format(epoch, args.epochs, get_time(), top1=top1, top5=top5, loss=losses))
    return (top1.avg, top5.avg, losses.avg)


def load_checkpoint(path, model, optimizer):
    if os.path.isfile(path):
        print("=> loading checkpoint '{}'".format(path))
        checkpoint = torch.load(path)
        checkpoint['state_dict'] = dict(((key[7:], value) for key, value in checkpoint['state_dict'].items()))
        model.load_state_dict(checkpoint['state_dict'])
        cur_epoch = checkpoint['epoch']
        best_acc1 = checkpoint['best_acc1']
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}'(epoch: {}, best acc1: {}%)".format(path, cur_epoch, checkpoint['best_acc1']))
    else:
        print("=> no checkpoint found at '{}'".format(path))
        cur_epoch = 0
        best_acc1 = 100
    return (cur_epoch, best_acc1)


def save_checkpoint(save_dir, state, is_best):
    os.makedirs(save_dir, exist_ok=True)
    if is_best:
        ckpt_path = os.path.join(save_dir, 'model_best.pth.tar')
    else:
        ckpt_path = os.path.join(save_dir, 'checkpoint.pth.tar')
    torch.save(state, ckpt_path)
    print('checkpoint saved! ', ckpt_path)

    
if __name__ == '__main__':
    from misc.utils import Logger
    from argument import args
    os.makedirs(args.save_dir, exist_ok=True)
    logger = Logger(args.save_dir)
    logger(f'Save dir: {args.save_dir}')
    print(args.imagenet_dir[0])
    seeds = list(range(3))
    print('seeds:', seeds)
    accs = []
    for s in seeds:
        args.seed = s
        cur_best = main(args, logger, args.repeat)
        accs.append(cur_best)
    print('###########################')
    print(args.imagenet_dir[0])
    print('spec: %s, net: %s, depth: %s, ipc: %s:' % (args.spec, args.net_type, args.depth, args.ipc))
    print('mean: %s, std: %s' % (np.mean(accs), np.std(accs)))
    print('###########################')