import numpy as np
import torch
import torch.nn.functional as F
import os
from sklearn import metrics
from time import time
from scgm_g.scgm_resnet import resnet50
# from utils.model_toolkit import identity_layer
from utils.utils import get_test_dataloader_breeds
from eval.eval_performance import classify, mean_confidence_interval
# import resource

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

# rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
# resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))


if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    classifier = 'LR'
    n_test_runs = 1000
    n_ways = 68
    n_shots = 1
    n_queries = 15
    n_aug_support_samples = 5
    batchsz = 256
    set_cuda = True
    norm = True
    ds_name = 'living17'

    # load dataset
    # ---
    info_dir = '/nfs/data/usr/jni/datasets/imagenet_ilsvrc/ILSVRC/BREEDS/'
    data_dir = '/nfs/data/usr/jni/datasets/imagenet_ilsvrc/ILSVRC/Data/CLS-LOC/'

    breeds_test_loader = get_test_dataloader_breeds(
        ds_name=ds_name,
        info_dir=info_dir,
        data_dir=data_dir,
        n_test_runs=n_test_runs,
        n_ways=n_ways,
        n_shots=n_shots,
        n_queries=n_queries,
        n_aug_support_samples=n_aug_support_samples,
        batch_size=1,
        num_workers=0)

    # load model
    # ---
    n_class = 17
    k = 100
    kd_t = 4.0
    hiddim = 128
    with_mlp = True

    net = resnet50(num_classes=n_class, num_subclasses=k, kd_t=kd_t, hiddim=hiddim, with_mlp=with_mlp)
    weights_path = 'pretrain_model/scgm_g_' + ds_name + '.pth'
    net.load_state_dict(torch.load(weights_path, map_location='cpu'), strict=False)
    # net.fc_enc = identity_layer()

    if set_cuda is True:
        net.to(device)
        # net = torch.nn.DataParallel(net)

    net.eval()

    # evaluation
    # ---
    with torch.no_grad():
        acc = []
        t0 = time()

        for (run_idx, batch_data) in enumerate(breeds_test_loader):

            support_xs, support_ys, query_xs, query_ys = batch_data
            support_xs = support_xs[0]
            support_ys = support_ys[0]
            query_xs = query_xs[0]
            query_ys = query_ys[0]

            # load support set embeddings
            # ---
            support_feats = []

            if len(support_ys) > batchsz:
                loop_range = range(0, (len(support_ys) - batchsz), batchsz)
            else:
                loop_range = [0]

            for i in loop_range:
                if (len(support_ys) - i) < 2 * batchsz:
                    batchsz_iter = len(support_ys) - i
                else:
                    batchsz_iter = batchsz

                batch_support_xs = support_xs[i:(i + batchsz_iter)]
                if set_cuda is True:
                    batch_support_xs = batch_support_xs.to(device)

                batch_support_xs = net(batch_support_xs)
                # batch_support_xs = net.module.embed(batch_support_xs)
                batch_support_xs = net.embed(batch_support_xs)
                if norm is True:
                    batch_support_xs = F.normalize(batch_support_xs, p=2, dim=1)

                support_feats.append(batch_support_xs.detach().cpu().numpy())

            support_feats = np.concatenate(support_feats, axis=0)

            # load query set embeddings
            # ---
            query_feats = []

            if len(query_ys) > batchsz:
                loop_range = range(0, (len(query_ys) - batchsz), batchsz)
            else:
                loop_range = [0]

            for i in loop_range:
                if (len(query_ys) - i) < 2 * batchsz:
                    batchsz_iter = len(query_ys) - i
                else:
                    batchsz_iter = batchsz

                batch_query_xs = query_xs[i:(i + batchsz_iter)]
                if set_cuda is True:
                    batch_query_xs = batch_query_xs.to(device)

                batch_query_xs = net(batch_query_xs)
                # batch_query_xs = net.module.embed(batch_query_xs)
                batch_query_xs = net.embed(batch_query_xs)
                if norm is True:
                    batch_query_xs = F.normalize(batch_query_xs, p=2, dim=1)

                query_feats.append(batch_query_xs.detach().cpu().numpy())

            query_feats = np.concatenate(query_feats, axis=0)

            # classification
            # ---
            clf = classify(classifier, support_feats, support_ys)
            support_preds = clf.predict(support_feats)
            query_preds = clf.predict(query_feats)

            # evaluation
            # ---
            acc_tr = metrics.accuracy_score(support_ys, support_preds)
            acc_te = metrics.accuracy_score(query_ys, query_preds)
            acc.append(acc_te)

            print('[{:d}'.format(run_idx),
                  '/ {:d}]'.format(n_test_runs),
                  'training: acc = {:.5f}'.format(acc_tr * 100),
                  '| testing: acc = {:.5f}'.format(acc_te * 100))

            del clf

        acc_mn, acc_std = mean_confidence_interval(acc)
        print('accuracy={:.5f}'.format(acc_mn * 100),
              'std={:.5f}'.format(acc_std * 100),
              'time={:.5f}'.format(time() - t0))
