import numpy as np
import math
import os

import copy
from skimage.metrics import structural_similarity
import torch


def img_normalize(img):
    '''
    using min-max to normalize image to [-1, 1]
    '''
    min_ = np.min(img)
    max_ = np.max(img)
    # norm1: [0, 1]
    norm1 = (img - min_)/(max_ - min_)
    # norm2: [-1, 1]
    norm2 = norm1 * 2 -1
    return norm2, min_, max_

def img_denormalize(img, min_, max_):
    '''
    using min-max to remove normalize
    '''
    norm1 = (img + 1)/2
    return norm1 * (max_ - min_) + min_

# all images are assumed in range [0, 1]

# def MPSNR_per_image(img1, img2):
#     '''
#     img1 and img2 have range [0, 1]
#     Args:
#         img1, img2: shape (C, H, W)
#     '''

#     if np.min(img1) < 0 :
#         raise Exception(" The input image is not [0, 1]")


#     img1 = img1.astype(np.float64)
#     img2 = img2.astype(np.float64)
#     ch = img1.shape[0]
    
#     # mse: shape
#     sum = 0
#     for i in range(ch):
#         mse = np.mean((img1[i,:,:] - img2[i,:,:]) ** 2)
#         if mse == 0:
#             return 100
#         PIXEL_MAX = 1.0
#         s = 20 * math.log10(PIXEL_MAX / np.sqrt(mse))
#         sum = sum + s
#     s = sum / ch
#     return s

# def MPSNR_loop(img1, img2):
#     B = img1.shape[0]
#     s = []
#     for i in range(B):
#         s.append( MPNSR_per_image(img1[i], img2[i]) )
    

#     return np.array(s)


# def MPSNR(y_gt, x_pred):
#     '''
#     y_gt, x_pred have range [0, 1]

#     Compute PNSR per image band, compute mean, and compute mean for all images
#     Args:
#         y_gt, x_pred: shape (B, C, H, W)
#     Return:
#         pnsr: shape (B) 
#     '''
#     img1, img2 = y_gt, x_pred
#     assert img1.shape == img2.shape
#     # if np.min(img1) < 0 :
#     #     raise Exception(" The input image is not [0, 1]")


#     img1 = img1.astype(np.float64)
#     img2 = img2.astype(np.float64)
    

#     mse = np.mean((img1 - img2) ** 2, axis = (-2,-1))
#     # Avoid divide 0
#     # This makes the PNSE for each image' each channel's PNSR as 100
#     mse[mse == 0] = (1/ (10 **(100/20) ))**2
#     PIXEL_MAX = 1.0
#     s = 20 * np.log10(PIXEL_MAX / np.sqrt(mse))

#     # mean PNSR across channels
#     pnsr = np.mean(s, axis = -1)

#     return pnsr


def MPSNR(y_gt, x_pred):
    '''
    y_gt, x_pred have range [0, 1]

    Compute PNSR per image band, compute mean, and compute mean for all images
    Args:
        y_gt, x_pred: shape (B, C, H, W)
    Return:
        pnsr: shape (B) 
    '''
    img1, img2 = y_gt, x_pred
    assert img1.shape == img2.shape
    # if np.min(img1) < 0 :
    #     raise Exception(" The input image is not [0, 1]")


    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    

    mse = np.mean((img1 - img2) ** 2, axis = (-2,-1))
    # Avoid divide 0
    # This makes the PNSE for each image' each channel's PNSR as 100
    mse[mse == 0] = np.power(10, -100/10)
    PIXEL_MAX = 1.0
    s = -10 * np.log10(mse / (PIXEL_MAX**2) )

    # mean PNSR across channels
    pnsr = np.mean(s, axis = -1)

    return pnsr


# def SSIM_per_image_ch_loop(img1, img2):
#     '''
#     img1 and img2 have range [0, 1]
#     Args:
#         img1, img2: shape (C, H, W)
#     '''
#     C,H,W = img1.shape

#     ssims = []
#     for i in range(C):
#         ssimi = structural_similarity(img1[i,:,:], img2[i,:,:], data_range=1 - 0, multichannel = False)
#         ssims.append(ssimi)
#     ssim = np.mean(ssims)
#     print(ssims)
#     return ssim

def SSIM_per_image(y_gt, x_pred):
    '''
    y_gt, x_pred have range [0, 1]
    Args:
        y_gt, x_pred: shape (C, H, W)
    '''
    img1, img2 = y_gt, x_pred
    C,H,W = img1.shape
    # img1_ or img2_: shape (H,W,C)
    img1_ = np.transpose(img1, (1, 2, 0))
    img2_ = np.transpose(img2, (1, 2, 0))

    return structural_similarity(img1_, img2_, data_range = 1 - 0, multichannel = True)

def SSIM(y_gt, x_pred):
    '''
    y_gt, x_pred have range [0, 1]
    Args:
        y_gt, x_pred: shape (B, C, H, W)
    Return:
        ssims: shape (B)
    '''
    img1, img2 = y_gt, x_pred
    assert img1.shape == img2.shape
    B,C,H,W = img1.shape
    ssims = []
    for i in range(B):
        ssimi = SSIM_per_image(img1[i], img2[i])
        ssims.append(ssimi)
    # print(ssims)
    ssims = np.array(ssims)
    return ssims


# def SAM_per_image(img1, img2):
#     '''
#     Compute the correlation coefficient of each corresponding spectral vector of two images,
#     then compute the mean
#     Args:
#         img1, img2: shape (C, H, W)
#     Return:
#         mean_sam: mean SAM across all pixel
#         var_sam:  variance across each pixel
#     '''
#     x_true = copy.deepcopy(img1)
#     x_pred = copy.deepcopy(img2)

