import numpy as np 
import GPy

# Error functions

def relative_error(x, y):
	"""
	Get the relative error: mean of norm of difference/ norm of ground truth 
	Derived from code for "FOURIER NEURAL OPERATOR FOR PARAMETRIC PARTIAL DIFFERENTIAL EQUATIONS"

	Uses the 2 norm. 
	args: 
		- x: image to compare 
		- y: second image - the ground truth image 
	returns: 
		- error: float 
	"""
	norm_ord = 2
	difference_image = np.array(x) - np.array(y)
	difference_norm = np.linalg.norm(difference_image, ord=norm_ord)
	y_norm = np.linalg.norm(y, ord=norm_ord)
	return difference_norm/y_norm


def MSE(x,y):
	"""
	Get the MSE (mean squared error) between the two images. 
	args: 
		- x: image to compare 
		- y: second image - the ground truth image 
	returns: 
		- error: float 
	"""
	difference_image = np.array(x) - np.array(y)
	diff_sq_image = np.square(difference_image)
	diff_sq_image_vec = diff_sq_image.reshape(np.product(diff_sq_image.shape))
	return np.mean(diff_sq_image_vec)

def avg_relative_error(xset, yset):
	"""
	The average relative error of all the images in the x and y datasets 
	args: 
		- xset: list of images
		- yset: list of images
	returns: 
		- average relative error 
	"""
	x = np.array(xset)
	y = np.array(yset)

	num_images = x.shape[0]
	x = x.reshape((num_images, -1))
	y = y.reshape((num_images, -1))

	difference_images = x - y 
	difference_norm = np.linalg.norm(difference_images, ord=2, axis=1)
	y_norm = np.linalg.norm(y, ord=2, axis=1)

	return np.mean(difference_norm/y_norm)

def avg_MSE(xset, yset):
	"""
	The average MSE of all the images in the xset and the yset 
	args: 
		- xset: list of images
		- yset: list of images
	"""

	x = np.array(xset)
	y = np.array(yset)

	num_images = x.shape[0]
	x = x.reshape((num_images, -1))
	y = y.reshape((num_images, -1))

	difference_image = x - y 
	diff_img_sq = np.square(difference_image)
	return np.mean(diff_img_sq)

def convert_imgdataset_to_vecdataset(self, dataset): 
        """
        Convert dataset: list of images or list of tuples of images to a dataset that is a list of 
        1 dimensional vectors. If tuple of images it will create an extended list of the flattened 
        images in the tuple concatenated. 
        args:
            - dataset: list of images or list of tuples of images
        returns: 
            - dataset_vec: list of 1 dimensional arrays 
        """

        dataset_vec = []
        # only do this flattening if the data is images
        if len(np.array(dataset).shape) >= 2: 
            # Training Data
            for train_num in range(len(dataset)):
                # Create current X vector
                if type(dataset[0]) == tuple: 
                    curr_x_vec = []
                    for ximg in dataset[train_num]:
                        ximg_vec = np.array(ximg).reshape(-1)
                        curr_x_vec.extend(list(ximg_vec))
                else: 
                    curr_x_vec = dataset[train_num].reshape(-1)

                dataset_vec.append(np.array(curr_x_vec))

        else: 
            dataset_vec = dataset

        return dataset_vec

def compute_stds_off(mean_image, var_image, ground_truth_image):
	"""
	Create a corresponding image that states how many standard deviations the
	ground truth image is off from the mean image. 
	args: 
		- mean_image: the mean image
		- var_image: var image - each pixel is the variance of the corresponding pixel in the mean image
		- ground_truth_image: the ground truth image
	returns: 
		- std_off_image: image where each pixel corresponds to how many standard deviations the ground
						 truth image is from the mean image. 
	"""
	diff_image = np.abs(mean_image - ground_truth_image)
	std_image = np.sqrt(var_image)
	std_image[std_image==0] = 1

	return diff_image/std_image
