# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.

import os
import sys
import numpy as np
import torch
import argparse
from tqdm import tqdm
from path_prefix import path_prefix

from load_test_models import load_test_models
from ofa.imagenet_classification.run_manager.run_config import get_data_provider_by_name
from ofa.utils import accuracy
pwd = os.getcwd()
sys.path.append(pwd)



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="flw",
        choices=[
            "cifar10",
            "mnist",
            "kmnist",
            "fmnist",
            "svhn",
            "emnistlet",
            'biwi',
            'flw',
        ],
    )
    parser.add_argument(
        "--method",
        type=str,
        default="random",
        choices=[
            "entropy",
            "coreset",
            "random",
            "DIAM",
            "lewisweight",
            "qbc",
        ],
    )
    parser.add_argument("--al_iter", type=int, default=0)
    parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
    parser.add_argument("--model_num", type=int, default=50)
    parser.add_argument("--regression", action="store_true")
    args = parser.parse_args()

    if args.gpu == "all":
        device_list = range(torch.cuda.device_count())
        args.gpu = ",".join(str(_) for _ in device_list)
    else:
        device_list = [int(_) for _ in args.gpu.split(",")]
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    results_saving_dir = f"./extracted_results/{args.model_num}/{args.dataset}/{args.method}"
    print("testing ", results_saving_dir)
    os.makedirs(results_saving_dir, exist_ok=True)
    if args.dataset == 'emnistlet':
        NCLASSES = 27
    elif args.dataset == 'biwi':
        NCLASSES = 2
    else:
        NCLASSES = 10

    if args.al_iter == 0:
        ofa_checkpoint_root = f"{path_prefix}exp/{args.al_iter}/{args.dataset}/"
        lab_idx = np.loadtxt(os.path.join(ofa_checkpoint_root, "lab_idx.txt"), dtype=int)
        unlab_idx = np.loadtxt(os.path.join(ofa_checkpoint_root, "unlab_idx.txt"), dtype=int)
    else:
        ofa_checkpoint_root = f"{path_prefix}exp/{args.al_iter}/{args.dataset}/{args.method}/"
        al_idx_save_root = os.path.join(f"{path_prefix}al_results", str(args.al_iter), args.dataset)
        al_save_root = os.path.join(al_idx_save_root, args.method)
        lab_idx = np.loadtxt(os.path.join(al_save_root, "lab_idx.txt"), dtype=int)
        unlab_idx = np.loadtxt(os.path.join(al_save_root, "unlab_idx.txt"), dtype=int)
    # load model
    net, image_size = load_test_models(net_id=0, trained_weights=None)
    del net
    ########################### construct sequential unlab load dataloader ########################################
    DataProvider = get_data_provider_by_name(args.dataset)
    dpv = DataProvider(
        train_batch_size=256,
        test_batch_size=128,
        valid_size=3000,
        n_worker=16,
        resize_scale=0.08,
        distort_color="tf",
        image_size=image_size,
        num_replicas=None,
        lab_idx=lab_idx,
        unlab_idx=unlab_idx,
    )
    unlab_dl = dpv.unlab
    test_dl = dpv.test
    train_dl = dpv.train

    if not args.regression:
        performances = np.zeros([args.model_num, 2])
        for imodel in range(args.model_num):
            ofa_ckpt_path = os.path.join(ofa_checkpoint_root, str(imodel), "checkpoint/checkpoint.pth.tar")
            net, image_size = load_test_models(net_id=imodel, n_classes=NCLASSES, trained_weights=ofa_ckpt_path)
            net.to("cuda")
            net.eval()

            top1 = 0
            top5 = 0
            all_num = 0
            with torch.no_grad():
                with tqdm(
                        total=len(test_dl),
                        desc="Validate model #{} ".format(imodel)
                ) as t:
                    for i, (images, labels) in enumerate(test_dl):
                        images, labels = images.cuda(), labels.cuda()
                        output = net(images)
                        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
                        all_num += output.size(0)
                        top1 += acc1 * output.size(0)
                        top5 += acc5 * output.size(0)
                        t.update(1)

            top1 /= all_num
            top5 /= all_num
            print(top1, top5)
            performances[imodel, :] = [float(top1), float(top5)]
        np.savetxt(os.path.join(results_saving_dir, "performances.txt"), performances, fmt="%f")
    else:
        from torchmetrics import R2Score, MeanAbsoluteError, MeanSquaredError
        performances = np.zeros([args.model_num, 3])
        R2metric = R2Score(num_outputs=NCLASSES).to("cuda")
        MAEmetric = MeanAbsoluteError().to("cuda")
        MSEmetric = MeanSquaredError().to("cuda")
        for imodel in range(args.model_num):
            ofa_ckpt_path = os.path.join(ofa_checkpoint_root, str(imodel), "checkpoint/checkpoint.pth.tar")
            net, image_size = load_test_models(net_id=imodel, n_classes=NCLASSES, trained_weights=ofa_ckpt_path)
            net.to("cuda")
            net.eval()

            r2score = 0
            msescore = 0
            maescore = 0
            mean_r2score = 0
            mean_msescore = 0
            mean_maescore = 0
            all_num = 0
            with torch.no_grad():
                with tqdm(
                        total=len(test_dl),
                        desc="Validate model #{} ".format(imodel)
                ) as t:
                    for i, (images, labels) in enumerate(test_dl):
                        images, labels = images.cuda(), labels.cuda()
                        output = net(images)
                        r2score = R2metric(output, labels)
                        msescore = MSEmetric(output, labels)
                        maescore = MAEmetric(output, labels)
                        all_num += output.size(0)
                        mean_r2score += r2score * output.size(0)
                        mean_msescore += msescore * output.size(0)
                        mean_maescore += maescore * output.size(0)
                        t.update(1)

            mean_r2score /= all_num
            mean_msescore /= all_num
            mean_maescore /= all_num
            print(mean_r2score, mean_msescore, mean_maescore)
            performances[imodel, :] = [float(mean_r2score), float(mean_msescore), float(mean_maescore)]
        np.savetxt(os.path.join(results_saving_dir, "performances.txt"), performances, fmt="%f")
