#output the iso theta
import time
from datetime import datetime
from pathlib import Path
import xlwt

from src.dataset.data_loader import GMDataset, get_dataloader
from src.evaluation_metric import *
from src.parallel import DataParallel
from src.utils.model_sl import load_model
from src.utils.data_to_cuda import data_to_cuda
from src.utils.timer import Timer

from src.utils.config import cfg
from pygmtools.benchmark import Benchmark

from torch.autograd import Variable
from torch.distributions.normal import Normal
import numpy as np

#parameters
sigma = 20
method="cie" #pca bbgm cie
date = 723
learning_rate = 0.04
verbose = True
iteration_num= 100

if __name__ == '__main__':
    from src.utils.dup_stdout_manager import DupStdoutFileManager
    from src.utils.parse_args import parse_args
    from src.utils.print_easydict import print_easydict

    args = parse_args('Deep learning of graph matching evaluation code.')

    import importlib
    mod = importlib.import_module(cfg.MODULE)
    Net = mod.Net

    torch.manual_seed(cfg.RANDOM_SEED)

    ds_dict = cfg[cfg.DATASET_FULL_NAME] if ('DATASET_FULL_NAME' in cfg) and (cfg.DATASET_FULL_NAME in cfg) else {}
    benchmark = Benchmark(name=cfg.DATASET_FULL_NAME,
                          sets='test',
                          problem=cfg.PROBLEM.TYPE,
                          obj_resize=cfg.PROBLEM.RESCALE,
                          filter=cfg.PROBLEM.FILTER,
                          **ds_dict)

    cls = None if cfg.EVAL.CLASS in ['none', 'all'] else cfg.EVAL.CLASS
    if cls is None:
        clss = benchmark.classes
    else:
        clss = [cls]

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = Net()
    model = model.to(device)
    model = DataParallel(model, device_ids=cfg.GPUS)

    if not Path(cfg.OUTPUT_PATH).exists():
        Path(cfg.OUTPUT_PATH).mkdir(parents=True)
    now_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    wb = xlwt.Workbook()
    ws = wb.add_sheet('epoch{}'.format(cfg.EVAL.EPOCH))

    with DupStdoutFileManager(str(Path(cfg.OUTPUT_PATH) / ('eval_log_' + now_time + '.log'))) as _:

        model_path = ''
        if cfg.EVAL.EPOCH is not None and cfg.EVAL.EPOCH > 0:
            model_path = str(Path(cfg.OUTPUT_PATH) / 'params' / 'params_{:04}.pt'.format(cfg.EVAL.EPOCH))
        if len(cfg.PRETRAINED_PATH) > 0:
            model_path = cfg.PRETRAINED_PATH
        if len(model_path) > 0:
            print('Loading model parameters from {}'.format(model_path))
            load_model(model, model_path)

        verbose = True
        print('Start evaluation...')
        since = time.time()

        device = next(model.parameters()).device

        #was_training = model.training
        model.eval()

        dataloaders = []

        for cls in clss:
            image_dataset = GMDataset(name=cfg.DATASET_FULL_NAME,
                                      bm=benchmark,
                                      problem=cfg.PROBLEM.TYPE,
                                      length=cfg.EVAL.SAMPLES,
                                      cls=cls,
                                      using_all_graphs=cfg.PROBLEM.TEST_ALL_GRAPHS)
            torch.manual_seed(cfg.RANDOM_SEED)
            dataloader = get_dataloader(image_dataset, shuffle=True)
            dataloaders.append(dataloader)

        timer = Timer()

        all_theta = torch.ones([1, cfg.EVAL.SAMPLES*20])

        for i, cls in enumerate(clss):
            if verbose:
                print('Evaluating class {}: {}/{}'.format(cls, i, len(clss)))

            running_since = time.time()
            iter_num = 0

            for inputs in dataloaders[i]:
                if iter_num >= cfg.EVAL.SAMPLES / inputs['batch_size']:
                    break
                if model.module.device != torch.device('cpu'):
                    inputs = data_to_cuda(inputs)

                batch_num = inputs['batch_size']

                iter_num = iter_num + 1

                keypoint_num = inputs['ns'][0].cpu().item()
                ori_theta = torch.ones(1).to(device) * sigma
                theta = Variable(ori_theta, requires_grad=True).to(device)
                optimizer = torch.optim.Adam([theta], lr=learning_rate)
                initial_theta = theta.detach().clone()

                m = Normal(
                    torch.zeros(keypoint_num).to(device),
                    torch.ones(keypoint_num).to(device)
                )

                for _ in range(iteration_num):

                    noise = torch.randn_like(inputs['Ps'][0]) * theta
                    inputs['Ps'][0] = inputs['Ps'][0] + noise
                    out = model(inputs)

                    vls, vls_order = out['ds_mat'][0].topk(2, dim=1, largest=True, sorted=True)
                    gap = (m.icdf(vls[:, 0].clamp_(0.001, 0.999)) - m.icdf(vls[:, 1].clamp_(0.001, 0.999))) / 2

                    if (_ % 20) == 0:
                        print(theta)

                    radius_maximizer = -(
                        (theta / 2 * gap).sum()
                    )

                    radius_maximizer.backward(retain_graph=True)
                    optimizer.step()
                    optimizer.zero_grad()

                all_theta[0][i*cfg.EVAL.SAMPLES+iter_num-1] = theta.detach()
        print(all_theta)

        np.savetxt("Isotropic_DD_result/date" + str(date) + "_orisigma" + str(sigma) + "_samples" + str(
                cfg.EVAL.SAMPLES*20) + "_method"+method+"_iso", all_theta.numpy(), delimiter=',')

