from scipy.stats import kstest, ks_2samp
from sklearn.metrics import roc_auc_score
from multiprocessing import Pool
from tabulate import tabulate
from tqdm.auto import tqdm
import numpy as np
import torch as th
import argparse
import sys

test_list = ['cifar10', 'cifar100', 'celeba', 'lsun', 'svhn']
total_size = 5000
device = th.device('cuda:0')


def ks_test_(param):
    noiseT, zT = param
    results = []

    for j in range(repeat_size):
        k = ks_2samp(noiseT[:, j], zT[:, j])[0]
        # k = kstest(noiseT[:, j], 'norm')[0]
        results.append(k)

    return np.mean(results)


def ks_test():
    z = np.load(f'temp/inv_test/{target}/{target}_recon.npz')['noise']
    choice = np.random.choice(list(range(len(z))), total_size, replace=False)
    z = z[choice].reshape(total_size, -1)
    z = th.from_numpy(z).float().to(device)

    T = np.random.randn(z.shape[1], repeat_size)
    T_norm = np.linalg.norm(T, axis=0, keepdims=True)
    T = T / T_norm
    T = th.from_numpy(T).float().to(device)
    zT = th.mm(z, T).cpu().numpy()

    score_dict = {}

    for item in tqdm(test_list, desc='ks_test'):
        data = np.load(f'temp/inv_test/{target}/{item}_recon.npz')['noise']
        choice = np.random.choice(list(range(len(data))), total_size, replace=False)
        data = data[choice].reshape(total_size, -1)

        score, params = [], []
        for i in range(total_size // batch_size):
            noise = data[i * batch_size: (i + 1) * batch_size]
            noise = th.from_numpy(noise).float().to(device)
            noiseT = th.mm(noise, T).cpu().numpy()
            params.append((noiseT, zT))

        pool = Pool(processes=16)
        score = list(tqdm(pool.imap(ks_test_, params), total=len(params),
                          desc=item, leave=False))

        pool.close()
        pool.join()
        score_dict[item] = score

    auroc(score_dict)


def denoising_test():
    loss_dict, info_dict = {}, {}

    for item in tqdm(test_list, desc='loss_test'):
        data = np.load(f'temp/ood_main/sin_r{repeat_size}/{target}/{item}.npz')
        img, noise, inter = data['img'], data['noise'], data['inter']
        # choice = np.random.choice(list(range(len(noise))), total_size, replace=False)
        img = img[:total_size]
        noise = noise[:total_size * repeat_size]
        for i in range(len(inter)):
            inter[i] = inter[i][:total_size * repeat_size]

        loss, info, params = [], [], []
        for i in range(total_size // batch_size):
            img0 = img[i * batch_size: (i + 1) * batch_size]
            ori = noise[i * batch_size * repeat_size: (i + 1) * batch_size * repeat_size]
            rec = inter[2][i * batch_size * repeat_size: (i + 1) * batch_size * repeat_size]
            # loss_ = np.abs(ori - rec).sum(axis=(1, 2, 3)).mean()
            # info_ = np.abs(img0[..., 1:, :] - img0[..., :-1, :]).sum(axis=(1, 2, 3)).mean() + \
            #         np.abs(img0[..., :, 1:] - img0[..., :, :-1]).sum(axis=(1, 2, 3)).mean()
            loss_ = compute_norm(ori, rec, (batch_size * repeat_size, -1))
            info_ = compute_norm(img0[..., 1:, :], img0[..., :-1, :], (batch_size, -1)) + \
                    compute_norm(img0[..., :, 1:], img0[..., :, :-1], (batch_size, -1))

            loss.append(loss_)
            info.append(info_)

        loss_dict[item] = loss
        info_dict[item] = info

    # print(loss_dict[target][:5])
    auroc_loss = auroc(loss_dict)
    auroc_cor = correct(info_dict, loss_dict)
    auroc_cross = correct_cross(info_dict, loss_dict)

    output_table = []
    for item in auroc_loss:
        output_table.append([item, auroc_loss[item], auroc_cor[item], auroc_cross[item]])

    print(tabulate(output_table, ['name', 'loss', 'cor', 'cross'], tablefmt='simple'))


def inter_test(method='pndm4'):
    loss_dict, info_dict = {}, {}

    for item in tqdm(test_list, desc='inter_test'):
        data = np.load(f'temp/ood_main/mul_s{sample_speed}_r{repeat_size}_{method}/{target}/{item}.npz')
        img, inter = data['img'], data['inter']
        # choice = np.random.choice(list(range(len(noise))), total_size, replace=False)
        img = img[:total_size]
        for i in range(len(inter)):
            inter[i] = inter[i][:total_size * repeat_size]

        loss, info, params = [], [], []
        for i in range(total_size // batch_size):
            img0 = img[i * batch_size: (i + 1) * batch_size]
            ori = img0.reshape(-1, 1, 3, 32, 32)
            rec0 = inter[0][i * batch_size * repeat_size: (i + 1) * batch_size * repeat_size]
            rec1 = inter[1][i * batch_size * repeat_size: (i + 1) * batch_size * repeat_size]
            # rec2 = inter[2][i * batch_size * repeat_size: (i + 1) * batch_size * repeat_size]
            rec0 = rec0.reshape(-1, repeat_size, 3, 32, 32)[:, :repeat_size, ...]
            rec1 = rec1.reshape(-1, repeat_size, 3, 32, 32)[:, :repeat_size, ...]
            # rec2 = rec2.reshape(-1, repeat_size, 3, 32, 32)[:, :repeat_size, ...]
            # loss_ = np.abs(rec0 - rec1).sum(axis=(2, 3, 4)).mean() + np.abs(rec1 - rec2).sum(axis=(2, 3, 4)).mean()
            # info_ = np.abs(img0[..., 1:, :] - img0[..., :-1, :]).sum(axis=(1, 2, 3)).mean() + \
            #         np.abs(img0[..., :, 1:] - img0[..., :, :-1]).sum(axis=(1, 2, 3)).mean()
            loss_ = compute_norm(rec0, rec1, (batch_size, repeat_size, -1), order_c=order)
            # loss_ = compute_norm(ori, rec2, (batch_size, repeat_size, -1))
            info_ = compute_norm(img0[..., 1:, :], img0[..., :-1, :], (batch_size, -1), order_c=order1) + \
                    compute_norm(img0[..., :, 1:], img0[..., :, :-1], (batch_size, -1), order_c=order1)

            loss.append(loss_)
            info.append(info_)

        loss_dict[item] = loss
        info_dict[item] = info

    # print(loss_dict[target][:5])
    auroc_loss = auroc(loss_dict)
    auroc_cor = correct(info_dict, loss_dict)
    auroc_cross = correct_cross(info_dict, loss_dict)

    output_table = []
    for item in auroc_loss:
        output_table.append([item, auroc_loss[item], auroc_cor[item], auroc_cross[item]])

    print(tabulate(output_table, ['name', 'loss', 'cor', 'cross'], tablefmt='simple'))


def auroc(score_dict, reverse=False):
    auroc_dict = {}
    y = [0] * len(score_dict[target]) + [1] * len(score_dict[target])
    if reverse:
        y = list(reversed(y))

    for item in score_dict:
        scores = score_dict[target] + score_dict[item]
        result = roc_auc_score(y, scores)
        # print(item, result)
        auroc_dict[item] = result

    return auroc_dict


def correct(info_dict, loss_dict):
    gamma = np.polyfit(info_dict[target], loss_dict[target], 1)[0]
    score_dict = {}

    for item in info_dict:
        score_dict[item] = [loss_dict[item][j] - gamma * info_dict[item][j] for j in range(len(loss_dict[item]))]

    # auroc(score_dict)
    return auroc(score_dict)


def correct_cross(info_dict, loss_dict):
    auroc_dict = {}
    for item in info_dict:
        cross_list = test_list[:]
        cross_list.remove(item)
        try:
            cross_list.remove(target)
        except ValueError:
            pass
        info, loss = [], []
        for item1 in cross_list:
            info += info_dict[item1]
            loss += loss_dict[item1]

        gamma = np.polyfit(info, loss, 1)[0]
        score_dict = {item: [loss_dict[item][j] - gamma * info_dict[item][j] for j in range(len(loss_dict[item]))],
                      target: [loss_dict[target][j] - gamma * info_dict[target][j] for j in
                               range(len(loss_dict[target]))]}

        y = [0] * len(score_dict[target]) + [1] * len(score_dict[target])
        scores = score_dict[target] + score_dict[item]
        result = roc_auc_score(y, scores)
        # print(item, result)
        auroc_dict[item] = result

    return auroc_dict


def compute_norm(a, b, shape, order_c=1):
    delta = (a - b).reshape(shape)
    return np.linalg.norm(delta, ord=order_c, axis=-1).mean()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--target', '-t', type=str,
                        help='Choose the target dataset')
    parser.add_argument('--batch_size', '-b', default=5, type=int,
                        help='Choose the batch size')
    parser.add_argument('--repeat_size', '-r', default=8, type=int,
                        help='Choose the repeat size')
    parser.add_argument('--sample_speed', '-s', default=10, type=int,
                        help='Choose the sample speed')
    parser.add_argument('--order', '-o', default='1', type=str,
                        help='Choose the order of the norm')
    parser.add_argument('--order1', default='1', type=str,
                        help='Choose the order of the norm')
    parser.add_argument('--method', '-m', default='pndm4', type=str,
                        help='Choose the method')
    args = parser.parse_args()

    target = args.target
    batch_size = args.batch_size
    repeat_size = args.repeat_size
    sample_speed = args.sample_speed
    order = np.inf if args.order == 'inf' else int(args.order)
    order1 = np.inf if args.order1 == 'inf' else int(args.order1)

    print('dataset:', target, ' batch_size:', batch_size, ' repeat_size:', repeat_size,
          ' sample_speed:', sample_speed, ' order:', order, sep='')

    # ks_test()
    # denoising_test()
    # inter_test(method='ddim')
    inter_test(method='pndm4')
