""" Search cell """
import logging
import os
import pickle
import sys
import numpy as np
import torch
import torch.nn as nn

import utils
from architect import Architect
from config import SearchConfig
from models.search_cnn import SearchCNNController
# from tensorboardX import SummaryWriter
# from visualize import plot


sys.path.insert(0, "..")
from boss_config import init_logging
from worker import DartsWorker


config = SearchConfig()

init_logging(exp_dir=config.log_dir, config_path=os.path.join(config.code_dir, "logging_config.yaml"))
logger = logging.getLogger(__name__)
logger.info(f"------------- start task: {config.name} -------------")
config.print_params(logger.info)


# # tensorboard
# writer = SummaryWriter(log_dir=os.path.join(config.log_dir, "tb"))
# writer.add_text('config', config.as_markdown(), 0)

alpha_config = pickle.load(open(os.path.join(config.shared_dir, "config.pkl"), "rb"))
# alpha_config = DartsWorker.get_configspace().sample_configuration()


# 从alpha_config中还原出alpha值。
# alpha_config中的数据为{"normal.0.00", "normal.0.01", ..., "reduce.3.39"}。 normal.0.00 表示normal_cell的第0号节点的（00%8）的operator的alpha权重。
normal_0 = np.array([alpha_config[f"normal.0.{i:02d}"] for i in range(8 * 2)]).reshape(2, 8)
normal_1 = np.array([alpha_config[f"normal.1.{i:02d}"] for i in range(8 * 3)]).reshape(3, 8)
normal_2 = np.array([alpha_config[f"normal.2.{i:02d}"] for i in range(8 * 4)]).reshape(4, 8)
normal_3 = np.array([alpha_config[f"normal.3.{i:02d}"] for i in range(8 * 5)]).reshape(5, 8)

reduce_0 = np.array([alpha_config[f"reduce.0.{i:02d}"] for i in range(8 * 2)]).reshape(2, 8)
reduce_1 = np.array([alpha_config[f"reduce.1.{i:02d}"] for i in range(8 * 3)]).reshape(3, 8)
reduce_2 = np.array([alpha_config[f"reduce.2.{i:02d}"] for i in range(8 * 4)]).reshape(4, 8)
reduce_3 = np.array([alpha_config[f"reduce.3.{i:02d}"] for i in range(8 * 5)]).reshape(5, 8)


# print("reduce 0:")
# print(np.around(reduce_0, 4))
# print("reduce 1:")
# print(np.around(reduce_1, 4))
# print("reduce 2:")
# print(np.around(reduce_2, 4))
# print("reduce 3:")
# print(np.around(reduce_3, 4))


