import math, torch, os
import torch.nn as nn
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from utils.misc import *
import pytorch_ssim.pytorch_ssim as pytorch_ssim

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


# def compute_psnr(X, Y):
#     criteria = nn.MSELoss()
#     return 20 * math.log10(1 / math.sqrt(criteria(X, Y)))


def PSNR(Xk, X):  # ONLY the REAL Part
    bs, C, W, H = X.shape

    Xk = torch.clamp(torch.abs(torch.view_as_complex(Xk.permute(0, 2, 3, 1).contiguous())), min=0, max=1)
    X = X[:, 0, :, :]

    mse = torch.sum(((Xk - X) ** 2).reshape(bs, -1), dim=1) / (W * H)
    # mse = torch.sum(((Xk - X) ** 2).reshape(bs, -1), dim=1) / (C * W * H)
    return 20 * torch.log10(torch.max(torch.max(X, dim=1)[0], dim=1)[0] / torch.sqrt(mse))


def compute_metrics(Xk, X, X0):
    init_psnr, recon_psnr = PSNR(X0, X), PSNR(Xk, X)
    bs = X.shape[0]
    avg_init_psnr = torch.sum(init_psnr) / bs
    avg_recon_psnr = torch.sum(recon_psnr) / bs
    avg_delta_psnr = torch.sum(recon_psnr - init_psnr) / bs

    Xk = torch.clamp(torch.abs(torch.view_as_complex(Xk.permute(0, 2, 3, 1).contiguous())), min=0, max=1)
    X = X[:, 0:1, :, :]
    avg_ssim = pytorch_ssim.ssim(Xk[:, None, :, :], X)
    return avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim


def PSNR1chan(Xk, X):  # ONLY the REAL Part
    bs, C, W, H = X.shape
    Xk = Xk[:, 0, :, :]
    X = X[:, 0, :, :]
    mse = torch.sum(((Xk - X) ** 2).reshape(bs, -1), dim=1) / (W * H)
    return 20 * torch.log10(torch.max(torch.max(X, dim=1)[0], dim=1)[0] / torch.sqrt(mse))


def compute_metrics1chan(Xk, X, X0):
    init_psnr, recon_psnr = PSNR1chan(X0, X), PSNR1chan(Xk, X)
    bs = X.shape[0]
    avg_init_psnr = torch.sum(init_psnr) / bs
    avg_recon_psnr = torch.sum(recon_psnr) / bs
    avg_delta_psnr = torch.sum(recon_psnr - init_psnr) / bs

    # Xk = torch.clamp(torch.abs(torch.view_as_complex(Xk.permute(0, 2, 3, 1).contiguous())), min=0, max=1)
    X = X[:, 0:1, :, :]
    avg_ssim = pytorch_ssim.ssim(Xk, X)
    return avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim


def PSNR3chan(Xk, X):
    bs, C, W, H = X.shape
    mse = torch.sum(((Xk - X) ** 2).reshape(bs, -1), dim=1) / (C * W * H)
    return 20 * torch.log10(1 / torch.sqrt(mse))


def compute_metrics3chan(Xk, X, X0):
    init_psnr, recon_psnr = PSNR3chan(X0, X), PSNR3chan(Xk, X)
    bs = X.shape[0]
    avg_init_psnr = torch.sum(init_psnr) / bs
    avg_recon_psnr = torch.sum(recon_psnr) / bs
    avg_delta_psnr = torch.sum(recon_psnr - init_psnr) / bs
    avg_ssim = pytorch_ssim.ssim(Xk, X)
    return avg_init_psnr, avg_recon_psnr, avg_delta_psnr, avg_ssim


def plot_MRI(Xk, X0, X, criteria, save_path, epoch):
    plt.figure(figsize=(10, 20))
    x_hat = torch.clamp(torch.abs(torch.view_as_complex(Xk.permute(0, 2, 3, 1).contiguous())), min=0, max=1)
    init = torch.view_as_complex(X0.permute(0, 2, 3, 1).contiguous())
    init_clamp = torch.clamp(torch.abs(init), min=0, max=1)
    X_true = X[:, 0, :, :]
    err = torch.abs(X_true - x_hat)
    for i in range(3):
        # ii = i * 5 + 2
        ii = i  # + 3
        plt.subplot(4, 3, i + 1)
        psnr = 20 * math.log10(torch.max(X_true) / math.sqrt(criteria(init_clamp[ii], X_true[ii])))
        plt.imshow(init_clamp[ii].detach().cpu(), cmap='gray')
        plt.title('$x_0$, PSNR = {0:.3f}'.format(psnr))
        plt.axis('off')
        plt.subplot(4, 3, i + 4)
        psnr = 20 * math.log10(torch.max(X_true) / math.sqrt(criteria(x_hat[ii], X_true[ii])))
        plt.imshow(x_hat[ii].detach().cpu(), cmap='gray')
        plt.title('$\hat{x}$, ' + 'PSNR = {0:.3f}'.format(psnr))
        plt.axis('off')
        plt.subplot(4, 3, i + 7)
        plt.imshow(X_true[ii].detach().cpu(), cmap='gray')
        plt.title('Clean image')
        plt.axis('off')
        plt.subplot(4, 3, i + 10)
        plt.imshow(err[ii].detach().cpu(), cmap='gray')
        plt.title('Error, max:{0:.2f}'.format(torch.max(err)))
        plt.axis('off')

    # plt.show()
    plt.savefig(os.path.join(save_path, f'{epoch}_results.png'))
    plt.close()


