import os, sys
import imageio
import numpy as np
import argparse
import math
import torchvision.transforms as transforms
import torch
import cv2 as cv
import glob as glob
from numpy import clip

from . import resnet_v2 as model

# import model
from PIL import Image


def load_images(args, image_dir):
    images = []
    transform = transforms.Compose(
        [
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ]
    )

    for fn in os.listdir(image_dir):
        ext = os.path.splitext(fn)[1].lower()
        img_path = os.path.join(image_dir, fn)
        img = Image.open(img_path).convert("L")
        img = transform(img).numpy()
        images.append(img)
    return images


def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / np.expand_dims(e_x.sum(axis=1), axis=1)  # only difference


def preds2score(preds, splits=10):
    scores = []
    for i in range(splits):
        part = preds[
            (i * preds.shape[0] // splits) : ((i + 1) * preds.shape[0] // splits), :
        ]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))
    return np.mean(scores), np.std(scores)


def get_inception_score(args, images):
    splits = args.num_splits
    inps = []
    for img in images:
        img = img.astype(np.float32)
        # inps.append(np.expand_dims(img, 0))
        inps.append(img)
    preds = []
    n_batches = int(math.ceil(float(len(inps)) / float(args.batch_size)))
    n_preds = 0

    net = torch.nn.DataParallel(model.ResNet(20, num_classes=10))
    net.load_state_dict(torch.load(args.model_dir, weights_only=True))
    net.cpu()
    net.cuda()
    print("load model successfully")

    for i in range(n_batches):
        sys.stdout.write(".")
        sys.stdout.flush()
        inp = inps[(i * args.batch_size) : min((i + 1) * args.batch_size, len(inps))]
        inp = np.concatenate(inp, 0)
        inp = np.expand_dims(inp, axis=1)
        inp = torch.from_numpy(inp).cuda()
        outputs = net(inp)
        pred = outputs.data.tolist()
        # pred = softmax(pred)
        preds.append(pred)
        n_preds += outputs.shape[0]
    preds = np.concatenate(preds, 0)
    preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True)
    mean_, std_ = preds2score(preds, splits)
    return mean_, std_


def get_fashion_mnist_inception_score(
    input_image_dir,
    model_dir=None,
    img_size=28,
    batch_size=100,
    channel=1,
    num_splits=10,
    device="cpu",
):

    if model_dir is None:
        model_dir = "checkpoints/fashion_resnet20_model.pth"
        # set relative to the current file
        model_dir = os.path.join(os.path.dirname(__file__), model_dir)

    class Args:
        pass

    args = Args()
    args.input_image_dir = input_image_dir
    args.model_dir = model_dir
    args.img_size = img_size
    args.batch_size = batch_size
    args.channel = channel
    args.num_splits = num_splits

    images = load_images(args, input_image_dir)
    mean, std = get_inception_score(args, images)
    if mean != mean:
        mean = 0.0  # this happens (I expect) when the predictor only predicts one class
        std = 0.0
    return mean, std


def main(args):
    images = load_images(args, args.input_image_dir)
    mean, std = get_inception_score(args, images)
    print("\nInception mean: ", mean)
    print("Inception std: ", std)


if __name__ == "__main__":
    GPUID = 0
    os.environ["CUDA_VISIBLE_DEVICES"] = str(GPUID)
    print("PACKAGES LOADED")

    in_path = "./data/"
    out_path = in_path + "/crops/"
    # crop10x10(in_path, out_path)
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_image_dir", default=out_path)
    parser.add_argument("--model_dir", default="checkpoints/mnist_model_10.ckpt")
    parser.add_argument("--img_size", default=28)
    parser.add_argument("--batch_size", default=100)
    parser.add_argument("--channel", default=1)
    parser.add_argument("--num_splits", default=10)
    args = parser.parse_args()
    main(args)
