
#%%
from pytorch_lightning.metrics import Precision
from pytorch_lightning.metrics.classification import AveragePrecision

import torch
target = torch.tensor([0, 1, 0, 0, 1, 1])
preds = torch.tensor([0.1, 0.8, 0, 0, 0, 1])
precision = Precision(num_classes=1, threshold=0.5)

print(precision(preds, target))

ap = AveragePrecision()
print(ap(preds, target))


# %%
import os

os.environ['KMP_DUPLICATE_LIB_OK']='True'
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
m=1000


x_mean = torch.zeros(2)
y_mean = torch.zeros(2)
x_cov = 1*torch.eye(2) # IMPORTANT: Covariance matrices must be positive definite
y_cov = 1*torch.eye(2)
y_cov[0,0] = 10
# y_cov = 3*torch.eye(2) - 1


px = MultivariateNormal(x_mean, x_cov)
qy = MultivariateNormal(y_mean, y_cov)
X = px.sample([m])
Y = qy.sample([m])

# X =torch.randn([1000, 3])
# Y = 10*torch.randn([1000,3])


# %%
from utils.metrics.mmd import MMDLoss
import torch
loss = MMDLoss(num_samples=1000, kernel_mul=4.0, kernel_num=5)
loss.fix_sigma = True

my_statistic = loss(X,Y )
print(my_statistic)