def plot_CelebA(Xk, X0, X, criteria, save_path, epoch):
    Xk = torch.clamp(Xk, min=0, max=1).permute(0, 2, 3, 1)
    X0 = torch.clamp(X0, min=0, max=1).permute(0, 2, 3, 1)
    X = X.permute(0, 2, 3, 1)
    err = torch.abs(X - Xk)
    plt.figure(figsize=(10, 20))
    for i in range(3):
        ii = i  # * 5 + 2
        plt.subplot(4, 3, i + 1)
        psnr = 20 * math.log10(torch.max(X) / math.sqrt(criteria(X0[ii], X[ii])))
        plt.imshow(X0[ii].detach().cpu())
        plt.title('$x_0$, PSNR = {0:.3f}'.format(psnr))
        plt.axis('off')
        plt.subplot(4, 3, i + 4)
        psnr = 20 * math.log10(torch.max(X) / math.sqrt(criteria(Xk[ii], X[ii])))
        plt.imshow(Xk[ii].detach().cpu())
        plt.title('$\hat{x}$, ' + 'PSNR = {0:.3f}'.format(psnr))
        plt.axis('off')
        plt.subplot(4, 3, i + 7)
        plt.imshow(X[ii].detach().cpu())
        plt.title('Clean image')
        plt.axis('off')
        plt.subplot(4, 3, i + 10)
        plt.imshow(err[ii].detach().cpu())
        plt.title('Error, max:{0:.2f}'.format(torch.max(err)))
        plt.axis('off')

    # plt.show()
    plt.savefig(os.path.join(save_path, f'{epoch}_results.png'))
    plt.close()


def plot_CT(Xk, X0, X, criteria, save_path, epoch):
    Xk = torch.clamp(Xk, min=0, max=1)
    X0 = torch.clamp(X0, min=0, max=1)
    err = torch.abs(X - Xk)
    plt.figure(figsize=(10, 20))
    for i in range(3):
        ii = i
        plt.subplot(4, 3, i + 1)
        psnr = 20 * math.log10(torch.max(X) / math.sqrt(criteria(X0[ii], X[ii])))
        plt.imshow(X0[ii].detach().cpu(), cmap='gray')
        plt.title('$x_0$, PSNR = {0:.3f}'.format(psnr))
        plt.axis('off')
        plt.subplot(4, 3, i + 4)
        psnr = 20 * math.log10(torch.max(X) / math.sqrt(criteria(Xk[ii], X[ii])))
        plt.imshow(Xk[ii].detach().cpu(), cmap='gray')
        plt.title('$\hat{x}$, ' + 'PSNR = {0:.3f}'.format(psnr))
        plt.axis('off')
        plt.subplot(4, 3, i + 7)
        plt.imshow(X[ii].detach().cpu(), cmap='gray')
        plt.title('Clean image')
        plt.axis('off')
        plt.subplot(4, 3, i + 10)
        plt.imshow(err[ii].detach().cpu(), cmap='gray')
        plt.title('Error, max:{0:.2f}'.format(torch.max(err)))
        plt.axis('off')

    # plt.show()
    plt.savefig(os.path.join(save_path, f'{epoch}_results.png'))
    plt.close()


