import os
import random
import cv2
import numpy as np
from scipy.ndimage import gaussian_filter
from scipy.ndimage import map_coordinates
import argparse
import re
import h5py
from tqdm import tqdm

def add_black_border(img, sz):
    bordered_img = cv2.copyMakeBorder(img, sz, sz, sz, sz, cv2.BORDER_CONSTANT, value=0)
    return bordered_img

def elastic_transform(image, alpha, sigma, random_state=None):

    if random_state is None:
        random_state = np.random.RandomState(None)

    shape = image.shape
    dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
    dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
    x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
    indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1))
    imageC = map_coordinates(image, indices, order=1, mode='constant').reshape(shape)
    return imageC

def generate_deformed_serial(input_folder, output_folder, alpha, sigma):
    
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    img_files = sorted([f for f in os.listdir(input_folder) if f.endswith('.png')])
    img = cv2.imread(os.path.join(input_folder, img_files[0]))  
    height, width = img.shape[:2]
    n = len(img_files)
    print('len of image stacks: ',n)
    total_count = 0
    
    for i, _ in tqdm(enumerate(img_files)):
        img_path = os.path.join(input_folder, img_files[i])
        img = cv2.imread(img_path, 0) 
        if i>0:
            img = elastic_transform(img, alpha*img.shape[0], sigma*img.shape[1])
        cv2.imwrite(os.path.join(output_folder, f'{i:04d}.png'), img)  
        

def generate_crop_serial(input_folder, output_folder, sz, border_sz):
    os.makedirs(output_folder, exist_ok=True)

    for fname in tqdm(sorted(os.listdir(input_folder))):
        if not fname.lower().endswith('.png'):
            continue

        img_path = os.path.join(input_folder, fname)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        h, w = img.shape

        top = (h - sz) // 2
        left = (w - sz) // 2
        cropped = img[top:top+sz, left:left+sz]

        bordered = cv2.copyMakeBorder(
            cropped,
            top=border_sz, bottom=border_sz,
            left=border_sz, right=border_sz,
            borderType=cv2.BORDER_CONSTANT,
            value=0 
        )

        out_path = os.path.join(output_folder, fname)
        cv2.imwrite(out_path, bordered)

    print(f"Processed images saved to '{output_folder}'.")


def main():
    inpath = './data/raw' 
    out1 = './data/test'
    out2 = './data/test_data1'
    generate_crop_serial(inpath, out1, sz=1024, border_sz=40)
    generate_deformed_serial(out1,out2, alpha=1.1, sigma=0.08)
    
if __name__ == '__main__':
    main()