import cv2
import numpy as np
import os
import matplotlib.pyplot as plt

import argparse


"""
Dictionary containing preprocessing configuration for each dataset:
{
    'dataset_name': {
        # INPUT CONFIGURATION
        'extension': str - File extension for output images (e.g., 'png', 'tif')
        
        # INPUT PATHS (relative to BASE_DATA_FOLDER)
        'frames_folders': list[str] - Directories containing pre-extracted frames
                           that need to be processed
        
        # OPTIONAL INPUT PATHS (relative to BASE_DATA_FOLDER)
        'videos_folders': list[str] - Directories containing video files that need
                           to be converted to frames first
    }
}

Processing pipeline:
1. If 'videos_folders' is present:
   - Videos are extracted to frames (saved to PREPROCESSED_FOLDER/videos_folder/video_name/)
   - Each frame is converted to grayscale and resized to maintain aspect ratio with target_height

2. For all 'frames_folders':
   - Each frame is converted to grayscale and resized
   - Processed frames are saved to corresponding paths under PREPROCESSED_FOLDER
   - Output structure mirrors input structure: PREPROCESSED_FOLDER/frames_folder/...

Output: All processed images are saved to the PREPROCESSED_FOLDER directory,
        maintaining the same relative path structure as the input paths.
"""
PREPROCESSING_CONFIG = {
    'ucsd': {
        'extension': 'tif',
        'frames_folders': ['UCSD_Anomaly_Dataset.v1p2/UCSDped1/Train',
                           'UCSD_Anomaly_Dataset.v1p2/UCSDped1/Test',
                           'UCSD_Anomaly_Dataset.v1p2/UCSDped2/Train',
                           'UCSD_Anomaly_Dataset.v1p2/UCSDped2/Test'],
    },
    'shanghaitech': {
        'extension': 'png',
        'frames_folders': ['shanghaitech/testing/frames', 'shanghaitech/training/videos'],
        'videos_folders': ['shanghaitech/training/videos']
    },
    'avenue': {
        'extension': 'png',
        'frames_folders': ['avenue/training/videos', 'avenue/testing/videos'],
        'videos_folders': ['avenue/training/videos', 'avenue/testing/videos']
    }
}

BASE_DATA_FOLDER="data/"
PREPROCESSED_FOLDER = "data/preprocessed/"

#UCSD 238x158 px
target_height = 160


def preprocess_datasets(datasets_to_process):
    for dataset in datasets_to_process:
        config = PREPROCESSING_CONFIG[dataset]
        frames_folders = config['frames_folders']
        if not os.path.exists(PREPROCESSED_FOLDER):
            os.makedirs(PREPROCESSED_FOLDER)
        extension  = config['extension'] 
        extension = extension if extension[0] == '.' else f'.{extension}'
        
        # Some datasets (avenue, shanghaitech) are in video format and we need to process that first.
        if 'videos_folders' in config:
            preprocess_videos(config['videos_folders'],extension)
            
        # When we have video frames as images, then we can preprocess them
        preprocess_frames(frames_folders, extension)
        
    
def preprocess_videos(video_folders,extension):
    """Turns videos into images
    
    video_folders: list of video directories relative to BASE_DATA_FOLDER. 
    extension: extension of the images to be saved
    
    Output:
    Saves images in PREPROCESSED_FOLDER, in the same structure as the input video.
    If the folder is called video it will still be call video under the preprocessed folder!
    
    """
    # #videos
    for folder in video_folders:
        input_folder = os.path.join(BASE_DATA_FOLDER,folder)
        videos = os.listdir(input_folder)
        for video in videos:

            videofile = os.path.join(input_folder, video)
            outputfolder = os.path.join(PREPROCESSED_FOLDER, folder, video.split('.')[0])

            if not os.path.exists(outputfolder):
                os.makedirs(outputfolder)

            idx = 0
            vc = cv2.VideoCapture(videofile)
            ret, frame = vc.read()

            print(video)

            while ret is True:

                image = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
                sx, sy = image.shape[1], image.shape[0]
                nsx = np.round(sx * target_height / sy).astype('int')
                image = cv2.resize(image, (nsx, target_height), interpolation=cv2.INTER_AREA)


                cv2.imwrite(os.path.join(outputfolder, '{:03d}'.format(idx) + extension), image)
                idx +=1
                ret, frame = vc.read()

def preprocess_frames(frames_folders,extension):
    for folder in frames_folders:
        input_folder = os.path.join(BASE_DATA_FOLDER,folder)
        subfolders = os.listdir(input_folder)
        for subfolder in subfolders:
            sf = os.path.join(input_folder, subfolder)
            of = os.path.join(PREPROCESSED_FOLDER, folder, subfolder)

            if os.path.exists(of) is False:
                os.makedirs(of)

            frames = os.listdir(sf)
            print(subfolder)
            for f in frames:
                outputfile = os.path.join(of, f.split('.')[0]) + extension

                image = cv2.cvtColor(cv2.imread(os.path.join(sf, f)), cv2.COLOR_RGB2GRAY)
                sx, sy = image.shape[1], image.shape[0]
                nsx = np.round(sx * target_height/sy).astype('int')
                image = cv2.resize(image, (nsx, target_height), interpolation=cv2.INTER_AREA)

                if os.path.exists(outputfile) is False:
                    cv2.imwrite(outputfile, image)
        
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Preprocess datasets')
    parser.add_argument('--datasets', nargs='+', type=str, help='List of datasets to preprocess. "all" to process all knwon datasets',default="all")
    args = parser.parse_args()

    datasets_to_process = args.datasets
    if datasets_to_process[0] == "all":
        # Preprocess all datasets for which we have config
        datasets_to_process = list(PREPROCESSING_CONFIG.keys())
    preprocess_datasets(datasets_to_process)
    print("Preprocessing done")