import time
import datetime
from pathlib import Path
import xlwt

from src.dataset.data_loader import GMDataset, get_dataloader
from src.evaluation_metric import *
from src.parallel import DataParallel
from src.utils.model_sl import load_model
from src.utils.data_to_cuda import data_to_cuda
from src.utils.timer import Timer

from src.utils.config import cfg
from pygmtools.benchmark import Benchmark

from torch.autograd import Variable
from torch.distributions.normal import Normal
from core_match_pixel import Smooth
import numpy as np
import os
import argparse

parser = argparse.ArgumentParser(description='Certify examples')
parser.add_argument("--method", default="ngmv2",choices=["ngmv2", "pca","cie","gmn-rrwm"],  help="matching methods")
parser.add_argument("--n0", type=int, default=10)
parser.add_argument("--n", type=int, default=100, help="number of samples to use")
parser.add_argument("--gamma", type=int, default=5, help="Normalization parameter")
args = parser.parse_args()

#parameters
n0=args.n0
n=args.n
gamma=args.gamma
method=args.method

file_name = "pixel_certify"
if os.path.exists("result_"+file_name)==False:
    os.mkdir("result_"+file_name)
if os.path.exists("result_"+file_name+"/"+method)==False:
    os.mkdir("result_"+file_name+"/"+method)

