import os
import numpy as np
from PIL import Image
from torchvision import datasets

#Code for downloading MNIST and saving to image files.
def downloadMnist():

    #Training set.
    try:
        dataset = datasets.MNIST(".", download=True)
        os.mkdir("./MNIST/images")
        for idx, (img, lbl) in enumerate(dataset):
            img.save(f"./MNIST/images/{lbl}_idx{idx}.jpg")
            print(f"Extracting training / val MNIST sets.  {idx + 1} / 60 000", end="\r")
    except:
        print("MNIST training set already exists!  Skipping extraction.")

    #Testing set.
    try:
        dataset = datasets.MNIST(".", train=False)
        os.mkdir("./MNIST/test")
        for idx, (img, lbl) in enumerate(dataset):
            img.save(f"./MNIST/test/{lbl}_idx{idx}.jpg")
            print(f"Extracting test MNIST set.  {idx + 1} / 10 000", end="\r")
    except:
        print("MNIST testing set already exists! Skipping extraction.")


#Function to create a pointer file to all generated images in the dataset.
##  setName (str): Relative path for the dataset.
##  mode (str): "train", "test", or "validate"
def makePointerFile(setName: str, mode: str):
    out = open(f"./{setName}/{setName}_{mode}.txt", 'w+')
    path = f"./{setName}/images/{mode}/"
    for dirpath, dirnames, filenames in os.walk(path):
        for filename in filenames:
            print(f"./images/{mode}/{filename}", file=out)
    out.close()


#Function to scan the entered folder.  Recursively scans all images and folders in the folder.
#All images must be of the same dimensions, or this will not work.
##  folder (str): The path to the folder to scan.
##  Returns (np.ndarray): Average magnitudes that are fft shifted to center.
def scanFolderMagnitudes(folder: str) -> np.ndarray:

    folders_to_check = [folder]
    images_checked = 0
    mag = None

    for root, dirs, files in os.walk(folders_to_check[0]):
        for item in dirs:
            folders_to_check.append(item)
        for item in files:
            if item.endswith(('.jpg', 'png')):

                img = np.asarray(Image.open(root + "\\" + item).convert('RGB').convert('L'))
                out = np.abs(np.fft.fftshift(np.fft.fft2(img)))

                if mag is None:
                    mag = out
                else:
                    if mag.shape != out.shape:
                        raise Exception(f"Expected {root}\\{item} to have shape {mag.shape}.  Received shape {out.shape}.  Ensure all images in {root} have the same dimensions.")
                    mag += out
                images_checked += 1

        folders_to_check.pop(0)

    mag = np.divide(mag, images_checked)

    return mag