def plot_LU_result(X, args, op, invBlock, forward_operator, measurement_process, criteria, save_path, epoch):
    with torch.no_grad():
        X = op.normalize(X, X.shape[0]).unsqueeze(1)
        zeros = torch.zeros_like(X)
        X = torch.cat((X, zeros), dim=1).to(args.device)
        y = measurement_process(X).squeeze()  # if args.train else measurement_process(X, seed=10).squeeze()
        X0 = forward_operator.adjoint(y)
        Xk = invBlock(X0, y)

        plt.figure(figsize=(10, 20))
        x_hat = torch.clamp(torch.abs(torch.view_as_complex(Xk.permute(0, 2, 3, 1).contiguous())), min=0, max=1)
        init = torch.view_as_complex(X0.permute(0, 2, 3, 1).contiguous())
        init_clamp = torch.clamp(torch.abs(init), min=0, max=1)
        X_true = X[:, 0, :, :]
        err = torch.abs(X_true - x_hat)
        for i in range(3):
            ii = i  # * 5 + 2
            plt.subplot(4, 3, i + 1)
            psnr = 20 * math.log10(torch.max(X_true) / math.sqrt(criteria(init_clamp[ii], X_true[ii])))
            plt.imshow(init_clamp[ii].detach().cpu(), cmap='gray', vmin=0, vmax=1)
            plt.title('$x_0$, PSNR = {0:.3f}'.format(psnr))
            plt.axis('off')
            plt.subplot(4, 3, i + 4)
            psnr = 20 * math.log10(torch.max(X_true) / math.sqrt(criteria(x_hat[ii], X_true[ii])))
            plt.imshow(x_hat[ii].detach().cpu(), cmap='gray', vmin=0, vmax=1)
            plt.title('$\hat{x}$, ' + 'PSNR = {0:.3f}'.format(psnr))
            plt.axis('off')
            plt.subplot(4, 3, i + 7)
            plt.imshow(X_true[ii].detach().cpu(), cmap='gray', vmin=0, vmax=1)
            plt.title('Clean image')
            plt.axis('off')
            plt.subplot(4, 3, i + 10)
            plt.imshow(err[ii].detach().cpu(), cmap='gray', vmin=0, vmax=1)
            plt.title('Error, max:{0:.2f}'.format(torch.max(err)))
            plt.axis('off')

        # plt.show()
        plt.savefig(os.path.join(save_path, f'{epoch}_results.png'))
        plt.close()


def plot_DEQasProx_result(X, args, op, invBlock, forward_operator, measurement_process, criteria, save_path, epoch):
    with torch.no_grad():
        X = X.to(args.device)
        X = op.normalize(X, X.size(0)).unsqueeze(1)
        zeros = torch.zeros_like(X)
        X = torch.cat((X, zeros), dim=1).to(args.device)
        y = measurement_process(X).squeeze()
        Xk = forward_operator.adjoint(y)

        for k in range(args.maxiters):
            Xk = Xk.detach()
            Xk = invBlock(Xk, y, k, True)

        plt.figure(figsize=(10, 20))
        x_hat = torch.clamp(torch.abs(torch.view_as_complex(Xk.permute(0, 2, 3, 1).contiguous())), min=0, max=1)
        init = torch.view_as_complex(forward_operator.adjoint(y).permute(0, 2, 3, 1).contiguous())
        init_clamp = torch.clamp(torch.abs(init), min=0, max=1)
        X = X[:, 0, :, :]
        err = torch.abs(x_hat - X)

        for i in range(3):
            ii = i * 5 + 2
            plt.subplot(4, 3, i + 1)
            psnr = 20 * math.log10(1 / math.sqrt(criteria(init_clamp[ii], X[ii])))
            plt.imshow(init_clamp[ii].detach().cpu(), cmap='gray', vmin=0, vmax=1)
            plt.title('$x_0$, PSNR = {0:.3f}'.format(psnr))
            plt.axis('off')
            plt.subplot(4, 3, i + 4)
            psnr = 20 * math.log10(1 / math.sqrt(criteria(x_hat[ii], X[ii])))
            plt.imshow(x_hat[ii].detach().cpu(), cmap='gray', vmin=0, vmax=1)
            plt.title('$\hat{x}$, ' + 'PSNR = {0:.3f}'.format(psnr))
            plt.axis('off')
            plt.subplot(4, 3, i + 7)
            plt.imshow(X[ii].detach().cpu(), cmap='gray', vmin=0, vmax=1)
            plt.title('Clean image')
            plt.axis('off')
            plt.subplot(4, 3, i + 10)
            plt.imshow(err[ii].detach().cpu(), cmap='gray', vmin=0, vmax=1)
            plt.title('Error, max:{0:.2f}'.format(torch.max(err)))
            plt.axis('off')

        # plt.show()
        plt.savefig(os.path.join(save_path, f'{epoch}_results.png'))
        plt.close()

        # plot error histogram
        plt.figure(figsize=(10, 5))
        plt.hist(err[0].reshape(1, -1).cpu(), bins=25, range=[0, 0.25])
        plt.savefig(os.path.join(save_path, f'err_hist_{epoch}_results.png'))
        plt.close()
