import numpy as np
import argparse
import torch
import os
import torch.nn as nn
print('Imported torch')
from util.fairness_utils import evaluate
from util.data_utils_balanced import load_dict_as_str
from util.data_utils_balanced import ImageFolderWithProtectedAttributes
from backbone.model_resnet import ResNet_50, ResNet_152
from backbone.MobileFaceNets import MobileFaceNet
import torchvision.transforms as transforms
import random
import ast

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)
random.seed(222)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument('--data_test_root', default='')
    parser.add_argument('--checkpoint', default="")
    parser.add_argument('--demographics', default= 'CelebA_demographics.txt')
    parser.add_argument('--backbone_name', default='')

    parser.add_argument('--groups_to_modify', default=['male', 'female'], type=str, nargs = '+')
    parser.add_argument('--p_identities', default=[1.0, 1.0], type=float, nargs = '+')
    parser.add_argument('--p_images', default=[1.0, 1.0], type=float,  nargs = '+')

    parser.add_argument('--batch_size', default=250, type=int)
    parser.add_argument('--input_size', default=[112,112], type=int)
    parser.add_argument('--embedding_size', default=512, type=int)
    parser.add_argument('--mean', default=[0.5, 0.5, 0.5], type=int)
    parser.add_argument('--std', default=[0.5, 0.5, 0.5], type=int)
    parser.add_argument('--seed', default=[0], type=int, nargs = '+')


    args = parser.parse_args()
    p_images = {args.groups_to_modify[i]:args.p_images[i] for i in range(len(args.groups_to_modify))}
    p_identities = {args.groups_to_modify[i]:args.p_identities[i] for i in range(len(args.groups_to_modify))}

    test_transform = transforms.Compose([
        transforms.Resize([int(128 * args.input_size[0] / 112), int(128 * args.input_size[1] / 112)]),
        transforms.CenterCrop([args.input_size[0], args.input_size[1]]),
        transforms.ToTensor(),
        transforms.Normalize(mean=args.mean,
                             std=args.std)])

    backbone_dict = {'MobileFaceNet': MobileFaceNet(embedding_size=512, out_h=7, out_w = 7),
                     'ResNet_50': ResNet_50(args.input_size),
                     'ResNet_152': ResNet_152(args.input_size),
                     }


    ###Load Data###
    # two dictionaries mapping demographic to classes
    demographic_to_classes = load_dict_as_str(args.demographics)
    classes_to_demographic = {cl: dem for dem, classes in demographic_to_classes.items() for cl in classes}

    for s in args.seed:

        data = ImageFolderWithProtectedAttributes(args.data_test_root, transform=test_transform,
                                                                     demographic_to_all_classes=demographic_to_classes,
                                                                     all_classes_to_demographic = classes_to_demographic,
                                                                     p_identities = p_identities,
                                                                     p_images = p_images,
                                                                     min_num = 3,
                                                                     ref_num_images = 7000,
                                                                     seed = s)


        demographic_to_labels = data.demographic_to_idx
        samples = data.samples
        label_to_demographic = {label: dem for dem, labels in demographic_to_labels.items() for label in labels}
        class_to_idx = data.class_to_idx
        idx_to_class = {idx: cl for cl, idx in class_to_idx.items()}

        dataloader = torch.utils.data.DataLoader(data, batch_size=args.batch_size, shuffle=False)
        backbone = backbone_dict[args.backbone_name]
        backbone.load_state_dict(torch.load(args.checkpoint))
        backbone = nn.DataParallel(backbone).to(device)

        loss, acc, acc_k, correct, labels_all, demographic_all = evaluate(dataloader, None, backbone, None, args.embedding_size,
                                           k_accuracy = True, multilabel_accuracy = False,  demographic_to_labels = demographic_to_labels)

        correct, labels_all, demographic_all = np.array(correct), np.array(labels_all), np.array(demographic_all)
