import torch 
import numpy as np 
import lpips
from scipy.stats import pearsonr
from skimage.metrics import structural_similarity

def conv2D_output_size(img_size, padding, kernel_size, stride):
    # compute output shape of conv2D
    outshape = (np.floor((img_size[0] + 2 * padding[0] - (kernel_size[0] - 1) - 1) / stride[0] + 1).astype(int),
                np.floor((img_size[1] + 2 * padding[1] - (kernel_size[1] - 1) - 1) / stride[1] + 1).astype(int))
    return outshape

def convtrans2D_output_size(img_size, kernel_size, stride=1, padding=0, output_padding=0, dilation=1):
    # compute output shape of conv2D
    img_size = (img_size, img_size) if isinstance(img_size, int) else img_size
    kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
    stride = (stride, stride) if isinstance(stride, int) else stride
    padding = (padding, padding) if isinstance(padding, int) else padding
    output_padding = (output_padding, output_padding) if isinstance(output_padding, int) else output_padding
    dilation = (dilation, dilation) if isinstance(dilation, int) else dilation

    outshape = ((img_size[0] - 1) * stride[0] - 2 * padding[0]  + dilation[0]*(kernel_size[0] - 1) + output_padding[0] + 1,
                (img_size[1] - 1) * stride[1] - 2 * padding[1]  + dilation[1]*(kernel_size[1] - 1) + output_padding[1] + 1)
    return outshape

'''
    Computes the Loss function for the HVAE (Binary cross entropy + KLD)
'''
def vae_loss(target, model, x):
    model_out = model(x)
    x_reconst, z, mus, log_vars = model_out
    BCE = torch.nn.functional.binary_cross_entropy(x_reconst, target, reduction="sum")
    KLD = 0
    if not isinstance(mus, list):
        mus = [mus]
        log_vars = [log_vars]

    for mu, log_var in zip(mus, log_vars):
        KLD += -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD, model_out

class neural_decoder_loss(object):
    def __init__(self) -> None:
        self.loss_fn = lpips.LPIPS(net='vgg').cuda()

    def __call__(self, img, model, x):
        model_out = model(x)
        x_reconst, z_pred = model_out
        z = model.hvae.encode(img)[0]
        loss = 0 
        for zz, zz_pred in zip(z, z_pred):
            loss += torch.nn.MSELoss()(zz, zz_pred)

        d = self.loss_fn.forward(img, x_reconst, normalize=True)
        loss += torch.sum(d)
        loss += torch.nn.functional.binary_cross_entropy(x_reconst, img, reduction="sum")
        return loss, model_out

"""
    Computes the n-way classification accuracy.
    y and y_pred are [B, 3, W, H]
"""
def n_way_classification(y, y_pred, n=2, metric="pearson"):
    assert metric in ["pearson", "ssim"]
    N = y.shape[0]
    cnt = 0
    for i in range(N):
        pred = y_pred[i]
        select = np.random.randint(0, N, n)
        select = [i] + list(select)
        best_indx = i 
        best_metric = -np.inf
        for j in select:
            if metric == "person":
                pred = pred.flatten()
                target = y[j].flatten()
                pears = pearsonr(pred, target)
                if pears > best_metric:
                    best_metric = pears
                    best_indx = j 
            else:
                ssim = structural_similarity(pred, y[j], channel_axis=0)
                if ssim > best_metric:
                    best_metric = ssim
                    best_indx = j 
        if i == best_indx:
            cnt += 1
    return cnt / N 
            
                
    
