##############################################################################
####                             Yurong Chen                              ####
####                      chenyurong1998 AT outlook.com                   ####
####                          Hunan University                            ####
####                       Happy Coding Happy Ending                      ####
##############################################################################

import cv2
import math
import torch
import numpy as np
from pytorch_msssim import ssim
import rTV
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def A(data, Phi):
    return torch.sum(data * Phi, 2)


def At(meas, Phi):
    meas = torch.unsqueeze(meas, 2).repeat(1, 1, Phi.shape[2])
    return meas * Phi

def A_si(x, Phi): 
    temp = x*Phi
    y = torch.sum(temp,dim=2,keepdim=False)
    y_si = torch.sum(x,dim=2,keepdim=False)
    return y, y_si

def At_si(meas, meas_si, Phi):
    '''
    Tanspose of the Side_Information based forward model. 
    '''
    y = torch.unsqueeze(meas, 2).repeat(1, 1, Phi.shape[2])
    y_si = torch.unsqueeze(meas_si, 2).repeat(1, 1, Phi.shape[2])
    x = y*Phi
    x = x + y_si
    return x



def shift(inputs, step):
    [h, w, nC] = inputs.shape
    output = torch.zeros((h, w+(nC - 1)*step, nC)).to(device)
    for i in range(nC):
        output[:, i*step : i*step + w, i] = inputs[:, :, i]
    del inputs
    return output


def shift_back(inputs, step):
    [h, w, nC] = inputs.shape
    for i in range(nC):
        inputs[:, :, i] = torch.roll(inputs[:, :, i], (-1)*step*i, dims=1)
    output = inputs[:, 0 : w - step*(nC - 1), :]
    return output


def shift_v(inputs, step): # vertical shifting
    [h, w, nC] = inputs.shape
    ex_h = h+(nC-1)*step
    output = torch.zeros((ex_h, w, nC)).to(device)
    for i in range(nC):
        output[ex_h - h - i*step : ex_h - i*step, :, i] = inputs[:, :, i]
    del inputs
    return output


def shift_back_v(inputs, step): # vertical shifting: from bottom to up
    [h, w, nC] = inputs.shape
    for i in range(nC):
        inputs[:, :, i] = torch.roll(inputs[:, :, i], (+1)*step*i, dims=0)
    output = inputs[-(h - step*(nC - 1)):, :, :]
    return output


