import numpy as np
import argparse
import os
import random
import pandas as pd
from collections import OrderedDict

import tabulate
parser = argparse.ArgumentParser(description='Produce tables')
parser.add_argument('--data_loc', default='../fishersearch_randomwirenetworks/cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../fimflam/NAS-Bench-201-v1_0-e61699.pth',
                    type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='corrdistintegral0_025', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--repeat', default=256, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='gaussnoise', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.01, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--activations', action='store_true')
parser.add_argument('--cosine', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')

args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU

from statistics import mean, median, stdev as std

import torch

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

df = []

datasets = OrderedDict()

datasets['CIFAR-10 (val)'] = ('cifar10-valid', 'x-valid', True)
datasets['CIFAR-10 (test)'] = ('cifar10-valid', 'ori-test', False)

### CIFAR-100
if args.nasspace == 'nasbench201':
    datasets['CIFAR-100 (val)'] = ('cifar100', 'x-valid', False)
    datasets['CIFAR-100 (test)'] = ('cifar100', 'x-test', False)

    datasets['ImageNet16-120 (val)'] = ('ImageNet16-120', 'x-valid', False)
    datasets['ImageNet16-120 (test)'] = ('ImageNet16-120', 'x-test', False)


dataset_top1s = OrderedDict()

#for n_samples in [10, 100]:
for n_samples in [10, 100]:
    method = f"Ours (N={n_samples})"

    time = 0.

    for dataset, params in datasets.items():
        top1s = []

        dset =  params[0]
        acc_type = 'accs' if 'test' in params[1] else 'val_accs'
        #acc_type = 'accs'
        #filename = f"{args.save_loc}/{dset}_{args.n_runs}_{n_samples}_{args.seed}.t7"
        filename = f"{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{dset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.batch_size}_{args.n_runs}_{n_samples}_{args.seed}.t7"

        full_scores = torch.load(filename)
        #print(acc_type)
        #print(full_scores.keys())
        #print(median(full_scores['times']))
        if dataset == 'CIFAR-10 (val)':
            time = median(full_scores['times'])
            time = f"{time:.2f}"
        accs = []
        for n in range(args.n_runs):
            acc = full_scores[acc_type][n]
            if args.nasspace == 'nasbench101':
                acc = 100.*acc
            accs.append(acc)
        dataset_top1s[dataset] = accs

    cifar10_val  = f"{mean(dataset_top1s['CIFAR-10 (val)']):.2f} +- {std(dataset_top1s['CIFAR-10 (val)']):.2f}"
    cifar10_test = f"{mean(dataset_top1s['CIFAR-10 (test)']):.2f} +- {std(dataset_top1s['CIFAR-10 (test)']):.2f}"
    #cifar10_test = f'-'

    if args.nasspace == 'nasbench201':
        cifar100_val  = f"{mean(dataset_top1s['CIFAR-100 (val)']):.2f} +- {std(dataset_top1s['CIFAR-100 (val)']):.2f}"
        cifar100_test = f"{mean(dataset_top1s['CIFAR-100 (test)']):.2f} +- {std(dataset_top1s['CIFAR-100 (test)']):.2f}"
    else:
        cifar100_val = '-'
        cifar100_test = f'-'

    if args.nasspace == 'nasbench201':
        imagenet_val  = f"{mean(dataset_top1s['ImageNet16-120 (val)']):.2f} +- {std(dataset_top1s['ImageNet16-120 (val)']):.2f}"
        imagenet_test = f"{mean(dataset_top1s['ImageNet16-120 (test)']):.2f} +- {std(dataset_top1s['ImageNet16-120 (test)']):.2f}"
    else:
        imagenet_val = '-'
        imagenet_test = f'-'

    df.append([method, time, cifar10_val, cifar10_test, cifar100_val, cifar100_test, imagenet_val, imagenet_test])


df = pd.DataFrame(df, columns=['Method','Search time (s)','CIFAR-10 (val)','CIFAR-10 (test)','CIFAR-100 (val)','CIFAR-100 (test)','ImageNet16-120 (val)','ImageNet16-120 (test)' ])

print(tabulate.tabulate(df.values,df.columns, tablefmt="pipe"))
