from sklearn.linear_model import OrthogonalMatchingPursuit
import numpy as np
from utilies import normalize_matrix, down_sample_matrix, mat2vec, vec2mat, mask_response_circle
import matplotlib.pyplot as plt
import os
import scipy.io
from sklearn.linear_model import OrthogonalMatchingPursuit







def save_data(file_name_response, file_name_gt, file_name_gt_save, file_name_recon_save, file_name_response_save, 
              load_num, down_sample_factor, matrix, start_index=1, response_key='bla', gt_key='temp_patch'):
    
    
    omp_model = OrthogonalMatchingPursuit(n_nonzero_coefs=50)

    for i in range(load_num):
        if (i % 100) == 0:
            print(f'Current File is {i + 1}')
        response_name = os.path.join(f'{file_name_response}{i+1}.mat')
        gt_name = os.path.join(f'{file_name_gt}{i+1}.mat')


        # Load MAT files
        mat_response = scipy.io.loadmat(response_name)
        mat_gt = scipy.io.loadmat(gt_name)

        # Process response


        
        mat_response_content = mat_response['img1'].astype(float)
        mat_response_content_part = mat_response_content[300:1500, 480:1680]
        mask = mask_response_circle(mat_response_content_part)
        mat_response_content_part_masked = mat_response_content_part * mask
        mat_response_content_part_masked[mat_response_content_part_masked < 0] = 0



        # Process ground truth
        mat_gt_content = mat_gt['img0'].astype(float)
        mat_gt_content[mat_gt_content < 0] = 0
        mat_gt_content_part = mat_gt_content[1070:1570, 1080:1580]
        
        # Normalize
        mat_response_normalized = normalize_matrix(mat_response_content_part_masked)
        mat_gt_normalized = normalize_matrix(mat_gt_content_part)
        '''
        # Downsample
        mat_response_down = down_sample_matrix(mat_response_normalized, down_sample_factor)
        mat_gt_down = down_sample_matrix(mat_gt_normalized, down_sample_factor)


        # Flatten
        vec_response_down_masked = mat2vec(mat_response_down)
        vec_response_down_masked_zeroout = vec_response_down_masked[vec_response_down_masked > 0]


        omp_model.fit(matrix, vec_response_down_masked_zeroout)
        recon = omp_model.coef_

        gt_est = vec2mat(recon, (50, 50))
        gt_est[gt_est < 0] = 0
        '''
        # Save the estimated ground truth and downsampled ground truth as .mat files
        res_save_path = os.path.join(file_name_response_save, f"response_{i + start_index}.mat")
        #recon_save_path = os.path.join(file_name_recon_save, f"recon_{i + start_index}.mat")
        gtdown_save_path = os.path.join(file_name_gt_save, f"gt_{i + start_index}.mat")
        
        scipy.io.savemat(res_save_path, {'mat_response': mat_response_normalized})
        #scipy.io.savemat(recon_save_path, {'gt_est_new': gt_est})
        scipy.io.savemat(gtdown_save_path, {'mat_gt': mat_gt_normalized})

    return 0




import numpy as np
from scipy.fftpack import fft2, ifft2, fftshift, ifftshift
import matplotlib.pyplot as plt
from PIL import Image
from scipy.signal import convolve2d, wiener
from skimage import color, data, restoration

def load_image(path):
    """Load an image from a file path."""
    with Image.open(path) as img:
        img = img.convert('L')  # Convert to grayscale
    return np.array(img)

def frequency_filter(shape):
    """Create a frequency domain filter to emphasize mid-high frequencies."""
    r, c = shape
    center_r, center_c = r // 2, c // 2
    Y, X = np.ogrid[:r, :c]
    mask = np.exp(-((X - center_c)**2 + (Y - center_r)**2) / (2*((r+c)/10)**2))
    return 1 - mask

def filter_enhance_image(image):
    """Enhance the image by applying a frequency-selective filter."""
    # Convert to frequency domain

    """Enhance the image by applying a frequency-selective filter."""
    if image.ndim != 2:
        raise ValueError("Input image must be a 2D numpy array")
    
    frequency_domain = fftshift(fft2(image))
    
    # Create a frequency domain filter
    filter_mask = frequency_filter(image.shape)
    
    # Apply the filter
    filtered_image = frequency_domain * filter_mask
    
    # Convert back to spatial domain
    enhanced_image = ifft2(ifftshift(filtered_image))
    return np.abs(enhanced_image)