#     assert x_true.ndim ==3 and x_true.shape == x_pred.shape

#     C, H, W = x_true.shape
#     # x_true, x_pred: shape (C, H*W)
#     x_true = x_true.reshape(C, -1)
#     x_pred = x_pred.reshape(C, -1)

#     x_pred[:, np.where((np.linalg.norm(x_pred, 2, 0))==0)]+=0.0001

#     # sam: shape (H*W)
#     sam = (x_true * x_pred).sum(axis=0) / (np.linalg.norm(x_true, 2, 0) * np.linalg.norm(x_pred, 2, 0))

#     sam = np.arccos(sam) * 180 / np.pi
#     mean_sam = sam.mean()
#     var_sam = np.var(sam)
#     return mean_sam, var_sam

# def SAM_old(img1, img2):
#     '''
#     img1 and img2 have range [0, 1]
#     Args:
#         img1, img2: shape (B, C, H, W)
#     Return:
#         sams: shape (B)
#     '''
#     assert img1.shape == img2.shape and img1.ndim ==4
#     B,C,H,W = img1.shape
#     sams = []
#     for i in range(B):
#         sam_i, var_sam_i = SAM_per_image(img1[i], img2[i]) 
#         sams.append(sam_i)
#     # print(sams)
#     sams = np.array(sams)
#     return sams

def SAM(y_gt, x_pred):
    import numpy.ma as ma 
    assert y_gt.shape == x_pred.shape
    B,C,H,W = y_gt.shape

    # x_true, x_pred: shape (B, C, H*W)
    y_gt = y_gt.reshape(B,C, -1)
    x_pred = x_pred.reshape(B,C, -1)

    epislon = 1e-7

    prod = (y_gt * x_pred).sum(axis=1)
    y_norm = np.linalg.norm(y_gt, 2, 1)
    x_norm = np.linalg.norm(x_pred, 2, 1)

    # sam: shape (B, H*W)
    sam = ( prod + epislon) / (  (y_norm + epislon) * ( x_norm + epislon ) )

    sam = np.clip(sam, a_min = -1, a_max = 1)

    # sam: shape (B, H*W)
    sam = np.arccos(sam) * 180 / np.pi


    sam = ma.masked_array(sam, mask =np.isnan(sam)) 
    mean_sam = np.array(sam.mean(axis = -1))
    var_sam = np.array(sam.var(axis = -1))
    return mean_sam, var_sam

def compute_RMSE(y_gt, x_pred):
    '''
    img1 and img2 have range [0, 1]

    Compute Root MSE 
    Args:
        img1, img2: shape (B, C, H, W)
    Return:
        rmse_per_band: shape (B, C)
        rmse: shape (B)
    '''
    y, x = y_gt, x_pred
    B,C,H,W = y.shape
    assert x.shape == y.shape

    aux = np.mean((x - y) ** 2, axis = (-1,-2) ) 
    rmse_per_band = np.sqrt(aux)
    rmse = np.sqrt(np.sum(aux, axis = 1)/C)
    return rmse_per_band, rmse

def ERGAS(y_gt, x_pred, ratio_ergas):
    '''
    y_gt, x_pred have range [0, 1]

    Compute ERGAS
    Args:
        y_gt, x_pred: shape (B, C, H, W)
        ratio_ergas: the ratio between high spatial resolution and low spatial resolution images' pixel size
            h/l
    Return:
        ergas: shape (B)
    '''
    # rmse_per_band: shape (B, C)
    # rmse: shape (B)
    rmse_per_band, rmse = compute_RMSE(y_gt, x_pred)
    y, x = y_gt, x_pred
    
    # mean_y: shape (B, C)
    mean_y = np.mean(y, axis = (-1,-2) )  
    ergas = 100*ratio_ergas*np.sqrt(np.mean((rmse_per_band / mean_y)**2, axis = 1))
    return ergas

def eval_img_metric(y_gt, x_pred, ratio_ergas = 1.0/2, eval_metric_flag = None):
    '''
    y_gt, x_pred have range [0, 1]

    Compute ERGAS
    Args:
        y_gt, x_pred: shape (B, C, H, W)
        ratio_ergas: the ratio between high spatial resolution and low spatial resolution images' pixel size
            h/l
            parameter required to compute ERGAS = h/l, where 
            h - linear spatial resolution of pixels of the high resolution image, 
            l - linear spatial resolution of pixels of the low resolution image (e.g., 1/4)
    Return:
        psnr: shape (B)
        ergas: shape (B)
        sam: tuple (mean_sam, var_sam), both of them shape (B)
        ssim: shape (B)
    '''
     

    if eval_metric_flag['psnr']:
        psnr = MPSNR(y_gt, x_pred)
    else:
        psnr = np.zeros(y_gt.shape[0])

    if eval_metric_flag['ergas']:
        ergas = ERGAS(y_gt, x_pred, ratio_ergas)
    else:
        ergas = np.zeros(y_gt.shape[0])

    if eval_metric_flag['sam']:
        sam = SAM(y_gt, x_pred)
        sam = sam[0]
    else:
        sam = np.zeros(y_gt.shape[0])

    if eval_metric_flag['ssim']:
        ssim = SSIM(y_gt, x_pred)
    else:
        ssim = np.zeros(y_gt.shape[0])

    return psnr, ergas, sam, ssim


