# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
from tensorflow.python.ops.init_ops import Initializer,_compute_fans
from numpy.random import RandomState


def calc_SNR(y, y_):
    y = np.array(y).flatten()
    y_ = np.array(y_).flatten()
    err = np.linalg.norm(y_ - y)**2
    snr = 10 * np.log10(np.linalg.norm(y_)**2 / err)

    return snr


def calc_PSNR(y, y_):
    y = np.array(y).flatten()
    y_ = np.array(y_).flatten()
    err = np.linalg.norm(y_ - y)**2
    max_y = np.max(np.abs(y))
    N = np.prod(y.shape)
    psnr = 10 * np.log10(N * max_y**2 / err)

    return psnr


def mse(recon, label):
    if recon.dtype == tf.complex64:
        residual_cplx = recon - label
        residual = tf.stack(
            [tf.math.real(residual_cplx),
             tf.math.imag(residual_cplx)],
            axis=-1)
        mse = tf.reduce_mean(residual**2)
    else:
        residual = recon - label
        mse = tf.reduce_mean(residual**2)
    return mse

def t_svd(x, t_rank):
    # nb nt nx ny
    X = tf.transpose(x, perm=[0, 2, 3, 1]) # nb nx ny nt
    X = tf.signal.fft(X)
    X = tf.transpose(X, perm=[0, 3, 1, 2]) # nb nt nx ny
    S, Ut, Vt = tf.linalg.svd(X)
    Ut = tf.transpose(Ut, perm=[0, 2, 3, 1]) # nb nt nx r -> nb nx r nt 
    Vt = tf.transpose(Vt, perm=[0, 2, 3, 1]) # nb nt ny r -> nb ny r nt 
    
    Ut = tf.signal.ifft(Ut)
    Vt = tf.signal.ifft(Vt)
    Ut = tf.transpose(Ut, perm=[0, 3, 1, 2]) # nb nx r nt -> nb nt nx r
    Vt = tf.transpose(Vt, perm=[0, 3, 1, 2]) # nb ny r nt -> nb nt ny r

    return Ut[:, :, :, 0:t_rank], Vt[:, :, :, 0:t_rank]

def prox_tnn(x, rho):
    W_t = tf.transpose(x, perm=[0, 2, 3, 1])  # nb nx ny nt
    W_t = tf.signal.fft(W_t)
    W_t = tf.transpose(W_t, perm=[0, 3, 1, 2])  # nb nt nx ny
    S, Ut, Vt = tf.linalg.svd(W_t)
    Ut = tf.transpose(Ut, perm=[0, 2, 3, 1])  # nb nt nx r -> nb nx r nt
    Vt = tf.transpose(Vt, perm=[0, 2, 3, 1])  # nb nt ny r -> nb ny r nt
    # S: nb, nt, r
    S = tf.transpose(S, perm=[0, 2, 1])  # nb, nt, r -> nb, r, nt
    Ut = tf.signal.ifft(Ut)
    Vt = tf.signal.ifft(Vt)
    S = tf.signal.ifft(tf.cast(S, dtype=tf.complex64))
    # S = tf.signal.ifft(S)

    Ut = tf.transpose(Ut, perm=[0, 3, 1, 2])  # nb nx r nt -> nb nt nx r
    # Vt = tf.transpose(Vt, perm=[0, 3, 1, 2]) # nb ny r nt -> nb nt ny r
    Vt_conj = tf.transpose(Vt, perm=[0, 3, 2,
                                        1])  # nb ny r nt -> nb nt r ny
    # S = tf.transpose(S, perm=[0, 2, 1]) #  nb, r, nt -> nb, nt, r
    S = tf.transpose(tf.math.real(S), perm=[0, 2,
                                            1])  #  nb, r, nt -> nb, nt, r
    thres = tf.sigmoid(rho) * S[..., 0]
    thres = tf.expand_dims(thres, -1)
    S = tf.nn.relu(S - thres)
    S = tf.linalg.diag(S)
    S = tf.dtypes.cast(S, tf.complex64)
    # Vt_conj = tf.transpose(Vt, perm=[0, 1, 3, 2]) # nb nt ny r -> nb nt r ny
    Vt_conj = tf.math.conj(Vt_conj)
    US = tf.linalg.matmul(Ut, S)
    W_t = tf.linalg.matmul(US, Vt_conj)

    return W_t

def dct_2d(
        feature_map,    # nb, nt, nx, ny
        norm='ortho'    # can also be 'ortho'
):
    # nb, nt, a+bj, nx, ny
    X1_t = tf.stack([tf.math.real(feature_map), tf.math.imag(feature_map)], axis=-3)
    X1_t = tf.signal.dct(X1_t, type=2, norm=norm)
    X1_t = tf.transpose(X1_t, perm=[0, 1, 2, 4, 3])
    X2_t = tf.signal.dct(X1_t, type=2, norm=norm)
    X2_t = tf.transpose(X2_t, perm=[0, 1, 2, 4, 3])
    X2_t = tf.complex(X2_t[:, :, 0, :, :], X2_t[:, :, 1, :, :])
    return X2_t

