import os
import numpy as np
import h5py
import cv2
import matplotlib.pyplot as plt


def main():

    def generate_gabor_params():
        # 尺度頻率
        scales = [1, 2, 4, 8, 16, 32]
        orientations = 8
        phases = [0, np.pi/2]
        ksize = 31  # 濾波核大小
        sigma = 4.0
        gamma = 0.5
        gabor_params = []
        for scale in scales:
            # Wavelet grid大小
            grid_size = scale
            lambd = float(ksize) / scale
            for y in range(grid_size):
                for x in range(grid_size):
                    # 空間位置 (這部分可能用至時刻需自行調整wavelet位置)
                    for ori in range(orientations):
                        theta = ori * np.pi / orientations
                        for psi in phases:
                            params = (ksize, sigma, theta, lambd, gamma, psi)
                            gabor_params.append(params)
        return gabor_params

    def simple_cell_response(image, gabor_params):
        responses = []
        for params in gabor_params:
            kernel = cv2.getGaborKernel((params[0], params[0]), params[1], params[2], params[3], params[4], params[5], ktype=cv2.CV_64F)
            filtered = cv2.filter2D(image, cv2.CV_64F, kernel)
            rectified = np.maximum(filtered, 0)
            responses.append(rectified)
        return responses
    
    def complex_cell_response(image, gabor_params):
        # gabor_params由兩兩成對（兩相位0和pi/2）組成四個simple cell平方和
        # 這裏簡化只考慮兩相位平方和後取log變換
        complex_responses = np.zeros((len(gabor_params)//2))
        n = len(gabor_params)
        for i in range(0, n, 2):
            filtered_sum = np.zeros(image.shape, dtype=np.float64)
            for j in [i, i+1]:
                kernel = cv2.getGaborKernel((gabor_params[j][0], gabor_params[j][0]), gabor_params[j][1], gabor_params[j][2], gabor_params[j][3], gabor_params[j][4], gabor_params[j][5], ktype=cv2.CV_64F)
                filtered = cv2.filter2D(image, cv2.CV_64F, kernel)
                filtered_sum += np.square(np.maximum(filtered, 0))
            complex_response = np.log1p(filtered_sum)
            complex_responses[i//2] = complex_response.mean()
            
        return complex_responses

    
    path = '/home/users/yhung7/SDAM/data/'

    stimu_dict = np.load(path+'stimu_dict.npy', allow_pickle=True).item()
    X_trn = stimu_dict['Train']
    X_val = stimu_dict['Val']

    gabor_params = generate_gabor_params()
    complex_matrix = np.zeros((X_trn.shape[0], len(gabor_params)//2))
    
    for i in range(X_trn.shape[0]):
        complex_responses = complex_cell_response(X_trn[i], gabor_params)
        complex_matrix[i] = complex_responses

    np.save(path+'train_complex_matrix.npy', complex_matrix)

    complex_matrix = np.zeros((X_val.shape[0], len(gabor_params)//2))
    
    for i in range(X_val.shape[0]):
        complex_responses = complex_cell_response(X_val[i], gabor_params)
        complex_matrix[i] = complex_responses

    np.save(path+'val_complex_matrix.npy', complex_matrix)
    
if __name__ == "__main__":
    main()







