import numpy as np
import os
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm


#### ONLY WITH THE REQUIRED DATA

def crop_center(image, crop_size):
    c, h, w = image.shape
    start_x = w // 2 - crop_size // 2
    start_y = h // 2 - crop_size // 2
    return image[:,start_y:start_y + crop_size, start_x:start_x + crop_size]

# Parameters
folder_path = 'SARS-COV-2_CT_COVID'
# Lists to store data
dataset = []
file_numbers = []

crop_size = 40

# Process each file in the folder
try:
    for file_name in tqdm(os.listdir(folder_path)):
        if file_name.endswith('.png'):
            file_path = os.path.join(folder_path, file_name)

            img = plt.imread(file_path)

            img = img.mean(axis=2)
            tensor_image = torch.tensor(img.astype(np.float32), dtype=torch.float32).unsqueeze(0).unsqueeze(0)
            cropped_image = crop_center(img, crop_size)
            

            tensor_image = cropped_image/cropped_image.max()
    
            dataset.append(tensor_image)


    dataset = torch.stack(dataset)

    print("Dataset shape:", dataset.shape)


    torch.save(dataset, 'SARS-COV-2_CT_COVID-40.pt')
except:
    print('DATA NOT FOUND')