import os, sys
from typing import OrderedDict
import imageio
import numpy as np
import argparse
import math
import torchvision.transforms as transforms
import torch
import cv2 as cv
import glob as glob
from numpy import clip

from . import densenet as model

# import model
from PIL import Image


def load_images(args, image_dir):
    images = []
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )

    for fn in os.listdir(image_dir):
        ext = os.path.splitext(fn)[1].lower()
        img_path = os.path.join(image_dir, fn)
        img = Image.open(img_path)
        img = transform(img).numpy()
        images.append(img)
    return images


def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / np.expand_dims(e_x.sum(axis=1), axis=1)  # only difference


def preds2score(preds, splits=10):
    scores = []
    for i in range(splits):
        part = preds[
            (i * preds.shape[0] // splits) : ((i + 1) * preds.shape[0] // splits), :
        ]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))
    return np.mean(scores), np.std(scores)


def get_inception_score(args, images):
    splits = args.num_splits
    inps = []
    for img in images:
        img = img.astype(np.float32)
        inps.append(np.expand_dims(img, 0))
    preds = []
    n_batches = int(math.ceil(float(len(inps)) / float(args.batch_size)))
    n_preds = 0

    net = model.DenseNet121() 
    state_dict = torch.load(args.model_dir, map_location="cpu", weights_only=True)[
        "net"
    ]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        key = k.replace("module.", "", 1)  # Remove 'module.' only at the beginning
        new_state_dict[key] = v
    net.load_state_dict(new_state_dict)
    print("load model successfully")

    for i in range(n_batches):
        sys.stdout.write(".")
        sys.stdout.flush()
        inp = inps[(i * args.batch_size) : min((i + 1) * args.batch_size, len(inps))]
        inp = np.concatenate(inp, 0)
        # inp = np.expand_dims(inp, axis=1)
        inp = torch.from_numpy(inp) 
        outputs = net(inp)
        pred = outputs.data.tolist()
        # pred = softmax(pred)
        preds.append(pred)
        n_preds += outputs.shape[0]
    preds = np.concatenate(preds, 0)
    preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True)
    mean_, std_ = preds2score(preds, splits)
    return mean_, std_


def crop10x10(in_path, out_path):
    # mnist
    x_cors = [2, 32, 62, 92, 122, 152, 182, 212, 242, 272]
    y_cors = [2, 32, 62, 92, 122, 152, 182, 212, 242, 272]
    img_size = 28
    number_channel = 1

    print(out_path)
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    in_list = glob.glob(in_path + "*.png")
    count = 0
    for img_name in in_list:
        count += 1
        if number_channel == 1:
            img = cv.imread(img_name, 0)
        else:
            img = cv.imread(img_name, 1)
        for x in x_cors:
            for y in y_cors:
                img_crop = img[x : x + img_size, y : y + img_size]
                # print(img_crop.shape)
                if number_channel == 1:
                    h, w = img_crop.shape
                else:
                    h, w, c = img_crop.shape
                if (h != img_size) or (w != img_size):
                    print("ERROR!!!")
                    exit()

                out_name = out_path + str(count) + "_" + str(x) + "_" + str(y) + ".png"
                # print(out_name)
                cv.imwrite(out_name, img_crop)


def get_cifar_inception_score(
    input_image_dir, model_dir=None, batch_size=100, num_splits=10, device="cpu"
):
    if model_dir is None:
        model_dir = "checkpoints/cifar_model.pth"
        # set relative to the current file
        model_dir = os.path.join(os.path.dirname(__file__), model_dir)

    class Args:
        pass

    args = Args()
    args.input_image_dir = input_image_dir
    args.model_dir = model_dir
    args.batch_size = batch_size
    args.num_splits = num_splits

    images = load_images(args, input_image_dir)
    mean, std = get_inception_score(args, images)
    if mean != mean:
        mean = 0.0  # this happens (I expect) when the predictor only predicts one class
        std = 0.0
    return mean, std


def main(args):
    images = load_images(args, args.input_image_dir)
    mean, std = get_inception_score(args, images)
    print("\nInception mean: ", mean)
    print("Inception std: ", std)


if __name__ == "__main__":
    GPUID = 0
    os.environ["CUDA_VISIBLE_DEVICES"] = str(GPUID)
    print("PACKAGES LOADED")

    in_path = "./data/"
    out_path = in_path + "/crops/"
    crop10x10(in_path, out_path)
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_image_dir", default=out_path)
    parser.add_argument("--model_dir", default="checkpoints/mnist_model_10.ckpt")
    parser.add_argument("--img_size", default=28)
    parser.add_argument("--batch_size", default=100)
    parser.add_argument("--channel", default=1)
    parser.add_argument("--num_splits", default=10)
    args = parser.parse_args()
    main(args)
