import argparse
import torch
import numpy as np
import torch.optim as optim
from utils import KpiReader
from models import StackedVAGT
from logger import Logger
import time
from evaluate_pot import *
import matplotlib.pyplot as plt
import scipy.io as sio
import torch.nn.functional as F


class Tester(object):
    def __init__(self, model, test, testloader, log_path='log_tester', log_file='loss', device=torch.device('cpu'),
                 learning_rate=0.0002, level=0.003, nsamples=None, sample_path=None, checkpoints=None):
        self.model = model
        self.model.to(device)
        self.device = device
        self.test = test
        self.testloader = testloader
        self.log_path = log_path
        self.log_file = log_file
        self.learning_rate = learning_rate
        self.nsamples = nsamples
        self.sample_path = sample_path
        self.checkpoints = checkpoints
        self.start_epoch = 0
        self.optimizer = optim.Adam(self.model.parameters(), self.learning_rate)
        self.epoch_losses = []
        self.logger = Logger(self.log_path, self.log_file)
        self.loss = {}
        self.level = level

    def load_checkpoint(self, start_ep):
        try:
            print("Loading Chechpoint from ' {} '".format(self.checkpoints + '_epochs{}.pth'.format(start_ep)))
            checkpoint = torch.load(self.checkpoints + '_epochs{}.pth'.format(start_ep))
            self.start_epoch = checkpoint['epoch']
            self.model.beta = checkpoint['beta']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.epoch_losses = checkpoint['losses']
            print("Resuming Training From Epoch {}".format(self.start_epoch))
        except:
            print("No Checkpoint Exists At '{}', Starting Fresh Training".format(
                self.checkpoints + '_epochs{}.pth'.format(start_ep)))
            self.start_epoch = 0

    def model_test(self, dataName, res=None):
        self.model.eval()
        res_each_mach_TP = []
        res_each_mach_TN = []
        res_each_mach_FP = []
        res_each_mach_FN = []
        loss_res_box = []
        labels_box = []
        x_data = []
        x_mu = []
        z_feature = []

        Gx = F.relu(torch.mm(self.model.alpha0, self.model.alpha0.t())).detach().cpu().numpy()
        Gz = F.relu(torch.mm(self.model.alpha1, self.model.alpha1.t())).detach().cpu().numpy()
        for i, dataitem in enumerate(self.testloader, 1):
            timestamps, labels, data = dataitem
            data = data.to(self.device)
            z_posterior_forward_list, \
            z_mean_posterior_forward_list, \
            z_logvar_posterior_forward_list, \
            z_mean_prior_forward_list, \
            z_logvar_prior_forward_list, \
            x_mu_list, \
            x_logsigma_list = self.forward_test(data)

            x_data.append(data[:, -1, -1, :, -1])
            x_mu.append(x_mu_list[:, -1, -1, :, -1])
            z_feature.append(z_posterior_forward_list[:, -1, :])

            label_last_timestamp_tensor = labels[:, -1, -1, -1]

            llh_last_timestamp = self.loglikelihood_last_timestamp(data[:, -1, -1, :, -1],
                                                                   x_mu_list[:, -1, -1, :, -1],
                                                                   x_logsigma_list[:, -1, -1, :, -1])
            loss_res_box.append(llh_last_timestamp)
            labels_box.append(label_last_timestamp_tensor)

        loss_res_box = torch.cat(loss_res_box).cpu().numpy().astype(float).reshape(-1)
        labels_box = torch.cat(labels_box).cpu().numpy().reshape(-1)
        x_data = torch.cat(x_data).cpu().numpy().astype(float)
        x_mu = torch.cat(x_mu).cpu().numpy().astype(float)
        z_feature = torch.cat(z_feature).cpu().numpy().astype(float)

        best_valid_metrics = {}

        pot_result = pot_eval(loss_res_box, loss_res_box, labels_box, level=self.level)
        best_valid_metrics.update(pot_result)
        res_each_mach_TP.append(best_valid_metrics['pot-TP'])
        res_each_mach_TN.append(best_valid_metrics['pot-TN'])
        res_each_mach_FP.append(best_valid_metrics['pot-FP'])
        res_each_mach_FN.append(best_valid_metrics['pot-FN'])

        TP = np.sum(res_each_mach_TP)
        TN = np.sum(res_each_mach_TN)
        FP = np.sum(res_each_mach_FP)
        FN = np.sum(res_each_mach_FN)

        precision = TP / (TP + FP + 0.00001)
        recall = TP / (TP + FN + 0.00001)
        f1 = 2 * precision * recall / (precision + recall + 0.00001)
        print("pot-f1: {} pot-precision: {} pot-recall: {}".format(f1, precision, recall))
        print("Testing is complete!")

        if not res is None:
            res = res.append([{'name': dataName, 'pot-TP': TP, 'pot-TN': TN, 'pot-FP': FP, 'pot-FN': FN,
                               'pot-precision': precision, 'pot-recall': recall, 'pot-f1': f1}], ignore_index=True)
        sio.savemat('./log_tester/{}/{}.mat'.format(dataName.split('-')[0], dataName), {'x_data': x_data, 'x_mu': x_mu,
                                                                                        'z_feas': z_feature,
                                                                                        "loss_res": loss_res_box,
                                                                                        'labels': labels_box,
                                                                                        'Gx': Gx, 'Gz': Gz})
        return res, x_data, x_mu, z_feature, loss_res_box, labels_box, Gx, Gz

    def forward_test(self, data):
        with torch.no_grad():
            z_posterior_forward_list, z_mean_posterior_forward_list, z_logvar_posterior_forward_list, \
            z_mean_prior_forward_list, z_logvar_prior_forward_list, x_mu_list, x_logsigma_list = self.model(data)
            return z_posterior_forward_list, z_mean_posterior_forward_list, z_logvar_posterior_forward_list, \
                   z_mean_prior_forward_list, z_logvar_prior_forward_list, x_mu_list, x_logsigma_list

    def loglikelihood_last_timestamp(self, x, recon_x_mu, recon_x_logsigma):
        llh = -0.5 * torch.sum(torch.pow(((x.float() - recon_x_mu.float()) / torch.exp(recon_x_logsigma.float())),
                                         2) + 2 * recon_x_logsigma.float() + np.log(np.pi * 2), dim=-1)
        return llh


