import numpy as np
import torch
from scipy.spatial.distance import jensenshannon
from scipy.stats import chisquare
from skimage.metrics import structural_similarity as ssim
from pytorch_msssim import ssim as pyssim

# 卡方检验
def chi_square(ori_prob, cur_prob, N=100000):
    chi2_stat, p_value = chisquare(f_obs=np.array(ori_prob)*N, f_exp=np.array(cur_prob)*N)
    return p_value

def chi2_statistic(O, E, eps=1e-8):
    O = O + eps  # 避免 0
    E = E + eps
    return torch.sum((O - E) ** 2 / E)


# JS散度 (可导)
def js(ori_prob, cur_prob):
    def kl_div(P, Q):
        eps = 1e-8
        P = P + eps
        Q = Q + eps
        return torch.sum(P * torch.log(P / Q))

    M = 0.5 * (ori_prob + cur_prob)
    return 0.5 * kl_div(ori_prob, M) + 0.5 * kl_div(cur_prob, M)

def js_batch(ori_prob, cur_prob, eps=1e-8):
    # ori_prob, cur_prob (B, C)
    ori_prob = ori_prob + eps
    cur_prob = cur_prob + eps
    M = 0.5 * (ori_prob + cur_prob)

    kl1 = torch.sum(ori_prob * torch.log(ori_prob / M), dim=1)
    kl2 = torch.sum(cur_prob * torch.log(cur_prob / M), dim=1)

    return 0.5 * (kl1 + kl2)  # shape = (B,)


def hellinger_distance(p, q, eps=1e-8):
    """
    计算Hellinger距离
    p, q: shape (batch_size, num_bins)
    """
    p = torch.clamp(p, min=eps)
    q = torch.clamp(q, min=eps)
    sqrt_p = torch.sqrt(p)
    sqrt_q = torch.sqrt(q)
    diff = sqrt_p - sqrt_q
    distance = torch.norm(diff, p=2, dim=1) / torch.sqrt(torch.tensor(2.0, device=p.device))
    return distance


# 图片感知相似度 Structural Similarity Index
def SSIM(img_0, img_1):
    return ssim(img_0.numpy(), img_1.numpy(), data_range=1.0)

def pySSIM(img_0, img_1):
    if len(img_0.shape) == 3:
        img_0 = img_0.unsqueeze(0)
        img_1 = img_1.unsqueeze(0)
    return pyssim(img_0, img_1, data_range=1.0, size_average=True)


def fid_imgs(img, fake):
    b_size = img.shape[0]
    img, fake = img.reshape(b_size, -1), fake.reshape(b_size, -1)
    norm_img = torch.divide(img, torch.norm(img, dim = 1).reshape(b_size, -1))
    norm_fake = torch.divide(fake, torch.norm(fake, dim = 1).reshape(b_size, -1))
    fids = torch.abs(torch.bmm(norm_img.view(b_size, 1, -1), norm_fake.view(b_size, -1, 1))**2)
    return fids


def finite_difference_gradient(f, x, eps=1e-6):
    grad = torch.zeros_like(x)
    for i in range(x.numel()):
        x_perturb_plus = x.clone().detach().view(-1)
        x_perturb_minus = x.clone().detach().view(-1)
        x_perturb_plus[i] += eps
        x_perturb_minus[i] -= eps

        f_plus = f(x_perturb_plus.view_as(x))
        f_minus = f(x_perturb_minus.view_as(x))

        grad.view(-1)[i] = (f_plus - f_minus) / (2 * eps)
    return grad


def nes_bandits(x, ori_out, ori_prob, obj, g_pre, sigma=0.1, beta=0.9, steps=50):
    """Natural Evolution Strategies with Bandits"""
    # 针对输入x算梯度
    grad = torch.zeros_like(x)
    for _ in range(steps):
        u = torch.randn_like(x)
        o_plus, _, _ = obj(x+sigma*u, ori_out, ori_prob)
        o_minus, _, _ = obj(x-sigma*u, ori_out, ori_prob)
        grad += (o_plus-o_minus)*u
    grad /= (2*sigma*steps)
    grad = (1-beta)*g_pre + beta*grad
    return grad


def nes(x, y, loss, sigma=0.1, steps=50):
    grad = torch.zeros_like(x)
    for _ in range(steps):
        u = torch.randn_like(x)
        o_plus = loss(x + sigma * u, y)
        o_minus = loss(x - sigma * u, y)
        grad += (o_plus - o_minus) * u
    grad /= (2 * sigma * steps)
    return grad


def bandit_grad_est(x, grad, ori_out, ori_prob, obj, delta=1,eps=0.01):
    u = torch.randn_like(x) / (x.numel() ** 0.5)

    q1 = grad + delta * u
    q2 = grad - delta * u
    L1, _, _ = obj(x + eps * q1/torch.norm(q1), ori_out, ori_prob)
    L2, _, _ = obj(x - eps * q2/torch.norm(q2), ori_out, ori_prob)
    grad_v =  ((L2 - L1) / (delta * eps)) * u
    return grad_v
