import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from classifier import Inception, InceptionBlock, Flatten, find_model

import numpy as np
from collections import Counter
def inception_score(classifier, sequences, batch_size, splits=10):
    # 设置分类器为评估模式
    classifier.eval()
    
    # 获取序列特征
    preds = []

    n_batches = len(sequences) // batch_size # = 8

    for i in range(n_batches):
        batch = sequences[i * batch_size: (i + 1) * batch_size]
        with torch.no_grad():

            pred = F.softmax(classifier(batch), dim=1)

        preds.append(pred.cpu().numpy())

    preds = np.concatenate(preds, axis=0)
    # ---------------------calculate fidelity matric: --------------------------------------
    # classes = np.argmax(preds, axis=1)
    # label_distribution = Counter(classes)
    # print("Label distribution:", label_distribution)
    # assert 1==2

    _entropy = entropy(preds)  # (B,)
    avg_entropy = np.mean(_entropy)  # fidelity

    # -------------------------------- diversity -------------------------------------------
    y = np.argmax(preds, axis=1)
    class_counts = np.bincount(y)
    y = class_counts / len(sequences)
    y = y.reshape(1, -1)

    # ---------------------calculate IS ----------------------------------------------------
    # 计算每个类别的概率分布
    scores = []
    for i in range(splits):
        part = preds[i * (preds.shape[0] // splits): (i + 1) * (preds.shape[0] // splits), :]
        part = part + 1e-12
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, axis=0), 0)))
        kl = np.mean(np.sum(kl, axis=1))
        scores.append(np.exp(kl))

    return np.mean(scores), np.std(scores), avg_entropy, entropy(y)[0]


def entropy(probabilities):
    # 避免概率为0时出现错误
    probabilities = np.clip(probabilities, 1e-12, 1.0)
    return -np.sum(probabilities * np.log(probabilities), axis=1)