def idct_2d(
        feature_map,    # nb, nt, nx, ny
        norm='ortho'    # can also be 'ortho'
):
    # nb, nt, a+bj, nx, ny
    X1_t = tf.stack([tf.math.real(feature_map), tf.math.imag(feature_map)], axis=-3)
    X1_t = tf.transpose(X1_t, perm=[0, 1, 2, 4, 3])
    X1_t = tf.signal.idct(X1_t, type=2, norm=norm)
    X2_t = tf.transpose(X1_t, perm=[0, 1, 2, 4, 3])
    X2_t = tf.signal.idct(X2_t, type=2, norm=norm)
    X2_t = tf.complex(X2_t[:, :, 0, :, :], X2_t[:, :, 1, :, :])
    return X2_t

def fft2c_mri(x):
    # nb nt nx ny
    X = tf.signal.ifftshift(x, 2)
    X = tf.transpose(X, perm=[0, 1, 3,
                              2])  # permute to make nx dimension the last one.
    X = tf.signal.fft(X)
    X = tf.transpose(X, perm=[0, 1, 3, 2])  # permute back to original order.
    nb, nt, nx, ny = np.float32(x.shape)
    nx = tf.constant(np.complex64(nx + 0j))
    ny = tf.constant(np.complex64(ny + 0j))
    X = tf.signal.fftshift(X, 2) / tf.sqrt(nx)

    X = tf.signal.ifftshift(X, 3)
    X = tf.signal.fft(X)
    X = tf.signal.fftshift(X, 3) / tf.sqrt(ny)

    return X


def ifft2c_mri(X):
    # nb nt nx ny
    x = tf.signal.ifftshift(X, 2)
    x = tf.transpose(x,
                     perm=[0, 1, 3,
                           2])  # permute a to make nx dimension the last one.
    x = tf.signal.ifft(x)
    x = tf.transpose(x, perm=[0, 1, 3, 2])  # permute back to original order.
    nb, nt, nx, ny = np.float32(X.shape)
    nx = tf.constant(np.complex64(nx + 0j))
    ny = tf.constant(np.complex64(ny + 0j))
    x = tf.signal.fftshift(x, 2) * tf.sqrt(nx)

    x = tf.signal.ifftshift(x, 3)
    x = tf.signal.ifft(x)
    x = tf.signal.fftshift(x, 3) * tf.sqrt(ny)

    return x

def ts_conj(x):
    x_conj = tf.transpose(x, perm=[0, 1, 3, 2])
    x_conj = tf.math.conj(x_conj)
    return x_conj

def mtimes(b, inv, mask, csm):
    # b [B, nt, nx, ny]
    # csm [B, nc, nt, nx, ny]
    if csm == None:
        if inv:
            x = ifft2c_mri_singlecoil(b * mask)
        else:
            x = fft2c_mri_singlecoil(b) * mask
    else:
        if len(mask.shape) > 3:
            if inv:
                x = ifft2c_mri_multicoil(b * mask[:,:,0:b.shape[2],:,:])
                x = x * tf.math.conj(csm)
                x = tf.reduce_sum(x, 1) #/ tf.cast(tf.reduce_sum(tf.abs(csm)**2, 1), dtype=tf.complex64)
            else:
                b = tf.expand_dims(b, 1) * csm
                x = fft2c_mri_multicoil(b) * mask[:,:,0:b.shape[2],:,:]
        else:
            if inv:
                x = ifft2c_mri_multicoil(b * mask)
                x = x * tf.math.conj(csm)
                x = tf.reduce_sum(x, 1) #/ tf.cast(tf.reduce_sum(tf.abs(csm)**2, 1), dtype=tf.complex64)
            else:
                b = tf.expand_dims(b, 1) * csm
                x = fft2c_mri_multicoil(b) * mask

    return x

def fft2c_mri_multicoil(x):
    # nb nt nx ny -> nb, nc, nt, nx, ny
    X = tf.signal.ifftshift(x, 3)
    X = tf.transpose(X, perm=[0,1,2,4,3]) # permute to make nx dimension the last one.
    X = tf.signal.fft(X)
    X = tf.transpose(X, perm=[0,1,2,4,3]) # permute back to original order.
    nb, nc, nt, nx, ny = np.float32(x.shape)
    nx = tf.constant(np.complex64(nx + 0j))
    ny = tf.constant(np.complex64(ny + 0j))
    X = tf.signal.fftshift(X, 3) / tf.sqrt(nx)
    X = tf.signal.ifftshift(X, 4)
    X = tf.signal.fft(X)
    X = tf.signal.fftshift(X, 4) / tf.sqrt(ny)
    
    return X

def ifft2c_mri_multicoil(X):
    # nb nt nx ny -> nb, nc, nt, nx, ny
    x = tf.signal.ifftshift(X, 3)
    x = tf.transpose(x, perm=[0,1,2,4,3]) # permute a to make nx dimension the last one.
    x = tf.signal.ifft(x)
    x = tf.transpose(x, perm=[0,1,2,4,3]) # permute back to original order.
    nb, nc, nt, nx, ny = np.float32(X.shape)
    nx = tf.constant(np.complex64(nx + 0j))
    ny = tf.constant(np.complex64(ny + 0j))

    x = tf.signal.fftshift(x, 3) * tf.sqrt(nx)

    x = tf.signal.ifftshift(x, 4)
    x = tf.signal.ifft(x)
    x = tf.signal.fftshift(x, 4) * tf.sqrt(ny)
    
    return x

