""" Training augmented model """
import os
import torch
import torch.nn as nn
import numpy as np
# from tensorboardX import SummaryWriter
from config import AugmentConfig
import utils
from models.augment_cnn import AugmentCNN, AugmentCNNImageNet
import logging
import subprocess
import pickle

import sys
sys.path.insert(0, "..")
from boss_config import init_logging


config = AugmentConfig()

init_logging(exp_dir=config.log_dir, config_path=os.path.join(config.code_dir, "logging_config.yaml"))
logger = logging.getLogger(__name__)
logger.info(f"------------- start autoaugment task: {config.dataset} -------------")
config.print_params(logger.info)

device = torch.device("cuda")

# # tensorboard
# writer = SummaryWriter(log_dir=os.path.join(config.log_dir, "tb"))
# writer.add_text('config', config.as_markdown(), 0)

# logger = utils.get_logger(os.path.join(config.path, "{}.log".format(config.name)))
# config.print_params(logger.info)


def main():
    logger.info("Logger is set - training start")

    # set default gpu device id
    # torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    input_size, input_channels, n_classes, train_data, valid_data = utils.get_data(
        config.dataset, config.data_path, config.cutout_length, validation=True)

    criterion = nn.CrossEntropyLoss().to(device)
    use_aux = config.aux_weight > 0.

    if config.dataset == "imagenet":
        model = AugmentCNNImageNet(input_size, input_channels, config.init_channels, n_classes, config.layers,
                                   use_aux, config.genotype)
    else:
        model = AugmentCNN(input_size, input_channels, config.init_channels, n_classes, config.layers,
                           use_aux, config.genotype)

    model = nn.DataParallel(model, device_ids=config.gpus).to(device)

    # model size
    mb_params = utils.param_size(model)
    logger.info("Model size = {:.3f} MB".format(mb_params))

    # weights optimizer
    optimizer = torch.optim.SGD(model.parameters(), config.lr, momentum=config.momentum,
                                weight_decay=config.weight_decay)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_data,
                                               batch_size=config.batch_size,
                                               shuffle=False,
                                               num_workers=config.workers,
                                               pin_memory=True)

    if config.dataset == "imagenet":
        # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.97)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs)
    else:
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs)

    best_top1 = 0.
    log = {"loss": [], "top1": [], "top5": [], "val_loss": [], "val_top1": [], "val_top5": []}
    # training loop
    for epoch in range(config.epochs):
        lr_scheduler.step()
        drop_prob = config.drop_path_prob * epoch / config.epochs
        model.module.drop_path_prob(drop_prob)

        # training
        trn_log = train(train_loader, model, optimizer, criterion, epoch)

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        val_log = validate(valid_loader, model, criterion, epoch, cur_step)
        top1 = val_log["top1"]

        log["loss"].append(trn_log["loss"])
        log["top1"].append(trn_log["top1"])
        log["top5"].append(trn_log["top5"])
        log["val_loss"].append(val_log["loss"])
        log["val_top1"].append(val_log["top1"])
        log["val_top5"].append(val_log["top5"])

        # save
        if best_top1 < top1:
            best_top1 = top1
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model.state_dict(), config.log_dir, is_best)

        with open(os.path.join(config.log_dir, "training.pkl"), "wb") as f:
            pickle.dump(log, f)

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))

    return best_top1, log


def train(train_loader, model, optimizer, criterion, epoch):
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    cur_step = epoch * len(train_loader)
    cur_lr = optimizer.param_groups[0]['lr']
    logger.info("Epoch {} LR {}".format(epoch, cur_lr))
    # writer.add_scalar('train/lr', cur_lr, cur_step)

    model.train()

    for step, (X, y) in enumerate(train_loader):
        X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
        N = X.size(0)

        optimizer.zero_grad()
        logits, aux_logits = model(X)
        loss = criterion(logits, y)
        if config.aux_weight > 0.:
            loss += config.aux_weight * criterion(aux_logits, y)
        loss.backward()
        # gradient clipping
        nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        optimizer.step()

        prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
        losses.update(loss.item(), N)
        top1.update(prec1.item(), N)
        top5.update(prec5.item(), N)

        if step % config.print_freq == 0 or step == len(train_loader) - 1:
            logger.info(
                "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                    epoch + 1, config.epochs, step, len(train_loader) - 1, losses=losses,
                    top1=top1, top5=top5))
            # subprocess.call("nvidia-smi", shell=True)

        # writer.add_scalar('train/loss', loss.item(), cur_step)
        # writer.add_scalar('train/top1', prec1.item(), cur_step)
        # writer.add_scalar('train/top5', prec5.item(), cur_step)
        cur_step += 1

        # # --- lyj ---
        # if step == config.print_freq:
        #     break

    logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))

    return {"loss": losses.avg, "top1": top1.avg, "top5": top5.avg}


def validate(valid_loader, model, criterion, epoch, cur_step):
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    model.eval()

    with torch.no_grad():
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)  # non_blocking是什么意思？
            N = X.size(0)

            logits, _ = model(X)
            loss = criterion(logits, y)

            prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)

            if step % config.print_freq == 0 or step == len(valid_loader) - 1:
                logger.info(
                    "Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses,
                        top1=top1, top5=top5))

            # # --- lyj ---
            # if step == config.print_freq:
            #     break

    # writer.add_scalar('val/loss', losses.avg, cur_step)
    # writer.add_scalar('val/top1', top1.avg, cur_step)
    # writer.add_scalar('val/top5', top5.avg, cur_step)

    logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))

    # return top1.avg

    return {"loss": losses.avg, "top1": top1.avg, "top5": top5.avg}


if __name__ == "__main__":
    best_top1, log = main()
    with open(os.path.join(config.shared_dir, "result.pkl"), "wb") as f:
        pickle.dump([best_top1, log], f)