def main():

    # # set default gpu device id
    # torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    input_size, input_channels, n_classes, train_data = utils.get_data(
        config.dataset, config.data_path, cutout_length=0, validation=False)

    net_crit = nn.CrossEntropyLoss().to(config.device)
    model = SearchCNNController(input_channels, config.init_channels, n_classes, config.layers,
                                net_crit, device_ids=config.gpus)
    model = model.to(config.device)

    with torch.no_grad():
        model.alpha_normal[0].copy_(torch.from_numpy(normal_0))
        model.alpha_normal[1].copy_(torch.from_numpy(normal_1))
        model.alpha_normal[2].copy_(torch.from_numpy(normal_2))
        model.alpha_normal[3].copy_(torch.from_numpy(normal_3))

        model.alpha_reduce[0].copy_(torch.from_numpy(reduce_0))
        model.alpha_reduce[1].copy_(torch.from_numpy(reduce_1))
        model.alpha_reduce[2].copy_(torch.from_numpy(reduce_2))
        model.alpha_reduce[3].copy_(torch.from_numpy(reduce_3))

    # weights optimizer
    w_optim = torch.optim.SGD(model.weights(), config.w_lr, momentum=config.w_momentum,
                              weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(model.alphas(), config.alpha_lr, betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])


    
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model, config.w_momentum, config.w_weight_decay)

    # training loop

    log = {"lr": [], "loss": [], 'top1': [], "top5": [], "val_loss": [], "val_top1": [], "val_top5": [], "genotype": None}
    best_top1 = 0.
    for epoch in range(config.epochs):
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]

        model.print_alphas(logger)

        # training
        train_log = train(train_loader, valid_loader, model, architect, w_optim, alpha_optim, lr, epoch)

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        val_log = validate(valid_loader, model, epoch, cur_step)
        top1 = val_log["top1"]

        log["lr"].append(lr)
        log["loss"].append(train_log["loss"])
        log["top1"].append(train_log["top1"])
        log["top5"].append(train_log["top5"])
        log["val_loss"].append(val_log["loss"])
        log["val_top1"].append(val_log["top1"])
        log["val_top5"].append(val_log["top5"])

        # log
        # genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))
        logger.info("[%d/%d] - loss:%.2f - top1:%.2f - val_loss:%.2f - val_top1:%.2f " % (
            epoch + 1, config.epochs, train_log["loss"], train_log["top1"], val_log["loss"], val_log["top1"]))

        # 在bohb版本的darts中，中间的alpha没有改变，因此不需要每个epoch画图
        # # genotype as a image
        # plot_path = os.path.join(config.plot_path, "EP{:02d}".format(epoch+1))
        # caption = "Epoch {}".format(epoch+1)
        # plot(genotype.normal, plot_path + "-normal", caption)
        # plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        # if best_top1 < top1:
        #     best_top1 = top1
        #     best_genotype = genotype
        #     is_best = True
        #     filename = os.path.join(config.log_dir, 'task_best_model.pth')
        #     torch.save(model.state_dict(), filename)
        # else:
        #     is_best = False

        if top1 > best_top1:
            best_top1 = top1
            best_genotype = genotype

            if top1 > config.prev_best_acc:
                config.prev_best_acc = top1

                # 保存模型和结果
                torch.save(model.state_dict(), os.path.join(config.log_dir, "best_model.pth"))
                with open(os.path.join(config.log_dir, "best_config.pkl"), "wb") as f:
                    pickle.dump([top1, {"genotype": str(genotype),
                                        "normal": [normal_0, normal_1, normal_2, normal_3],
                                        "reduce": [reduce_0, reduce_1, reduce_2, reduce_3]}], f)

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
    log["genotype"] = str(best_genotype)

    return best_top1, log


def train(train_loader, valid_loader, model, architect, w_optim, alpha_optim, lr, epoch):
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    cur_step = epoch * len(train_loader)

    model.train()

    print(f"---------------> train at device: {config.device} --------------")
    for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(train_loader, valid_loader)):
        trn_X, trn_y = trn_X.to(config.device, non_blocking=True), trn_y.to(config.device, non_blocking=True)
        val_X, val_y = val_X.to(config.device, non_blocking=True), val_y.to(config.device, non_blocking=True)
        N = trn_X.size(0)

        # # 不更新alpha，使用bayesian optimization方法采样。
        # # phase 2. architect step (alpha)
        # alpha_optim.zero_grad()
        # architect.unrolled_backward(trn_X, trn_y, val_X, val_y, lr, w_optim)
        # alpha_optim.step()

        # phase 1. child network step (w)
        w_optim.zero_grad()
        logits = model(trn_X)
        loss = model.criterion(logits, trn_y)
        loss.backward()

        # gradient clipping
        nn.utils.clip_grad_norm_(model.weights(), config.w_grad_clip)
        w_optim.step()

        prec1, prec5 = utils.accuracy(logits, trn_y, topk=(1, 5))
        losses.update(loss.item(), N)
        top1.update(prec1.item(), N)
        top5.update(prec5.item(), N)

        if step % config.print_freq == 0 or step == len(train_loader) - 1:
            logger.info(
                "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                    epoch + 1, config.epochs, step, len(train_loader) - 1, losses=losses,
                    top1=top1, top5=top5))

        cur_step += 1

        # # --- lyj ---
        # if step == 1:
        #     break

    logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))

    return {"loss": losses.avg, "top1": top1.avg, "top5": top5.avg}


def validate(valid_loader, model, epoch, cur_step):
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    model.eval()

    with torch.no_grad():
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(config.device, non_blocking=True), y.to(config.device, non_blocking=True)
            N = X.size(0)

            logits = model(X)
            loss = model.criterion(logits, y)

            prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)

            if step % config.print_freq == 0 or step == len(valid_loader) - 1:
                logger.info(
                    "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses,
                        top1=top1, top5=top5))

            # # --- lyj ---
            # if step == 1:
            #     break

    logger.info("Valid: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))

    return {"loss": losses.avg, "top1": top1.avg, "top5": top5.avg}


if __name__ == "__main__":
    best_top1, log = main()
    with open(os.path.join(config.shared_dir, "result.pkl"), "wb") as f:
        pickle.dump([best_top1, log], f)
