import os
import numpy as np
import h5py
import cv2
import matplotlib.pyplot as plt
import argparse

def main():

    parser = argparse.ArgumentParser(description="Train a model")
    parser.add_argument('--seed', type=str, default="1")

    def generate_gabor_params():
        # 尺度頻率
        scales = [1, 2, 4, 8, 16, 32]
        orientations = 8
        phases = [0, np.pi/2]
        ksize = 128  # 濾波核大小
        sigma = 4.0
        gamma = 0.5
        gabor_params = []
        for scale in scales:
            # Wavelet grid大小
            grid_size = scale
            lambd = float(ksize) / scale
            for y in range(grid_size):
                for x in range(grid_size):
                    # 空間位置 (這部分可能用至時刻需自行調整wavelet位置)
                    for ori in range(orientations):
                        theta = ori * np.pi / orientations
                        for psi in phases:
                            params = (ksize, sigma, theta, lambd, gamma, psi)
                            gabor_params.append(params)
        return gabor_params
    
    '''
    def simple_cell_response(image, gabor_params):
        responses = []
        for params in gabor_params:
            kernel = cv2.getGaborKernel((params[0], params[0]), params[1], params[2], params[3], params[4], params[5], ktype=cv2.CV_64F)
            filtered = cv2.filter2D(image, cv2.CV_64F, kernel)
            rectified = np.maximum(filtered, 0)
            responses.append(rectified)
        return responses
    '''
    
    def complex_cell_response(image, gabor_params):
    
        n = len(gabor_params)
        complex_responses = np.zeros((n//2))
        
        for i in range(0, n, 2):
            filtered_sum = 0
            for j in [i, i+1]:
                kernel = cv2.getGaborKernel((gabor_params[j][0], gabor_params[j][0]), gabor_params[j][1], gabor_params[j][2], gabor_params[j][3], gabor_params[j][4], gabor_params[j][5], ktype=cv2.CV_64F)
                filtered = cv2.filter2D(image, cv2.CV_64F, kernel).reshape(1, -1)
                filtered = filtered @ (image.reshape(-1,1))
                filtered_sum += np.square(filtered)
    
            filtered_sum = np.sqrt(filtered_sum)
            complex_responses[i//2] = np.log1p(filtered_sum.item())
            
        return complex_responses

    args = parser.parse_args()
    
    path = '/home/users/yhung7/SDAM/data/'

    X_trn = np.load(path+'stimu_dict_' + args.seed +'.npy', allow_pickle=True)

    gabor_params = generate_gabor_params()
    complex_matrix = np.zeros((X_trn.shape[0], len(gabor_params)//2))
    
    for i in range(X_trn.shape[0]):
        complex_responses = complex_cell_response(X_trn[i], gabor_params)
        complex_matrix[i] = complex_responses

    np.save(path+'complex_matrix_k128'+str(args.seed)+'.npy', complex_matrix)
    
if __name__ == "__main__":
    main()






