import matplotlib.pyplot as plt
import numpy as np
import scipy.io
import os
import scipy.io
import numpy as np
from PIL import Image


import math

from sklearn.linear_model import OrthogonalMatchingPursuit
from utilies import normalize_matrix, down_sample_matrix, mat2vec, vec2mat, mask_response_circle

def matrix_energy(matrix):
    # Convert the input to a NumPy array if it's not already one
    matrix = np.array(matrix)
    # Calculate the sum of the squares of all entries in the matrix
    sum_of_squares = np.sum(np.square(matrix))
    # Return the square root of the sum of squares
    return np.sqrt(sum_of_squares)


def load_and_process_data(load_num, down_sample_factor_res, down_sample_factor_gt, response_file_path, gt_file_path, start_index=1):
    response_list = []
    gt_list = []

    response_down_list = []
    gt_down_list = [] 
    response_down_flat_list = []
    gt_down_flat_list = []

    res_energy_list = []
    gt_energy_list = []
    print('Start Loading Data')

    for i in range(load_num):
        
        # Construct file paths
        response_name = os.path.join(f'{response_file_path}{i+start_index}.mat')
        gt_name = os.path.join(f'{gt_file_path}{i+start_index}.mat')
        
        # Load MAT files
        mat_response = scipy.io.loadmat(response_name)
        mat_gt = scipy.io.loadmat(gt_name)

        # Process response
        mat_response_content = mat_response['img1'].astype(float)
        mat_response_content_part = mat_response_content[300:1500, 480:1680]
        mask = mask_response_circle(mat_response_content_part)
        mat_response_content_part_masked = mat_response_content_part * mask
        mat_response_content_part_masked[mat_response_content_part_masked < 0] = 0

        # Process ground truth
        mat_gt_content = mat_gt['img0'].astype(float)
        mat_gt_content[mat_gt_content < 0] = 0
        mat_gt_content_part = mat_gt_content[1070:1570, 1080:1580]
        
        # Normalize
        mat_response_normalized = normalize_matrix(mat_response_content_part_masked)
        mat_gt_normalized = normalize_matrix(mat_gt_content_part)

        # Append energies
        res_energy_list.append(matrix_energy(mat_response_normalized))
        gt_energy_list.append(matrix_energy(mat_gt_normalized))

        # Downsample
        down_res = down_sample_factor_res
        down_gt = down_sample_factor_gt
        mat_response_down = down_sample_matrix(mat_response_normalized, down_res)
        mat_gt_down = down_sample_matrix(mat_gt_normalized, down_gt)


        # Flatten
        vec_response_down_masked = mat2vec(mat_response_down)
        vec_response_down_masked_zeroout = vec_response_down_masked[vec_response_down_masked > 0]

        response_down_flat_list.append(vec_response_down_masked_zeroout)
        gt_down_flat_list.append(mat2vec(mat_gt_down))

    # Convert lists to arrays for any further processing
    response_down_flat_list = np.array(response_down_flat_list)   
    gt_down_flat_list = np.array(gt_down_flat_list)

    '''
    # Plot energies
    plt.plot(res_energy_list, label='Response Energy')
    plt.plot(gt_energy_list, label='Ground Truth Energy')
    plt.title('Energy Comparison')
    plt.xlabel('Index')
    plt.ylabel('Energy')
    plt.legend()
    plt.show()

    '''

    print('Data Loaded')

    return response_down_flat_list, gt_down_flat_list



import numpy as np
from sklearn.linear_model import ElasticNet

def fit_and_save_elastic_net(gt_data, response_data, alpha_val, l1_val, load_num, down_sample_factor_res, down_sample_factor_gt, save_path):
    # Initialize the ElasticNet model
    print('Initial ELASTICNET')

    elastic_model = ElasticNet(alpha=alpha_val, l1_ratio=l1_val, fit_intercept=False, max_iter=2000, tol=5e-2, selection='random')

    print('error 5e-2')
    
    # Fit the model to the data
    print('START FITTING')
    elastic_model.fit(gt_data, response_data)
    D_elastic_net = elastic_model.coef_
    
    # Generate the file path for saving the model coefficients
    file_name = f'{save_path}/Ela_load{load_num}_res{down_sample_factor_res}_gt{down_sample_factor_gt}_NewFiberData_alpha{alpha_val}_l1{l1_val}_err5e_2.npy'

    
    # Save the matrix to a .npy file
    np.save(file_name, D_elastic_net)
    
    # Print the save confirmation
    print(f'File saved: {file_name}')

    return D_elastic_net





'''
load_num = 32000
down_sample_factor_res=25
down_sample_factor_gt=5

response_file_path='F:/0125_dataset/01092025 test slide 1/s'
gt_file_path='F:/0125_dataset/01092025 test slide 1/gt'


response_down_flat_list, gt_down_flat_list = load_and_process_data(
                    load_num=load_num, 
                    down_sample_factor_res=down_sample_factor_res, 
                    down_sample_factor_gt=down_sample_factor_gt,
                    response_file_path=response_file_path,
                    gt_file_path=gt_file_path)


D = fit_and_save_elastic_net(
    gt_data=gt_down_flat_list, 
    response_data=response_down_flat_list, 
    alpha_val=1e-3, 
    l1_val=0.5, 
    load_num=load_num, 
    down_sample_factor_res=down_sample_factor_res, 
    down_sample_factor_gt=down_sample_factor_gt,
    save_path='F:/0125_dataset/Ela_param')
'''