import argparse
import logging
import os
import pickle
import time
from collections import OrderedDict
import joblib
import numpy as np
import pandas as pd
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.optim import lr_scheduler
from tqdm import tqdm
import random

from auto_augment import AutoAugmentBO, Cutout, AutoAugment
from utils import *
from wide_resnet import SimpleNet, WideResNet
from config import init_logging


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--option", default="search", choices=["search", "train"])
    parser.add_argument('--data_set', default='cifar10', choices=['cifar10', 'cifar100'], help='dataset name')
    parser.add_argument('--depth', default=28, type=int)
    parser.add_argument('--width', default=10, type=int)
    parser.add_argument('--cutout', default=True, type=str2bool)
    # parser.add_argument('--auto-augment', default=False, type=str2bool)
    # parser.add_argument('--auto-augmentbo', default=False, type=str2bool)
    parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float)
    parser.add_argument('--milestones', default='60,120,160', type=str)
    parser.add_argument('--gamma', default=0.2, type=float)
    parser.add_argument('--momentum', default=0.9, type=float)
    parser.add_argument('--weight-decay', default=5e-4, type=float)
    parser.add_argument('--nesterov', default=False, type=str2bool)

    # parser.add_argument('--shared_directory', default="/cache/exps/pytorch-auto-augment-boss", type=str)

    # parser.add_argument('--id', type=str)
    # parser.add_argument("--work_directory", type=str)

    parser.add_argument("--batch_size", default=128)  # 128
    parser.add_argument('--name', default="wideresnet", choices=["wideresnet", "simplenet"])
    parser.add_argument("--auto_augment", default="BOSS", choices=["EQI", "Basic", "BOSS", "BOHB", "HB", "SH", "Random", "BO"], type=str)
    parser.add_argument('--epochs', default=200, type=int)    # 相当于budget
    parser.add_argument("--shared_dir", default="/cache/exps/pytorch_boss_autoaugment_cifar10/shared/0_0", type=str)
    parser.add_argument("--code_dir", default="/home/ma-user/work/pytorch_boss_darts")
    parser.add_argument("--log_dir", default="/cache/exps/pytorch_boss_autoaugment_cifar10/log/0_0")
    parser.add_argument("--device", default="cuda:0", type=str)

    args = parser.parse_args()

    return args


args = parse_args()

init_logging(exp_dir=args.log_dir, config_path=os.path.join(args.code_dir, "logging_config.yaml"))
logger = logging.getLogger(__name__)

logger.debug(f"-------- cifar10_task with epochs {args.epochs} by {args.auto_augment} --------")
for k, v in vars(args).items():
    logger.debug("%s: %s" % (k, v))


def train(args, train_loader, model, criterion, optimizer, epoch, scheduler=None):
    losses = AverageMeter()
    scores = AverageMeter()

    model.train()

    # for i, (input, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
    for i, (input, target) in enumerate(train_loader):
        # from original paper's appendix
        # input = input.cuda()
        # target = target.cuda()

        input = input.to(args.device)
        target = target.to(args.device)

        output = model(input)
        loss = criterion(output, target)

        acc = accuracy(output, target)[0]

        losses.update(loss.item(), input.size(0))
        scores.update(acc.item(), input.size(0))

        # compute gradient and do optimizing step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    log = OrderedDict([
        ('loss', losses.avg),
        ('acc', scores.avg),
    ])

    return log


def validate(args, val_loader, model, criterion):
    losses = AverageMeter()
    scores = AverageMeter()
    scores_top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        # for i, (input, target) in tqdm(enumerate(val_loader), total=len(val_loader)):
        for i, (input, target) in enumerate(val_loader):
            # input = input.cuda()
            # target = target.cuda()

            input = input.to(args.device)
            target = target.to(args.device)

            output = model(input)
            loss = criterion(output, target)

            acc1, acc5 = accuracy(output, target, topk=(1, 5))

            losses.update(loss.item(), input.size(0))
            scores.update(acc1.item(), input.size(0))
            scores_top5.update(acc5.item(), input.size(0))

    log = OrderedDict([
        ('loss', losses.avg),
        ('acc', scores.avg),
        ('acc5', scores_top5.avg),
    ])

    return log