if __name__ == '__main__':
    from src.utils.dup_stdout_manager import DupStdoutFileManager
    from src.utils.parse_args import parse_args
    from src.utils.print_easydict import print_easydict

    args = parse_args('Deep learning of graph matching evaluation code.')

    import importlib
    mod = importlib.import_module(cfg.MODULE)
    Net = mod.Net

    torch.manual_seed(cfg.RANDOM_SEED)

    ds_dict = cfg[cfg.DATASET_FULL_NAME] if ('DATASET_FULL_NAME' in cfg) and (cfg.DATASET_FULL_NAME in cfg) else {}
    benchmark = Benchmark(name=cfg.DATASET_FULL_NAME,
                          sets='test',
                          problem=cfg.PROBLEM.TYPE,
                          obj_resize=cfg.PROBLEM.RESCALE,
                          filter=cfg.PROBLEM.FILTER,
                          **ds_dict)

    cls = None if cfg.EVAL.CLASS in ['none', 'all'] else cfg.EVAL.CLASS
    if cls is None:
        clss = benchmark.classes
    else:
        clss = [cls]

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = Net()
    model = model.to(device)
    model = DataParallel(model, device_ids=cfg.GPUS)

    if not Path(cfg.OUTPUT_PATH).exists():
        Path(cfg.OUTPUT_PATH).mkdir(parents=True)
    now_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    wb = xlwt.Workbook()
    ws = wb.add_sheet('epoch{}'.format(cfg.EVAL.EPOCH))

    with DupStdoutFileManager(str(Path(cfg.OUTPUT_PATH) / ('eval_log_' + now_time + '.log'))) as _:

        model_path = ''
        if cfg.EVAL.EPOCH is not None and cfg.EVAL.EPOCH > 0:
            model_path = str(Path(cfg.OUTPUT_PATH) / 'params' / 'params_{:04}.pt'.format(cfg.EVAL.EPOCH))
        if len(cfg.PRETRAINED_PATH) > 0:
            model_path = cfg.PRETRAINED_PATH
        if len(model_path) > 0:
            print('Loading model parameters from {}'.format(model_path))
            load_model(model, model_path)

        for ori_sigma in [1.5,2]:
            for ancer in [False]:
                for cov in [1]:

                    if cov == 0:
                        cov_pro = 10 #correlation probability
                    else:
                        cov_pro = 0

                    if ancer == False and cov == 1: #two-stage sample
                        f1 = open("result_" + file_name + "/" + method + "/sigmaori" + str(ori_sigma) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(cfg.EVAL.SAMPLES * 20)  + '_cov_Lvolume', 'w')
                        f2 = open("result_" + file_name + "/" + method + "/sigmaori" + str(ori_sigma) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(cfg.EVAL.SAMPLES * 20)  + '_cov_Llower', 'w')
                        f3 = open("result_" + file_name + "/" + method + "/sigmaori" + str(ori_sigma) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(cfg.EVAL.SAMPLES * 20)  + '_cov_Lmax', 'w')
                    if ancer == False and cov == 0: #RS
                        f1 = open("result_" + file_name + "/" + method + "/sigmaori" + str(ori_sigma) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(cfg.EVAL.SAMPLES * 20)  + '_RS_Lvolume', 'w')
                        f2 = open("result_" + file_name + "/" + method + "/sigmaori" + str(ori_sigma) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(cfg.EVAL.SAMPLES * 20)  + '_RS_Llower', 'w')
                        f3 = open("result_" + file_name + "/" + method + "/sigmaori" + str(ori_sigma) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(cfg.EVAL.SAMPLES * 20)  + '_RS_Lmax', 'w')

                    print("method", method, "/ori_sigma:", ori_sigma, "/cov:", cov)
                    print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f1, flush=True)
                    print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f2, flush=True)
                    print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f3, flush=True)
                    verbose = True
                    print('Start evaluation...')
                    since = time.time()

                    model.eval()
                    dataloaders = []

                    # create the smoothed classifier g
                    smoothed_classifier = Smooth(model)

                    for cls in clss:
                        image_dataset = GMDataset(name=cfg.DATASET_FULL_NAME,
                                                  bm=benchmark,
                                                  problem=cfg.PROBLEM.TYPE,
                                                  length=cfg.EVAL.SAMPLES,
                                                  cls=cls,
                                                  using_all_graphs=cfg.PROBLEM.TEST_ALL_GRAPHS)
                        torch.manual_seed(cfg.RANDOM_SEED)
                        dataloader = get_dataloader(image_dataset, shuffle=True)
                        dataloaders.append(dataloader)

                    timer = Timer()
                    number = 0

                    for i, cls in enumerate(clss):
                        if verbose:
                            print('Evaluating class {}: {}/{}'.format(cls, i, len(clss)))

                        running_since = time.time()
                        iter_num = 0

                        for inputs in dataloaders[i]:
                            if iter_num >= cfg.EVAL.SAMPLES / inputs['batch_size']:
                                break
                            if model.module.device != torch.device('cpu'):
                                inputs = data_to_cuda(inputs)
                            print("sample_number: ",iter_num)

                            batch_num = inputs['batch_size']

                            iter_num = iter_num + 1
                            batch_num = inputs['batch_size']

                            with torch.set_grad_enabled(False):

                                before_time = time.time()
                                clas_item = inputs['ns'][0].cpu().item()

                                prediction_Lvolume,prediction_Llower,prediction_Lmax = smoothed_classifier.certify(inputs, n0, n, alpha=0.001, batch_size=8,
                                                                         clas=clas_item,
                                                                         sigma_pro=cov_pro, sigma=ori_sigma,
                                                                         if_ancer=ancer,down=gamma,k=1.1)

                                after_time = time.time()

                                total_label = inputs['gt_perm_mat'][0].argmax(axis=1)
                                time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))

                                for j in range(clas_item):

                                    if prediction_Lvolume[j][0]!=-1:
                                        if inputs['gt_perm_mat'][0][j][prediction_Lvolume[j][0]].item() == 1:
                                            print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, total_label[j].item(), prediction_Lvolume[j][0], prediction_Lvolume[j][1], 1,time_elapsed),file=f1, flush=True)
                                        else:
                                            print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, total_label[j].item(), prediction_Lvolume[j][0], prediction_Lvolume[j][1], 0,time_elapsed), file=f1, flush=True)
                                    if prediction_Lvolume[j][0] == -1:
                                        print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, total_label[j].item(), prediction_Lvolume[j][0],prediction_Lvolume[j][1], 0,time_elapsed), file=f1, flush=True)

                                    if prediction_Llower[j][0] != -1:
                                        if inputs['gt_perm_mat'][0][j][prediction_Llower[j][0]].item() == 1:
                                            print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, total_label[j].item(), prediction_Llower[j][0], prediction_Llower[j][1], 1,time_elapsed),file=f2, flush=True)
                                        else:
                                            print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, total_label[j].item(), prediction_Llower[j][0], prediction_Llower[j][1], 0,time_elapsed), file=f2, flush=True)
                                    if prediction_Llower[j][0] == -1:
                                        print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, total_label[j].item(), prediction_Llower[j][0], prediction_Llower[j][1], 0,time_elapsed), file=f2, flush=True)

                                    if prediction_Lmax[j][0] != -1:
                                        if inputs['gt_perm_mat'][0][j][prediction_Lmax[j][0]].item() == 1:
                                            print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, total_label[j].item(), prediction_Lmax[j][0], prediction_Lmax[j][1], 1,time_elapsed),file=f3, flush=True)
                                        else:
                                            print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, total_label[j].item(), prediction_Lmax[j][0], prediction_Lmax[j][1], 0,time_elapsed), file=f3, flush=True)
                                    if prediction_Lmax[j][0] == -1:
                                        print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, total_label[j].item(), prediction_Lmax[j][0], prediction_Lmax[j][1],0,time_elapsed), file=f3, flush=True)
                    f1.close()
                    f2.close()
                    f3.close()