def ssim_(data, recon):
    C1 = (0.01 * 1) ** 2
    C2 = (0.03 * 1) ** 2
    data = data.astype(np.float64)
    recon = recon.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())
    mu1 = cv2.filter2D(data, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(recon, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(data ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(recon ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(data * recon, -1, window)[5:-5, 5:-5] - mu1_mu2
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def calculate_ssim(data, recon, border=0):
    if not data.shape == recon.shape:
        raise ValueError('Data size must have the same dimensions!')
    if not data.dtype == recon.dtype:
        data, recon = data.float(), recon.float()
        
    h, w = data.shape[:2]
    data = data[border:h - border, border:w - border]
    recon = recon[border:h - border, border:w - border]
    if data.ndim == 2:
        return ssim_(data, recon)
    elif data.ndim == 3:
        return ssim(torch.unsqueeze(data, 0).permute(3, 0, 1, 2), torch.unsqueeze(recon, 0).permute(3, 0, 1, 2), data_range=1).data
     

def calculate_psnr(data, recon):
    mse = torch.mean((recon - data)**2)
    if mse == 0:
        return 100
    Pixel_max = 1.
    return 20 * torch.log10(Pixel_max / torch.sqrt(mse))
    
    
def shrink(x, _lambda):
    u = torch.sign(x) * torch.clamp(x - 1/_lambda, 0)
    return u


def calculate_tv(x):
    N = x.shape
    idx = torch.arange(1, N[0]+1)
    idx[-1] = N[0]-1
    ir = torch.arange(1, N[1]+1)
    ir[-1] = N[1]-1

    x1 = x[:,ir,:] - x
    x2 = x[idx,:,:] - x
    tv = torch.abs(x1) + torch.abs(x2)
    return torch.mean(torch.sum(tv, 2))

def TV_denoiser(x, _lambda, n_iter_max):
    dt = 0.25
    N = x.shape
    idx = torch.arange(1, N[0]+1)
    idx[-1] = N[0]-1
    iux = torch.arange(-1, N[0]-1)
    iux[0] = 0
    ir = torch.arange(1, N[1]+1)
    ir[-1] = N[1]-1
    il = torch.arange(-1, N[1]-1)
    il[0] = 0
    p1 = torch.zeros_like(x)
    p2 = torch.zeros_like(x)
    divp = torch.zeros_like(x)

    for i in range(n_iter_max):
        z = divp - x*_lambda
        z1 = z[:,ir,:] - z
        z2 = z[idx,:,:] - z
        denom_2d = 1 + dt*torch.sqrt(torch.sum(z1**2 + z2**2, 2))
        denom_3d = torch.unsqueeze(denom_2d, 2).repeat(1, 1, N[2])
        p1 = (p1+dt*z1)/denom_3d
        p2 = (p2+dt*z2)/denom_3d
        divp = p1-p1[:,il,:] + p2 - p2[iux,:,:]
    u = x - divp/_lambda
    return u

def TV_minimization(x, y, _lambda, n_iter_max, alpha):
    dt = 0.25
    # alpha = 0 #0.95
    N = x.shape
    idx = torch.arange(1, N[0]+1)
    idx[-1] = N[0]-1
    iux = torch.arange(-1, N[0]-1)
    iux[0] = 0
    ir = torch.arange(1, N[1]+1)
    ir[-1] = N[1]-1
    il = torch.arange(-1, N[1]-1)
    il[0] = 0
    p1 = torch.zeros_like(x)
    p2 = torch.zeros_like(x)
    p3 = torch.zeros_like(x)
    p4 = torch.zeros_like(x)
    divp = torch.zeros_like(x)
    if y.dim() == 2:
        y = y.unsqueeze(2).repeat(1, 1, N[2])
        
    for i in range(n_iter_max):
        z = divp - (x - alpha*10*y) * _lambda
        if z.dim() == 3:
            z1 = 0.5*(z[:,ir,:] - z)
            z2 = 0.5*(z[idx,:,:] - z)
            z3 = 0.5*(z - z[:,il,:])
            z4 = 0.5*(z - z[iux,:,:])
            denom_2d = 1 + dt*torch.sqrt(torch.sum(z1**2 + z2**2 + z3**2 + z4**2, 2))
            denom_3d = torch.unsqueeze(denom_2d, 2).repeat(1, 1, N[2])
            p1 = (p1+dt*z1)/denom_3d
            p2 = (p2+dt*z2)/denom_3d
            p3 = (p3+dt*z3)/denom_3d
            p4 = (p4+dt*z4)/denom_3d
            divp = p1-p1[:,il,:] + p2-p2[iux,:,:] + p3[:,ir,:]-p3 + p4[idx,:,:]-p4
        elif z.dim() == 2:
            z1 = 0.5 * (z[:, ir] - z)
            z2 = 0.5*(z[idx,:] - z)
            z3 = 0.5*(z - z[:,il])
            z4 = 0.5*(z - z[iux,:])
            denom_2d = 1 + dt*torch.sqrt(z1**2 + z2**2 + z3**2 + z4**2)
            denom_3d = denom_2d
            p1 = (p1+dt*z1)/denom_3d
            p2 = (p2+dt*z2)/denom_3d
            p3 = (p3+dt*z3)/denom_3d
            p4 = (p4+dt*z4)/denom_3d
            divp = p1-p1[:,il] + p2-p2[iux,:] + p3[:,ir]-p3 + p4[idx,:]-p4
        else:
            raise ValueError("Unsupported tensor dimension. Expected 2 or 3 dimensions.")

    u = x - divp/_lambda
    return u



def RTV_minimization(x):
    x = x.squeeze().cpu().numpy()
    height, width, c = x.shape
    # s = np.zeros((height, width, c))
    s = rTV.tsmooth(x, maxIter=2)
    return torch.from_numpy(s).to(device)



# def TV_handcraft(self, x, tv_weight=1, dt=0.25):
#     B, C, H, W = x.shape
    
#     idx = torch.arange(1, H + 1)
#     idx[-1] = H - 1
#     iux = torch.arange(-1, H - 1)
#     iux[0] = 0
#     ir = torch.arange(1, W + 1)
#     ir[-1] = W - 1
#     il = torch.arange(-1, W - 1)
#     il[0] = 0
#     p1 = torch.zeros_like(x)
#     p2 = torch.zeros_like(x)
#     divp = torch.zeros_like(x)

#     tv_weight = tv_weight
#     for i in range(self.tv_iter_max):
#         z = divp - x *tv_weight
#         z1 = z[:, :, :, ir] - z
#         z2 = z[:, :,idx, :] - z
#         denom_2d = 1 + dt * torch.sqrt(torch.sum(z1**2 + z2**2, 1))
#         denom_3d = denom_2d.unsqueeze(1).repeat(1, C, 1, 1)
#         p1 = (p1 + dt * z1) / denom_3d
#         p2 = (p2 + dt * z2) / denom_3d
#         divp = p1 - p1[:, :, :, il] + p2 - p2[:, :, iux, :]
#     x = x - divp / tv_weight
#     return x 

#### Side Information for GAP
def gap_side_information(mask, alpha = 1):
    channel = mask.shape[-1] ## 需要修改
    Phi_1_Phi_1_T = torch.sum(mask ** 2, axis=2)
    Phi_sum = Phi_1_Phi_1_T.clone()
    Phi_sum[Phi_sum == 0] = 1
    Phi_1_diag_inv = 1.0 / Phi_sum
    Phi_2_diag = channel * alpha
    Phi_2_diag_inv = 1 / (channel * alpha)
    Phi_1_Phi_2_T = torch.sum(mask, axis=2)
    Phi_2_Phi_1_T = Phi_1_Phi_2_T

    temp_B1 = Phi_1_Phi_1_T - (Phi_1_Phi_2_T * Phi_2_diag_inv) * Phi_2_Phi_1_T
    temp_B1[temp_B1 == 0] = 1
    Phi_B1 = 1.0 / temp_B1

    temp_B4 = (alpha ** 2) * Phi_2_diag - (alpha ** 2) * Phi_2_Phi_1_T * Phi_1_diag_inv * Phi_1_Phi_2_T
    temp_B4[temp_B4 == 0] = 1
    Phi_B4 = 1.0 / temp_B4

    Phi_B2 = -alpha * Phi_1_diag_inv * Phi_1_Phi_2_T * Phi_B4
    Phi_B3 = -(1 / alpha) * Phi_2_diag_inv * Phi_2_Phi_1_T * Phi_B1

    return Phi_B1, Phi_B2, Phi_B3, Phi_B4


#### Side Information for ADMM
def admm_side_information(mask, gamma, alpha=1):
    channel = mask.shape[-1] ## 需要修改
    Phi_1_Phi_1_T = torch.sum(mask ** 2, axis=2)
    Phi_sum = Phi_1_Phi_1_T.clone()
    Phi_sum[Phi_sum == 0] = 1
    Phi_1_diag_inv = 1.0 / Phi_sum
    Phi_2_diag = channel * alpha
    Phi_2_diag_inv = 1 / (channel * alpha)
    Phi_1_Phi_2_T = torch.sum(mask, axis=2)
    Phi_2_Phi_1_T = Phi_1_Phi_2_T

    tempB1_inverse = 1 / ((Phi_2_diag * (alpha ** 2) / gamma) + 1)

    temp_B1 = 1 + (1 / gamma) * Phi_1_Phi_1_T - ((alpha ** 2) / gamma) * (Phi_1_Phi_2_T * tempB1_inverse) * Phi_2_Phi_1_T
    temp_B1[temp_B1 == 0] = 1
    Phi_B1 = 1.0 / temp_B1

    tempB4_inverse = 1 / ((Phi_1_Phi_1_T / gamma) + 1)
    temp_B4 = 1 + (((alpha ** 2) / gamma) * Phi_2_diag) - ((alpha ** 2) / (gamma ** 2)) * (Phi_2_Phi_1_T * tempB4_inverse) * Phi_1_Phi_2_T
    temp_B4[temp_B4 == 0] = 1
    Phi_B4 = 1.0 / temp_B4

    Phi_B2 = -(alpha / gamma) * tempB4_inverse * Phi_1_Phi_2_T * Phi_B4
    Phi_B3 = -(alpha / gamma) * tempB1_inverse * Phi_2_Phi_1_T * Phi_B1

    return Phi_B1, Phi_B2, Phi_B3, Phi_B4