def fft2c_mri_singlecoil(x):
    # nb nx ny nt
    X = tf.signal.ifftshift(x, 2)
    X = tf.transpose(X, perm=[0,1,3,2]) # permute to make nx dimension the last one.
    X = tf.signal.fft(X)
    X = tf.transpose(X, perm=[0,1,3,2]) # permute back to original order.
    nb, nt, nx, ny = np.float32(x.shape)
    nx = tf.constant(np.complex64(nx + 0j))
    ny = tf.constant(np.complex64(ny + 0j))
    X = tf.signal.fftshift(X, 2) / tf.sqrt(nx)
    X = tf.signal.ifftshift(X, 3)
    X = tf.signal.fft(X)
    X = tf.signal.fftshift(X, 3) / tf.sqrt(ny)
    
    return X

def ifft2c_mri_singlecoil(X):
    # nb nx ny nt
    x = tf.signal.ifftshift(X, 2)
    x = tf.transpose(x, perm=[0,1,3,2]) # permute a to make nx dimension the last one.
    x = tf.signal.ifft(x)
    x = tf.transpose(x, perm=[0,1,3,2]) # permute back to original order.
    nb, nt, nx, ny = np.float32(X.shape)
    nx = tf.constant(np.complex64(nx + 0j))
    ny = tf.constant(np.complex64(ny + 0j))

    x = tf.signal.fftshift(x, 2) * tf.sqrt(nx)

    x = tf.signal.ifftshift(x, 3)
    x = tf.signal.ifft(x)
    x = tf.signal.fftshift(x, 3) * tf.sqrt(ny)
    
    return x


class Emat_xyt():
    def __init__(self, mask):
        super(Emat_xyt, self).__init__()
        self.mask = mask

    def mtimes(self, b, inv, csm):
        if csm == None:
            if inv:
                x = self._ifft2c_mri_singlecoil(b * self.mask)
            else:
                x = self._fft2c_mri_singlecoil(b) * self.mask
        else:
            if len(self.mask.shape) > 3:
                if inv:
                    x = self._ifft2c_mri_multicoil(b * self.mask[:,:,0:b.shape[2],:,:])
                    x = x * tf.math.conj(csm)
                    x = tf.reduce_sum(x, 1) #/ tf.cast(tf.reduce_sum(tf.abs(csm)**2, 1), dtype=tf.complex64)
                else:
                    b = tf.expand_dims(b, 1) * csm
                    x = self._fft2c_mri_multicoil(b) * self.mask[:,:,0:b.shape[2],:,:]
            else:
                if inv:
                    x = self._ifft2c_mri_multicoil(b * self.mask)
                    x = x * tf.math.conj(csm)
                    x = tf.reduce_sum(x, 1) #/ tf.cast(tf.reduce_sum(tf.abs(csm)**2, 1), dtype=tf.complex64)
                else:
                    b = tf.expand_dims(b, 1) * csm
                    x = self._fft2c_mri_multicoil(b) * self.mask
        
        return x
    
class ComplexInit(Initializer):

    def __init__(self, kernel_size, input_dim,
                 weight_dim, nb_filters=None,
                 criterion='glorot', seed=None):

        # `weight_dim` is used as a parameter for sanity check
        # as we should not pass an integer as kernel_size when
        # the weight dimension is >= 2.
        # nb_filters == 0 if weights are not convolutional (matrix instead of filters)
        # then in such a case, weight_dim = 2.
        # (in case of 2D input):
        #     nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2
        # conv1D: len(kernel_size) == 1 and weight_dim == 1
        # conv2D: len(kernel_size) == 2 and weight_dim == 2
        # conv3d: len(kernel_size) == 3 and weight_dim == 3

        assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3}
        self.nb_filters = nb_filters
        self.kernel_size = kernel_size
        self.input_dim = input_dim
        self.weight_dim = weight_dim
        self.criterion = criterion
        self.seed = 1337 if seed is None else seed

    def __call__(self, shape, dtype=None, partition_info=None):

        if self.nb_filters is not None:
            kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters)
        else:
            kernel_shape = (int(self.input_dim), self.kernel_size[-1])

        fan_in, fan_out = _compute_fans(
            tuple(self.kernel_size) + (self.input_dim, self.nb_filters)
        )

        if self.criterion == 'glorot':
            s = 1. / (fan_in + fan_out)
        elif self.criterion == 'he':
            s = 1. / fan_in
        else:
            raise ValueError('Invalid criterion: ' + self.criterion)
        rng = RandomState(self.seed)
        modulus = rng.rayleigh(scale=s, size=kernel_shape)
        phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape)
        weight_real = modulus * np.cos(phase)
        weight_imag = modulus * np.sin(phase)
        weight = np.concatenate([weight_real, weight_imag], axis=-1)

        return weight