import os
import cv2
import numpy as np
from dataloaders.utils import divide_sequence, create_patches
from torchvision import transforms
import torch
from dataloaders.generic_dataloader import GenericDataLoader

class UCSD_Loader(GenericDataLoader):

    def _load_files(self, basedir):
        """Implementation of the abstract method to load UCSD dataset files
        TO DO this could be moved to the parent class. 
        """
        folders = ["UCSDped1", "UCSDped2"]
        train_folder = "Train"
        test_folder = "Test"

        train_files = []
        test_files = []

        for folder in folders:
            train_dir = os.path.join(basedir, folder, train_folder)
            test_dir = os.path.join(basedir, folder, test_folder)

            train_files.extend(self._retrieve_files(train_dir))
            test_files.extend(self._retrieve_files(test_dir))

        # Select evaluation videos
        eval_idx = [0, 4, 9, 14, 19, 24, 29, 34, 39, 44]
        eval_files = [train_files[i] for i in eval_idx]
        train_files = [
            train_files[i] for i in range(len(train_files)) if i not in eval_idx
        ]

        return train_files, eval_files, test_files, eval_idx
    
    def _retrieve_files(self, basedir):
        """Retrieve UCSD file paths"""
        out = []
        sequences = os.listdir(basedir)
        for s in sequences:
            if ("Train" in s or "Test" in s) and ("gt" not in s):
                out.append([])
                for i in range(1, 1000):
                    filename = os.path.join(basedir, s, "{:03d}.tif".format(i))
                    if os.path.exists(filename):
                        out[-1].append(filename)
        return out

    def _load_frames(self, files):
        """Load frames from file paths"""
        dest = []
        for s in files:
            frames = []

            # Preprocessing frames
            for fr in s:
                f = cv2.imread(fr, cv2.IMREAD_GRAYSCALE)

                # Cropping to bottom left
                sx, sy = f.shape[0], f.shape[1]
                nsx = sx - (sx % self.patch_size)
                nsy = sy - (sy % self.patch_size)
                f = f[sx - nsx:, :nsy]

                # Scaling
                if self.scale is True:
                    f = f / 255.0

                frames.append(f)

            dest.append(np.array(frames))

        return dest

if __name__ == "__main__":
    path = "../../UCSD/"
    loader = UCSD_Loader(path, flatten=False, scale=False)
