##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##############################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
from statistics import mean, stdev

from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, SearchDataset
from xautodl.procedures import (
    prepare_seed,
    prepare_logger,
    save_checkpoint,
    copy_checkpoint,
    get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_search_spaces

from nas_201_api import NASBench201API as API
from R_EA import train_and_eval, random_architecture_func, get_final_accuracy, get_num_classes
from gs import GsApi


def main(xargs, nas_bench, gs_api):
    assert torch.cuda.is_available(), "CUDA is not available."
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    acc = {}

    if xargs.dataset == "cifar10":
        dataname = "cifar10-valid"
    else:
        dataname = xargs.dataset
    if xargs.data_path is not None:
        train_data, valid_data, xshape, class_num = get_datasets(
            xargs.dataset, xargs.data_path, -1
        )
        split_Fpath = "configs/nas-benchmark/cifar-split.txt"
        cifar_split = load_config(split_Fpath, None, None)
        train_split, valid_split = cifar_split.train, cifar_split.valid
        logger.log("Load split file from {:}".format(split_Fpath))
        config_path = "configs/nas-benchmark/algos/R-EA.config"
        config = load_config(
            config_path, {"class_num": class_num, "xshape": xshape}, logger
        )
        # To split data
        train_data_v2 = deepcopy(train_data)
        train_data_v2.transform = valid_data.transform
        valid_data = train_data_v2
        search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
        # data loader
        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=config.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
            num_workers=xargs.workers,
            pin_memory=True,
        )
        valid_loader = torch.utils.data.DataLoader(
            valid_data,
            batch_size=config.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
            num_workers=xargs.workers,
            pin_memory=True,
        )
        logger.log(
            "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
                xargs.dataset, len(train_loader), len(valid_loader), config.batch_size
            )
        )
        logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
        extra_info = {
            "config": config,
            "train_loader": train_loader,
            "valid_loader": valid_loader,
        }
    else:
        config_path = "configs/nas-benchmark/algos/R-EA.config"
        config = load_config(config_path, None, logger)
        logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
        extra_info = {"config": config, "train_loader": None, "valid_loader": None}
    search_space = get_search_spaces("cell", xargs.search_space_name)
    random_arch = random_architecture_func(xargs.max_nodes, search_space)
    # x =random_arch() ; y = mutate_arch(x)
    x_start_time = time.time()
    logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench))
    best_arch, best_acc, total_time_cost, history = None, -1, 0, []
    # for idx in range(xargs.random_num):
    while total_time_cost < xargs.time_budget:

        # G-Random
        arch_list = []
        score_list = []
        for i in range(10):
            arch = random_arch()
            arch_id = nas_bench.query_index_by_arch(arch)
            score = gs_api.get_score_by_index(arch_id)
            arch_list.append(arch)
            score_list.append(score)
        best_index = np.argmax(np.array(score_list))
        arch = arch_list[best_index]

        accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
        if total_time_cost + cost_time > xargs.time_budget:
            break
        else:
            total_time_cost += cost_time
        history.append(arch)
        if best_arch is None or best_acc < accuracy:
            best_acc, best_arch = accuracy, arch
        logger.log(
            "[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy)
        )
    logger.log(
        "{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).".format(
            time_string(),
            best_arch,
            best_acc,
            len(history),
            total_time_cost,
            time.time() - x_start_time,
        )
    )

    info = nas_bench.query_by_arch(best_arch, "200")

    best_index = nas_bench.query_index_by_arch(best_arch)

    # added by Zhihao Zhang
    acc["cifar10_valid"] = get_final_accuracy("cifar10", nas_bench, best_index, True)
    acc["cifar10_test"] = get_final_accuracy("cifar10", nas_bench, best_index, False)
    acc["cifar100_valid"] = get_final_accuracy("cifar100", nas_bench, best_index, True)
    acc["cifar100_test"] = get_final_accuracy("cifar100", nas_bench, best_index, False)
    acc["in_valid"] = get_final_accuracy("ImageNet16-120", nas_bench, best_index, True)
    acc["in_test"] = get_final_accuracy("ImageNet16-120", nas_bench, best_index, False)


    if info is None:
        logger.log("Did not find this architecture : {:}.".format(best_arch))
    else:
        logger.log("{:}".format(info))
    logger.log("-" * 100)
    logger.close()
    return logger.log_dir, nas_bench.query_index_by_arch(best_arch), acc


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Random NAS")
    parser.add_argument("--data_path", type=str, help="Path to dataset")
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["cifar10", "cifar100", "ImageNet16-120"],
        help="Choose between Cifar10/100 and ImageNet-16.",
    )
    # channels and number-of-cells
    parser.add_argument("--search_space_name", type=str, help="The search space name.")
    parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
    parser.add_argument("--channel", type=int, help="The number of channels.")
    parser.add_argument(
        "--num_cells", type=int, help="The number of cells in one stage."
    )
    # parser.add_argument('--random_num',         type=int,   help='The number of random selected architectures.')
    parser.add_argument(
        "--time_budget",
        type=int,
        help="The total time cost budge for searching (in seconds).",
    )
    # log
    parser.add_argument(
        "--workers",
        type=int,
        default=2,
        help="number of data loading workers (default: 2)",
    )
    parser.add_argument(
        "--save_dir", type=str, help="Folder to save checkpoints and log."
    )
    parser.add_argument(
        "--arch_nas_dataset",
        type=str,
        help="The path to load the architecture dataset (tiny-nas-benchmark).",
    )
    parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
    parser.add_argument("--rand_seed", type=int, help="manual seed")
    args = parser.parse_args()
    # if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)

    '''
    loading GsApi
    '''
    gs_root = "/home/ubuntu/zzh/zero-cost-nas/201_results_batch_128"
    gs_data = "cifar10"
    gs_cat = "mean"
    gs_seed = 42

    pre = 'cf' if 'cifar' in gs_data else 'im'
    pfn = f'nb2_{pre}{get_num_classes(gs_data)}_seed{gs_seed}_base.p'

    gs_path = os.path.join(gs_root, gs_cat, pfn)
    gs_api = GsApi(gs_path)

    if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
        nas_bench = None
    else:
        print(
            "{:} build NAS-Benchmark-API from {:}".format(
                time_string(), args.arch_nas_dataset
            )
        )
        nas_bench = API(args.arch_nas_dataset)
    if args.rand_seed < 0:
        save_dir, all_indexes, num = None, [], 500
        stat = {
            "cifar10_valid": [],
            "cifar10_test": [],
            "cifar100_valid": [],
            "cifar100_test": [],
            "in_valid": [],
            "in_test": [],
        }
        for i in range(num):
            print("{:} : {:03d}/{:03d}".format(time_string(), i, num))
            args.rand_seed = random.randint(1, 100000)
            # save_dir, index, acc = main(args, nas_bench, gs_api)
            save_dir, index, acc = main(args, nas_bench, gs_api)
            all_indexes.append(index)

            # Zhihao Zhang added
            for key in stat:
                stat[key].append(acc[key])
        print("Stats:")
        for key in stat:
            print("{}: {}+-{}".format(key, mean(stat[key]), stdev(stat[key])))
        torch.save(all_indexes, save_dir / "results.pth")
    else:
        main(args, nas_bench)
