import torch
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
from math import exp

def normalize(tensor):
    return (tensor + 1)/2

def normalize_clip(tensor):
    return normalize(tensor).clip(0,1)

class PSNR(torch.nn.Module):
    def __init__(self, max_val=1.0):
        super(PSNR, self).__init__()
        self.max_val = max_val

    def _psnr(self, img1, img2, max_val=1.0):
        if img1.shape != img2.shape:
            raise ValueError(f"img1 and img2 must have the same shape, got {img1.shape} vs {img2.shape}")

        diff = img1 - img2
        mse = (diff ** 2).mean()

        psnr = 10.0 * torch.log10((max_val ** 2) / mse)
        return psnr

    def forward(self, img1, img2):
        return self._psnr(img1, img2, self.max_val)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, dim=2, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.dim = dim
        if dim==3:
            self.window = self._create_window_3D(window_size, self.channel)
        else:
            self.window = self._create_window(window_size, self.channel)

    def _gaussian(self, window_size, sigma):
        gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
        return gauss/gauss.sum()

    def _create_window(self, window_size, channel):
        _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
        return window

    def _create_window_3D(self, window_size, channel):
        _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t())
        _3D_window = _1D_window.mm(_2D_window.reshape(1, -1)).reshape(window_size, window_size, window_size).float().unsqueeze(0).unsqueeze(0)
        window = Variable(_3D_window.expand(channel, 1, window_size, window_size, window_size).contiguous())
        return window

    def _ssim(self, img1, img2, window, window_size, channel, size_average = True):
        mu1 = F.conv2d(img1, window.to(img1.device), padding = window_size//2, groups = channel)
        mu2 = F.conv2d(img2, window.to(img2.device), padding = window_size//2, groups = channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1*mu2

        sigma1_sq = F.conv2d(img1*img1, window.to(img1.device), padding = window_size//2, groups = channel) - mu1_sq
        sigma2_sq = F.conv2d(img2*img2, window.to(img2.device), padding = window_size//2, groups = channel) - mu2_sq
        sigma12 = F.conv2d(img1*img2, window.to(img1.device), padding = window_size//2, groups = channel) - mu1_mu2

        C1 = 0.01**2
        C2 = 0.03**2 

        ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

        if size_average:
            return ssim_map.mean()
        else:
            return ssim_map.mean(1).mean(1).mean(1)
        
    def _ssim_3D(self, img1, img2, window, window_size, channel, size_average = True):
        mu1 = F.conv3d(img1, window.to(img1.device), padding = window_size//2, groups = channel)
        mu2 = F.conv3d(img2, window.to(img2.device), padding = window_size//2, groups = channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)

        mu1_mu2 = mu1*mu2

        sigma1_sq = F.conv3d(img1*img1, window.to(img1.device), padding = window_size//2, groups = channel) - mu1_sq
        sigma2_sq = F.conv3d(img2*img2, window.to(img2.device), padding = window_size//2, groups = channel) - mu2_sq
        sigma12 = F.conv3d(img1*img2, window.to(img1.device), padding = window_size//2, groups = channel) - mu1_mu2

        C1 = 0.01**2
        C2 = 0.03**2

        ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

        if size_average:
            return ssim_map.mean()
        else:
            return ssim_map.mean(1).mean(1).mean(1)

    def forward(self, img1, img2):
        channel = img1.size(1)
        if self.dim==3:
            return self._ssim_3D(img1, img2, self.window, self.window_size, channel, self.size_average)
        return self._ssim(img1, img2, self.window, self.window_size, channel, self.size_average)

def compute_confidence_and_coverage(y_true, y_samples, n_bins=20):
    confidence_levels = np.linspace(0.01, 0.99, n_bins)
    alpha = 1 - confidence_levels

    lower_q = np.quantile(y_samples, alpha / 2, axis=1)      # shape: (n_bins, pixels)
    upper_q = np.quantile(y_samples, 1 - alpha / 2, axis=1)  # shape: (n_bins, pixels)

    y_true = y_true[:, None]  # shape: (pixels, 1)
    coverage = ((y_true >= lower_q.T) & (y_true <= upper_q.T)).mean(axis=0)  # mean over pixels

    return confidence_levels, coverage

def PCE(y_true, y_samples, order=1, n_bins=20):
    confidence_levels, empirical_coverage = compute_confidence_and_coverage(y_true, y_samples, n_bins=n_bins)

    return np.mean(np.abs((confidence_levels-empirical_coverage)**order))

def pretty_plot(x, y, color='#1f77b4', label=None, x_label=None, y_label=None):
    with plt.style.context('seaborn-v0_8-whitegrid'), mpl.rc_context({
        'font.size': 6,
        'axes.labelsize': 6,
        'axes.titlesize': 6,
        'legend.fontsize': 6
    }):


        fig, ax = plt.subplots(figsize=(2, 2), dpi=300)
        ax.plot(x, y, color=color, linewidth=2, label='Reliability curve')
            
            
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.xticks([0.2,0.4,0.6,0.8,1.])
        plt.yticks([0.2,0.4,0.6,0.8,1.])
        ax.set_aspect('equal', adjustable='box')
            
        ax.set_xlabel('Confidence Interval')
        ax.set_ylabel("Empirical Coverage")

        ax.legend()
        ax.grid(True, linestyle='--', alpha=0.7)
        plt.show()