def main(i, j, dataName, res=None):
    import os
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu_id', type=int, default=None)

    parser.add_argument('--dataset_path', type=str, default='')
    parser.add_argument('--batch_size', type=int, default=None)
    parser.add_argument('--num_workers', type=int, default=None)
    parser.add_argument('--x_dim', type=int, default=None)
    parser.add_argument('--win_len', type=int, default=None)

    parser.add_argument('--c_dim', type=int, default=None)
    parser.add_argument('--z_dim', type=int, default=None)
    parser.add_argument('--h_dim', type=int, default=None)
    parser.add_argument('--n_head', type=int, default=None)
    parser.add_argument('--layer_xz', type=int, default=None)
    parser.add_argument('--layer_h', type=int, default=None)
    parser.add_argument('--q_len', type=int, default=None)
    parser.add_argument('--embd_h', type=int, default=None)
    parser.add_argument('--embd_s', type=int, default=None)
    parser.add_argument('--vocab_len', type=int, default=None)
    parser.add_argument('--dropout', type=float, default=None)
    parser.add_argument('--learning_rate', type=float, default=None)
    parser.add_argument('--beta', type=float, default=None)
    parser.add_argument('--max_beta', type=float, default=None)
    parser.add_argument('--anneal_rate', type=float, default=None)
    parser.add_argument('--epochs', type=int, default=None)
    parser.add_argument('--start_epoch', type=int, default=None)
    parser.add_argument('--checkpoints_interval', type=int, default=5)
    parser.add_argument('--checkpoints_path', type=str, default='')
    parser.add_argument('--checkpoints_file', type=str, default='')
    parser.add_argument('--log_path', type=str, default='')
    parser.add_argument('--log_file', type=str, default='')

    parser.add_argument('--level', type=float, default=None)
    parser.add_argument('--bf_search_min', type=float, default=None)
    parser.add_argument('--bf_search_max', type=float, default=None)
    parser.add_argument('--bf_search_step_size', type=float, default=None)

    args = parser.parse_args()

    if torch.cuda.is_available() and args.gpu_id >= 0:
        device = torch.device('cuda:%d' % args.gpu_id)
    else:
        device = torch.device('cpu')

    if not os.path.exists(args.dataset_path):
        raise ValueError('Unknown dataset path: {}'.format(args.dataset_path))

    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)

    if not os.path.exists(args.checkpoints_path):
        os.makedirs(args.checkpoints_path)

    if args.checkpoints_file == '':
        args.checkpoints_file = 'c_dim-{}_x_dim-{}_z_dim-{}_h_dim-{}_layer_xz-{}_layer_h-{}_embd_h-{}_n_head-{}_' \
                                'win_len-{}_q_len-{}_vocab_len-{}'.format(args.c_dim, args.x_dim, args.z_dim, args.h_dim,
                                                                          args.layer_xz, args.layer_h, args.embd_h,
                                                                          args.n_head, args.win_len, args.q_len,
                                                                          args.vocab_len)
    if args.log_file == '':
        args.log_file = 'c_dim-{}_x_dim-{}_z_dim-{}_h_dim-{}_layer_xz-{}_layer_h-{}_embd_h-{}_n_head-{}_win_len-{}_' \
                        'q_len-{}_vocab_len-{}_epochs-{}_loss'.format(args.c_dim, args.x_dim, args.z_dim, args.h_dim, args.layer_xz,
                                                       args.layer_h, args.embd_h, args.n_head, args.win_len,
                                                       args.q_len, args.vocab_len, args.epochs)

    kpi_value_test = KpiReader(args.dataset_path)
    test_loader = torch.utils.data.DataLoader(kpi_value_test, batch_size=args.batch_size,
                                              shuffle=False, num_workers=args.num_workers)

    stackedvagt = StackedVAGT(layer_xz=args.layer_xz, layer_h=args.layer_h, n_head=args.n_head, c_dim=args.c_dim,
                              x_dim=args.x_dim, z_dim=args.z_dim, h_dim=args.h_dim, embd_h=args.embd_h, embd_s=args.embd_s,
                              beta=args.beta, q_len=args.q_len, vocab_len=args.vocab_len, win_len=args.win_len,
                              dropout=args.dropout, anneal_rate=args.anneal_rate, max_beta=args.max_beta,
                              device=device, is_train=True).to(device)

    tester = Tester(stackedvagt, kpi_value_test, test_loader, log_path=args.log_path,
                    log_file=args.log_file, learning_rate=args.learning_rate, level=args.level, device=device,
                    checkpoints=os.path.join(args.checkpoints_path, args.checkpoints_file),
                    nsamples=None, sample_path=None)
    tester.load_checkpoint(args.epochs)
    res, x_data, x_mu, z_feature, loss_res_box, labels_box, Gx, Gz = tester.model_test('{}-{}-{}'.format(dataName, i, j), res)
    return res


if __name__ == "__main__":
    import warnings
    import pandas as pd
    import csv
    warnings.filterwarnings("ignore")
    res = None
    dataName = ''
    i = None
    j = None
    res = main(i, j, dataName, res)
