import numpy as np
import torch
import torch.nn.functional as F
import os
from sklearn import metrics
# from sklearn.utils import shuffle
from time import time
from scgm_a.model_generator import MoEncoderGenerator
# from utils.model_toolkit import identity_layer
from utils.utils import get_test_dataloader_breeds
from eval.eval_performance import classify, mean_confidence_interval

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    arch = 'resnet50'
    head_type = 'seq_em'
    cst_dim = 128
    queue_k = 65536
    encoder_m = 0.999
    cls_t = 1.0
    cst_t = 0.2
    n_subclass = 100
    tau1 = 0.1
    tau2 = 1.0
    head_norm = False
    with_mlp = False
    queue_type = 'multi'
    metric_type = 'angular'
    calc_types = ['cls', 'cst_by_class', 'cst_by_subclass', 'cst_two_class']

    classifier = 'LR'
    n_test_runs = 1000
    n_ways = 68
    n_shots = 1
    n_queries = 15
    n_aug_support_samples = 5
    batchsz = 256
    set_cuda = True
    norm = True
    ds_name = 'living17'

    # load dataset
    # ---
    info_dir = '/nfs/data/usr/jni/datasets/imagenet_ilsvrc/ILSVRC/BREEDS/'
    data_dir = '/nfs/data/usr/jni/datasets/imagenet_ilsvrc/ILSVRC/Data/CLS-LOC/'

    breeds_test_loader = get_test_dataloader_breeds(
        ds_name=ds_name,
        info_dir=info_dir,
        data_dir=data_dir,
        n_test_runs=n_test_runs,
        n_ways=n_ways,
        n_shots=n_shots,
        n_queries=n_queries,
        n_aug_support_samples=n_aug_support_samples,
        batch_size=1,
        num_workers=0)

    # load model
    # ---
    n_class = 17
    net = MoEncoderGenerator().generate_ancor_model(arch=arch,
                                                     head_type=head_type,
                                                     dim=cst_dim,
                                                     K=queue_k,
                                                     m=encoder_m,
                                                     T=[cls_t, cst_t, tau1, tau2],
                                                     mlp=with_mlp,
                                                     num_classes=n_class,
                                                     num_subclasses=n_subclass,
                                                     norm=head_norm,
                                                     queue_type=queue_type,
                                                     metric=metric_type,
                                                     calc_types=calc_types)

    weights_path = 'pretrain_model/scgm_a_' + ds_name + ' .pth'
    net.load_state_dict(torch.load(weights_path, map_location='cpu'), strict=False)
    # net.encoder_q.fc = identity_layer()

    if set_cuda is True:
        net.to(device)
        # net = torch.nn.DataParallel(net)

    net.eval()

    # evaluation
    # ---
    with torch.no_grad():
        acc = []
        t0 = time()

        for (run_idx, batch_data) in enumerate(breeds_test_loader):

            support_xs, support_ys, query_xs, query_ys = batch_data
            support_xs = support_xs[0]
            support_ys = support_ys[0]
            query_xs = query_xs[0]
            query_ys = query_ys[0]

            # load support set embeddings
            # ---
            support_feats = []

            if len(support_ys) > batchsz:
                loop_range = range(0, (len(support_ys) - batchsz), batchsz)
            else:
                loop_range = [0]

            for i in loop_range:
                if (len(support_ys) - i) < 2 * batchsz:
                    batchsz_iter = len(support_ys) - i
                else:
                    batchsz_iter = batchsz

                batch_support_xs = support_xs[i:(i + batchsz_iter)]
                # batch_support_xs = torch.tensor(batch_support_xs, dtype=torch.float32)
                if set_cuda is True:
                    batch_support_xs = batch_support_xs.to(device)

                # batch_support_xs, _ = net.module.encoder_q(batch_support_xs)
                # batch_support_xs = net.module.encoder_q(batch_support_xs)
                batch_support_xs, _ = net.encoder_q(batch_support_xs)
                # batch_support_xs = net.encoder_q(batch_support_xs)
                if norm is True:
                    batch_support_xs = F.normalize(batch_support_xs, p=2, dim=1)

                support_feats.append(batch_support_xs.detach().cpu().numpy())

            support_feats = np.concatenate(support_feats, axis=0)

            # load query set embeddings
            # ---
            query_feats = []

            if len(query_ys) > batchsz:
                loop_range = range(0, (len(query_ys) - batchsz), batchsz)
            else:
                loop_range = [0]

            for i in loop_range:
                if (len(query_ys) - i) < 2 * batchsz:
                    batchsz_iter = len(query_ys) - i
                else:
                    batchsz_iter = batchsz

                batch_query_xs = query_xs[i:(i + batchsz_iter)]
                # batch_query_xs = torch.tensor(batch_query_xs, dtype=torch.float32)
                if set_cuda is True:
                    batch_query_xs = batch_query_xs.to(device)

                # batch_query_xs, _ = net.module.encoder_q(batch_query_xs)
                # batch_query_xs = net.module.encoder_q(batch_query_xs)
                batch_query_xs, _ = net.encoder_q(batch_query_xs)
                # batch_query_xs = net.encoder_q(batch_query_xs)
                if norm is True:
                    batch_query_xs = F.normalize(batch_query_xs, p=2, dim=1)

                query_feats.append(batch_query_xs.detach().cpu().numpy())

            query_feats = np.concatenate(query_feats, axis=0)

            # classification
            # ---
            clf = classify(classifier, support_feats, support_ys)
            support_preds = clf.predict(support_feats)
            query_preds = clf.predict(query_feats)

            # evaluation
            # ---
            acc_tr = metrics.accuracy_score(support_ys, support_preds)
            acc_te = metrics.accuracy_score(query_ys, query_preds)
            acc.append(acc_te)

            print('[{:d}'.format(run_idx),
                  '/ {:d}]'.format(n_test_runs),
                  'training: acc = {:.5f}'.format(acc_tr * 100),
                  '| testing: acc = {:.5f}'.format(acc_te * 100))

            del clf

        acc_mn, acc_std = mean_confidence_interval(acc)
        print('accuracy={:.5f}'.format(acc_mn * 100),
              'std={:.5f}'.format(acc_std * 100),
              'time={:.5f}'.format(time() - t0))