def main():

    with open(os.path.join(args.shared_dir, f"config.pkl"), "rb") as f:
        config = pickle.load(f)

    model_name = args.name
    if args.name == "wideresnet":
        args.name = '%s_WideResNet%s-%s' % (args.data_set, args.depth, args.width)
    elif args.name == "simplenet":
        args.name = "%s_SimpleNet" % args.data_set
    else:
        logger.error("wrong model_name: %s" % model_name)
        assert RuntimeError("wrong model_name: %s" % model_name)

    if args.cutout:
        args.name += '_wCutout'

    args.name += f"_wAutoAugmetn{args.auto_augment}"

    criterion = nn.CrossEntropyLoss().to(args.device)

    cudnn.benchmark = True

    # data loading code
    if args.data_set == 'cifar10':
        transform_train = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
        ]

        if args.auto_augment == "Basic":
            transform_train.append(AutoAugment())
        else:
            transform_train.append(AutoAugmentBO(config))

        if args.cutout:
            transform_train.append(Cutout())
        transform_train.extend([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transform_train = transforms.Compose(transform_train)

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_set = datasets.CIFAR10(
            root='/Users/liyujun/Programs/data/CIFAR10',
            train=True,
            download=True,
            transform=transform_train)

        if args.option == "search":
            n_train = len(train_set)

            split = 4000
            indices = list(range(n_train))
            random.shuffle(indices)

            train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
            train_loader = torch.utils.data.DataLoader(
                train_set,
                batch_size=args.batch_size,
                sampler=train_sampler,
                num_workers=8 if torch.cuda.is_available() else 0,
                pin_memory=True)
        elif args.option == "train":
            train_loader = torch.utils.data.DataLoader(
                train_set,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=8 if torch.cuda.is_available() else 0,
                pin_memory=True)
        else:
            raise RuntimeError("wrong option in cifar10_task.py!!!!!!!!!!!!!!!")

        logger.info("train_loader size: %s" % len(train_loader))

        test_set = datasets.CIFAR10(
            root='/Users/liyujun/Programs/data/CIFAR10',
            train=False,
            download=True,
            transform=transform_test)

        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=8 if torch.cuda.is_available() else 0)

        num_classes = 10

    elif args.data_set == 'cifar100':
        transform_train = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
        ]
        if args.auto_augment:
            transform_train.append(AutoAugmentBO(config))
        if args.cutout:
            transform_train.append(Cutout())
        transform_train.extend([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transform_train = transforms.Compose(transform_train)

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_set = datasets.CIFAR100(
            root='/cache/data/CIFAR100',
            train=True,
            download=True,
            transform=transform_train)

        if args.option == "search":
            split = 4000
            n_train = len(train_set)
            indices = list(range(n_train))
            random.shuffle(indices)

            train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
            train_loader = torch.utils.data.DataLoader(train_set,
                                                       batch_size=args.batch_size,
                                                       sampler=train_sampler,
                                                       num_workers=8 if torch.cuda.is_available() else 0,
                                                       pin_memory=True)
        elif args.option == "train":
            train_loader = torch.utils.data.DataLoader(
                train_set,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=8 if torch.cuda.is_available() else 0)
        else:
            raise RuntimeError("wrong option in cifar10_task.py!!!!!!!!!!!!!!!!")

        logger.info("train_loader size: %s" % len(train_loader))

        test_set = datasets.CIFAR100(
            root='/cache/data/CIFAR100',
            train=False,
            download=True,
            transform=transform_test)

        test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=8 if torch.cuda.is_available() else 0)

        num_classes = 100
    else:
        logger.error("wrong dataset: %s" % args.data_set)
        assert RuntimeError("wrong dataset: %s" % args.data_set)

    # create model
    if model_name == "wideresnet":
        model = WideResNet(args.depth, args.width, num_classes=num_classes)
    elif model_name == "simplenet":
        model = SimpleNet()
    else:
        logger.error(f"wrong model name: {model_name}")
        assert RuntimeError(f"wrong model name: {model_name}")
    # model = model.cuda()
    model = model.to(args.device)

    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
                          momentum=args.momentum, weight_decay=args.weight_decay)

    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[int(e) for e in args.milestones.split(',')], gamma=args.gamma)

    log = {"lr": [], "loss": [], 'acc': [], "val_loss": [], "val_acc": [], "val_acc5": [], "test_loss": [], "test_acc": [], "test_acc5": [], }
    best_test_acc = 0
    for epoch in range(args.epochs):
        # logger.info('Epoch [%d/%d]' %(epoch+1, args.epochs))

        # train for one epoch
        train_log = train(args, train_loader, model, criterion, optimizer, epoch)
        # evaluate on validation set
        test_log = validate(args, test_loader, model, criterion)

        scheduler.step()

        logger.debug('Epoch [%d/%d]: loss %.4f - acc %.4f - test_loss %.4f - test_acc %.4f'
                     % (epoch + 1, args.epochs, train_log['loss'], train_log['acc'], test_log['loss'], test_log['acc']))

        log["lr"].append(scheduler.get_lr()[0])
        log['loss'].append(train_log["loss"])
        log['acc'].append(train_log['acc'])
        log['test_loss'].append(test_log['loss'])
        log['test_acc'].append(test_log['acc'])
        log['test_acc5'].append(test_log['acc5'])

        if test_log['acc'] > best_test_acc:
            best_test_acc = test_log['acc']

        log['best_test_acc'] = best_test_acc

    logger.info(f"best_test_acc: {best_test_acc}")
    return best_test_acc, log


if __name__ == '__main__':
    best_test_acc, train_log = main()
    with open(os.path.join(args.shared_dir, f"result.pkl"), "wb") as f:
        pickle.dump([best_test_acc, train_log], f)
