import numpy as np
import scipy.io
import os
import scipy.io
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt


import math

from sklearn.linear_model import OrthogonalMatchingPursuit



def normalize_matrix(matrix):
    # Convert the matrix to a numpy array if it isn't already
    matrix = np.array(matrix)
    
    # Find the minimum and maximum values in the matrix
    min_val = np.min(matrix)
    max_val = np.max(matrix)
    
    # Normalize the matrix to range [0, 1]
    normalized_matrix = (matrix - min_val) / (max_val - min_val)
    
    return normalized_matrix


def down_sample_matrix(input_matrix, downfactor):
    # Get the dimensions of the input matrix
    rows, cols = input_matrix.shape
    
    # Calculate the dimensions of the downsampled matrix
    new_rows = rows // downfactor
    new_cols = cols // downfactor
    
    # Initialize the downsampled matrix
    down_matrix = np.zeros((new_rows, new_cols))
    
    for i in range(new_rows):
        for j in range(new_cols):
            # Define the block from the input matrix to average
            block = input_matrix[i*downfactor:(i+1)*downfactor, j*downfactor:(j+1)*downfactor]
            # Compute the average of the block and assign it to the downsampled matrix
            down_matrix[i, j] = np.mean(block)
    
    return down_matrix

def split_data(response_list, gt_list, train_ratio=0.97, val_ratio=0.02, test_ratio=0.01):
    assert len(response_list) == len(gt_list), "The two lists must have the same length"
    
    # Convert to numpy arrays for easier manipulation
    response_array = np.array(response_list)
    gt_array = np.array(gt_list)
    
    # Shuffle the data
    indices = np.arange(len(response_list))
    np.random.shuffle(indices)
    response_array = response_array[indices]
    gt_array = gt_array[indices]
    
    # Calculate the number of samples for each set
    total_samples = len(response_list)
    train_end = int(train_ratio * total_samples)
    val_end = train_end + int(val_ratio * total_samples)
    
    # Split the data
    response_train = response_array[:train_end]
    gt_train = gt_array[:train_end]
    response_val = response_array[train_end:val_end]
    gt_val = gt_array[train_end:val_end]
    response_test = response_array[val_end:]
    gt_test = gt_array[val_end:]
    
    return (response_train, gt_train), (response_val, gt_val), (response_test, gt_test)


def mat2vec(matrix):
    """
    Turns a matrix into a vector.
    
    Parameters:
    matrix (array-like): The input matrix.
    
    Returns:
    np.ndarray: The resulting vector.
    """
    return np.array(matrix).flatten()

def vec2mat(vec, matrix_shape):
    """
    Turns a vector into a matrix with a given shape.
    
    Parameters:
    vec (array-like): The input vector.
    matrix_shape (tuple): The shape of the desired output matrix (rows, cols).
    
    Returns:
    np.ndarray: The resulting matrix.
    """
    required_length = matrix_shape[0]*matrix_shape[1]
    if len(vec) < required_length:
    # Pad the vector with zeros (or other appropriate values) to match the required length
        vec = np.pad(vec, (0, required_length - len(vec)), 'constant')

    return np.array(vec).reshape(matrix_shape)


def mask_response_circle(image):

    # Input an image, and return a circle round mask to get the main part in the scope img

    image_np = np.array(image)

    # Define the center and radius of the circle
    height, width = image_np.shape
    center = (width // 2, height // 2)
    radius = min(center)  # Adjust radius as needed

    y, x = np.ogrid[:height, :width]

    # Create a mask
    mask = np.zeros((height, width), dtype=np.uint8)
    dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
    mask[dist_from_center <= radius] = 1

    # Apply the mask
    # masked_image = image_np * mask

    # plt.imshow(masked_image, cmap='gray')


    return mask



def show_img(file_name, show_range):
    '''
    Read the .mat files from the specified directory, display the matrices as images.
    Only 'show_range' number of images are displayed.

    Parameters:
        file_name (str): Directory containing .mat files.
        show_range (int): Number of images to display.
    '''

    # List all files in the directory containing .mat files
    files = [f for f in os.listdir(file_name) if f.endswith('.mat')]
    files = sorted(files)  # Sort the files to maintain a consistent order

    # Loop over the first 'show_range' files
    for i, file in enumerate(files[:show_range]):
        full_path = os.path.join(file_name, file)  # Build the full file path
        mat_data = scipy.io.loadmat(full_path)  # Load the .mat file
        
        # Assuming there is a known variable/key in the .mat files that holds the image data
        # Typically, this key should be known or inspected beforehand
        data_key = list(mat_data.keys())[-1]  # Using the last key as an example; adjust as needed

        # Retrieve the matrix from the loaded .mat file
        matrix = mat_data[data_key]

        print(matrix.shape)
        
        # Display the matrix as an image
        plt.figure(figsize=(5, 5))  # You can adjust the size as necessary
        plt.imshow(matrix, cmap='gray')  # You can change the color map as needed
        plt.axis('off')
        plt.title(f"Image from file: {i}")
        # plt.colorbar()  # Optional, to show the color scale
        plt.show()


