from torchvision import transforms
from PIL import Image
from torch.nn import functional as F
from scipy.stats import entropy
from torch import nn
from torch.autograd import Variable
from torchvision.models.inception import inception_v3
from pytorch_fid import fid_score, inception
import os
import torch.utils.data
import torch
import open_clip
import clip
import glob
import numpy as np

def inception_score(imgs, cuda=True, batch_size=32, resize=True, splits=10):
    """Computes the inception score of the generated images imgs

    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
    N = len(imgs)

    assert batch_size > 0
    assert N >= batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print("WARNING: You have a CUDA device, so you should probably set cuda=True")
        dtype = torch.FloatTensor

    # Set up dataloader
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)

    # Load inception model
    inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
    inception_model.eval();
    up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
    def get_pred(x):
        if resize:
            x = up(x)
        x = inception_model(x)
        #return torch.nn.functional.softmax(x, dim=0).data.cpu().numpy()
        return F.softmax(x).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, 1000))

    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]
        #print('batch size: {batch.size()}')

        preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)

    # Now compute the mean kl-div
    split_scores = []

    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)


def batch_cosine_similarity(imgs1, imgs2, batch_size, device):
    # model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
    # model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
    model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='/root/ViT-H-14/open_clip_pytorch_model.bin')
    model.eval()
    
    # process
    imgs1 = [preprocess(fig) for fig in imgs1]
    imgs2 = [preprocess(fig) for fig in imgs2]
    
    # 计算总批次
    total_batches = len(imgs1) // batch_size + (0 if len(imgs1) % batch_size == 0 else 1)

    model = model.to(device)
    all_cosine_similarities = []
    with torch.no_grad():
        for i in range(total_batches):
            # 从列表中获取当前批次
            batch_imgs1 = imgs1[i*batch_size : (i+1)*batch_size]
            batch_imgs2 = imgs2[i*batch_size : (i+1)*batch_size]

            # 把图像数据转移到指定的设备上
            batch_imgs1 = torch.stack(batch_imgs1).to(device)
            batch_imgs2 = torch.stack(batch_imgs2).to(device)

            # 获取embeddings
            embeddings1 = model.encode_image(batch_imgs1)
            embeddings2 = model.encode_image(batch_imgs2)

            # 标准化embeddings
            embeddings1 = embeddings1 / embeddings1.norm(dim=1, keepdim=True)
            embeddings2 = embeddings2 / embeddings2.norm(dim=1, keepdim=True)

            # 计算余弦相似度
            cosine_similarities = (embeddings1 * embeddings2).sum(dim=1)
            all_cosine_similarities.extend(cosine_similarities.cpu().numpy().tolist())

    # 返回平均余弦相似度
    return sum(all_cosine_similarities) / len(all_cosine_similarities)


def get_images(path, score_type):
    """获取需要处理的图像列表"""
    files = [filename for filename in glob.glob(f"{path}/*.jpg")]
    images = [Image.open(file) for file in files]
    
    if score_type == 'IS':
        transform = transforms.Compose([
            transforms.Resize((256,256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        transformed_images = [transform(image).to(device) for image in images]
    elif score_type == 'CLIP':
        '''
        transform = transforms.Compose([
            transforms.Resize((224, 224)), 
            transforms.CenterCrop(224), 
            transforms.ToTensor(), 
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        '''
        transformed_images = images
    
    return transformed_images

# 设置 device
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")

bsz = 64

# 路径设置为你的图像文件夹路径
path_to_images = ['/root/autodl-tmp/img2img_unlearning/diffusion/experiments/test_ours_all-6.6-encoder-center-1e-5-100-10-multi/ours_all-6.6-encoder-center-1e-5-100-10-multi/GT',
                  '/root/autodl-tmp/img2img_unlearning/diffusion/experiments/test_ours_all-6.6-encoder-center-1e-5-100-10-multi/ours_all-6.6-encoder-center-1e-5-100-10-multi/Out']

'''
# 计算生成图像的 IS 分数
if path_to_images[1] is not None:
    imgs = get_images(os.path.join(path_to_images[1], 'forget'), 'IS')
    is_mean, is_std = inception_score(imgs, cuda=True, batch_size=bsz, resize=True, splits=10)
    print(f"Inception Score of Forget Set: Mean - {is_mean}, Std - {is_std}")
    
    imgs = get_images(os.path.join(path_to_images[1], 'retain'), 'IS')
    is_mean, is_std = inception_score(imgs, cuda=True, batch_size=bsz, resize=True, splits=10)
    print(f"Inception Score of Retain Set: Mean - {is_mean}, Std - {is_std}")
'''

# 计算原始图像集和生成图像集的 FID 分数
if len(path_to_images) == 2:
    paths = []
    # forget fid value
    paths.append(path_to_images[0] + '/forget')
    paths.append(path_to_images[1] + '/forget')
    fid_value = fid_score.calculate_fid_given_paths(paths, batch_size=bsz, device=device, dims=2048, num_workers=16)
    print("FID Score of Forget Set: ", fid_value)
    
    paths = []
    paths.append(path_to_images[0] + '/retain')
    paths.append(path_to_images[1] + '/retain')
    fid_value = fid_score.calculate_fid_given_paths(paths, batch_size=bsz, device=device, dims=2048, num_workers=16)
    print("FID Score of Retain Set: ", fid_value)

# 计算原始图像集和生成图像集对应 CLIP 分数（CLIP embedding 的余弦相似度平均值）
if len(path_to_images) == 2:
    # compute forget clip score
    imgs1 = get_images(os.path.join(path_to_images[0], 'forget'), 'CLIP')
    imgs2 = get_images(os.path.join(path_to_images[1], 'forget'), 'CLIP')
    average_cosine_similarity = batch_cosine_similarity(imgs1, imgs2, bsz, device)
    print(f"CLIP Score of Forget Set: {average_cosine_similarity}")
    
    imgs1 = get_images(os.path.join(path_to_images[0], 'retain'), 'CLIP')
    imgs2 = get_images(os.path.join(path_to_images[1], 'retain'), 'CLIP')
    average_cosine_similarity = batch_cosine_similarity(imgs1, imgs2, bsz, device)
    print(f"CLIP Score of Retain Set: {average_cosine_similarity}")
