from skimage import io, img_as_float, img_as_ubyte
from skimage.restoration import denoise_tv_chambolle
import sys
import os
from pathlib import Path
from datasets import load_from_disk
from datasets import concatenate_datasets
from utils import *

PATH = Path('/data/vision/___/scratch/___ht/fastmri_dir')


if __name__ == "__main__":

    path_prefix = Path("/data/vision/___/scratch/___ht/fastmri_dir")
    path = path_prefix / "hf/fastmri-kspace-r6"
    save_path = path_prefix / "datasets_for_eval/tv"

    dataset = load_from_disk(path)
    dataset.set_format('numpy')

    trainset_yA = dataset['train']
    trainset_yA = concatenate_datasets([trainset_yA] * 2)

    for idx, image_arr in tqdm(enumerate(trainset_yA['y'])):
        image_arr = ifft2c(real2complex(image_arr)).real # range [-2, 2]
        image_arr = (image_arr + 2) / 4 # range [0, 1]
        denoised = denoise_tv_chambolle(image_arr, weight=0.1, channel_axis=None)
        denoised = (denoised * 4) - 2 # range [-2, 2]
        to_pil(denoised).save(save_path / f"{idx+1}.png")
        # to_pil(image_arr).save(save_path / f"{idx+1}.png")
    