def wiener_deblur(image, noise_ratio):

    # Assuming test_output[0] is a PyTorch tensor and has shape [1, H, W] or [H, W]
    # Convert PyTorch tensor to numpy and remove any singleton dimensions
    # image = test_output[0][0].cpu().numpy()
    if image.ndim == 3:
        image = image.squeeze()  # Remove the channel dimension if it's singleton

    # Check the dimensions
    if image.ndim != 2:
        raise ValueError("Image must be 2D")

    # Define the PSF
    psf = np.ones((5, 5)) / 25

    # Apply the Wiener filter
    deblurred_wiener = wiener(image, mysize=psf.shape, noise=0.1)

    return deblurred_wiener


def Richardson_Lucy_Deconv(image):

    # Richardson-Lucy Deconvolution

    # Create a synthetic point spread function (PSF)
    psf = np.ones((5, 5)) / 25

    # Apply blur to the image
    blurred = convolve2d(image, psf, 'same')

    # Richardson-Lucy deconvolution
    deblurred_rl = restoration.richardson_lucy(blurred, psf)

    return deblurred_rl



import numpy as np
from scipy.ndimage import gaussian_filter

def gaussian_smoothing(image, sigma=1):
    """
    Apply Gaussian smoothing to a 2D numpy array (image).

    Parameters:
    - image: A 2D numpy array representing the image.
    - sigma: Standard deviation of the Gaussian kernel.

    Returns:
    - A 2D numpy array of the smoothed image.
    """
    smoothed_image = gaussian_filter(image, sigma=sigma)
    return smoothed_image


import numpy as np
from scipy.ndimage import median_filter

def median_denoising(image, size=5):
    """
    Denoise an image by applying a median filter.

    Parameters:
    - image: A 2D numpy array representing the noisy image.
    - size: Size of the neighborhood from which the median is taken.

    Returns:
    - A 2D numpy array of the denoised image.
    """
    denoised_image = median_filter(image, size=size)
    return denoised_image


def test_result_OMP_show(D_est_matrix, show_range, nonzeronum, test_response, test_gt):
    for it in range(show_range):
        omp_model = OrthogonalMatchingPursuit(n_nonzero_coefs=nonzeronum)
        omp_model.fit(D_est_matrix, test_response[it])
        recon = omp_model.coef_

        # Assuming vec2mat is correctly defined elsewhere
        gt_est = vec2mat(recon, (50, 50))
        gt_est_new = np.copy(gt_est)
        gt_est_new[gt_est < 0.2] = 0

        # Assume these functions are defined elsewhere
        enhanced_image = filter_enhance_image(gt_est_new)
        denoised_image = median_denoising(gt_est_new)
        deblurred_rc = Richardson_Lucy_Deconv(gt_est_new)
        smoothed_image = gaussian_smoothing(gt_est_new, sigma=1)

        # Display the images
        plt.figure(figsize=(20, 10))

        # Subplot for the original GT
        plt.subplot(231)
        plt.imshow(vec2mat(test_gt[it], (50, 50)), cmap='gray')
        plt.title('Original GT')
        plt.axis('off')

        # Subplot for the Predicted GT
        plt.subplot(232)
        plt.imshow(gt_est_new, cmap='gray')
        plt.title('Predicted GT')
        plt.axis('off')

        # Subplot for the Smoothed Image
        plt.subplot(233)
        plt.imshow(smoothed_image, cmap='gray')
        plt.title('Smoothed Image')
        plt.axis('off')


        # Show all plots
        plt.show()








'''
D = np.load('F:/0125_dataset/Ela_param/Ela_load32000_down25_data0519_alpha0.001_l1_0.5.npy')


save_data(
    file_name_response='F:/0125_dataset/01092025 test slide 1/s', 
    file_name_gt='F:/0125_dataset/01092025 test slide 1/gt', 
    file_name_gt_save='F:/0125_dataset/0125_gt_500/', 
    file_name_recon_save='F:/0125_dataset/0125_recon_down25/', 
    file_name_response_save='F:/0125_dataset/0125_res_1200/', 
    load_num=32000, 
    down_sample_factor=25, 
    matrix=D, 
    start_index=1, 
    response_key='bla', 
    gt_key='temp_patch')
'''