from lib.dataset.mydata import CifarData
import cv2
from skimage.metrics import structural_similarity
from sklearn import metrics as mr
import random
import argparse
import numpy as np
import warnings
import tqdm
import math


def mutual_info(img1, img2):
    return mr.normalized_mutual_info_score(img1.reshape(-1), img2.reshape(-1))


def PSNR(img1, img2):
    mse = np.mean((img1/255. - img2/255.) ** 2)
    if mse == 0:
        return 100
    PIXEL_MAX = 1
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))


def SSIM(img1, img2):
    gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
    gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
    score, diff = structural_similarity(gray1, gray2, full=True)
    # diff = (diff * 255).astype("uint8")
    # print("SSIM:{}".format(score))
    return score


def cal(data, args=None):
    n = len(data)
    sum = 0
    pbar = tqdm.tqdm(total=n * (n + 1) // 2)
    for i in range(n):
        for j in range(i + 1, n):
            if args.md == 'm':
                sum += mutual_info(data[i], data[j])
            elif args.md == 's':
                sum += SSIM(data[i], data[j])
            elif args.md == 'p':
                sum += PSNR(data[i], data[j])

            pbar.update(1)
    return sum / (n * (n + 1) // 2)


def main():
    warnings.filterwarnings('ignore')
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', default='rnd', type=str)
    parser.add_argument('-tr', default='x', type=str)
    parser.add_argument('-md', default='m', type=str)
    parser.add_argument('-p', default='output.txt')
    parser.add_argument('-b', default=50, type=int)
    parser.add_argument('-r', default=1, type=int)
    parser.add_argument('-t', default=0, type=int)
    args = parser.parse_args()
    path1 = './logs/record1/tr_cifar10_fnn.txt'
    if args.tr == 'x':
        path1 = './logs/record1/tr_cifar_train_ResNeXt29_2x64d.txt'
    elif args.tr == 'r':
        path1 = './logs/record1/tr_cifar10_res18.txt'
    elif args.tr == 'd':
        path1 = './logs/record1/tr_cifar_denesnet.txt'
    elif args.tr == 'v':
        path1 = './logs/record1/tr_cifar_train_vgg.txt'
    print(args)
    print(path1)

    data = CifarData(norm=False)
    sums = []
    for i in range(args.r):
        rand1 = 5000 - args.b
        if args.m == 'tr':
            print(rand1, rand1 + args.b)
            md = data.get_tr_suf(size=args.b, l=rand1, r=args.b + rand1, path=path1, loader=False)
            # md = data.get_tr_suf(size=args.b, l=args.b * rand1, r=args.b * (rand1 + 1))
        else:
            md = data.get_rnd_suf(size=args.b, loader=False)
        pics = []
        for item in md:
            pic, lab = item
            pics.append(pic.permute(1, 2, 0).numpy())
        sums.append(cal(pics, args=args))
        print(sums[-1])
    print(np.mean(sums), np.std(sums))


if __name__ == '__main__':
    main()

