import argparse
import json
import os.path as osp

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from models.resnet import make_resnet50_base
from datasets.image_folder import ImageFolder
from utils import set_gpu, pick_vectors
from torchvision.models import resnet101

import os

from IPython import embed

DIR_PATH = os.path.dirname(os.path.realpath(__file__))

class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

def test_on_subset(dataset, cnn, n, pred_vectors, all_label,
                   consider_trains, device):
    hit = 0
    tot = 0

    loader = DataLoader(dataset=dataset, batch_size=32,
                        shuffle=False, num_workers=2)
    all_pred = []
    for batch_id, batch in enumerate(loader, 1):
        data, label = batch
        data = data.to(device)

        feat = cnn(data) # (batch_size, d)
        feat = torch.cat([feat, torch.ones(len(feat)).view(-1, 1).to(device)], dim=1)

        fcs = pred_vectors.t()

        table = torch.matmul(feat, fcs)
        if not consider_trains:
            table[:, :n] = -1e18

        pred = torch.argmax(table, dim=1)
        hit += (pred == all_label).sum().item()
        tot += len(data)

        all_pred.extend([p.cpu().numpy().tolist() for p in pred])

    return hit, tot, all_pred


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--pred')
    parser.add_argument('--gpu', default='0')
    parser.add_argument('--consider-trains', action='store_true')
    parser.add_argument('--output', default=None)
    args = parser.parse_args()

    if torch.cuda.is_available():
        device = torch.device('cuda:'+args.gpu)
    else:
        device = torch.device('cpu')
    print('device : ', device)

    awa2_split = json.load(open(osp.join(DIR_PATH, 'materials/awa2-split.json'), 'r'))
    train_wnids = awa2_split['train']
    test_wnids = awa2_split['test']
    print("pred: {}".format(os.path.basename(args.pred)))
    print('train: {}, test: {}'.format(len(train_wnids), len(test_wnids)))
    print('consider train classifiers: {}'.format(args.consider_trains))

    pred_file = torch.load(args.pred, map_location="cpu")
    pred_vectors = pred_file['awa']
    pred_vectors = pred_vectors.to(device)

    n = len(train_wnids)
    m = len(test_wnids)

    cnn = resnet101(pretrained=True)
    cnn.fc = Identity()

    cnn = cnn.to(device)
    cnn.eval()

    test_names = awa2_split['test_names']

    ave_acc = 0; ave_acc_n = 0

    results = {}

    awa2_path = osp.join(DIR_PATH, 'materials/datasets/awa2')
    output_path = osp.join(DIR_PATH, 'awa_save/')
    pred_name, ext = os.path.splitext(os.path.basename(args.pred))
    output_path = osp.join(output_path, 'result_'+pred_name+'.json')
    all_ground = []
    all_pred = []
    for i, name in enumerate(test_names, 1):
        dataset = ImageFolder(osp.join(awa2_path, 'JPEGImages'), [name], 'test')
        hit, tot, preds = test_on_subset(dataset, cnn, n, pred_vectors, n + i - 1,
                                  args.consider_trains, device)
        acc = hit / tot
        ave_acc += acc
        ave_acc_n += 1

        all_ground.extend([n + i - 1]*len(preds))
        all_pred.extend(preds)

        print('{} {}: {:.2f}%'.format(i, name.replace('+', ' '), acc * 100))

        results[name] = acc

    print('summary: {:.2f}%'.format(ave_acc / ave_acc_n * 100))

    results['overall'] = ave_acc / ave_acc_n * 100

    json.dump(results, open(output_path, 'w'))
