import os
import cv2
import numpy as np
from dataloaders.utils import divide_sequence, create_patches
from dataloaders.generic_dataloader import GenericDataLoader


class ShanghaiTechLoader(GenericDataLoader):
    def __init__(
        self,
        path,
        patch_size=16,
        time_steps=20,
        flatten=True,
        scale=False,
        normalize=False,
        split=None, 
        vit_clip_processor=None,
        normalized_images=False,
    ):
        self.split = split
        super().__init__(path, patch_size, time_steps, flatten, scale, normalize, vit_clip_processor, normalized_images)

    def _load_files(self, basedir):
        """
        Implementation of the abstract method to load ShanghaiTech dataset files
        TO DO this could be moved to the parent class.
        """
        folders = ['shanghaitech']
        if self.split is not None:
            folders = self.split

        train_folder = 'training/videos'
        test_folder = 'testing/videos'

        train_files = []
        test_files = []

        print(f"\n[DEBUG] Loading from folders: {folders}")
        for folder in folders:
            train_dir = os.path.join(basedir, folder, train_folder)
            test_dir = os.path.join(basedir, folder, test_folder)
            print(f"[DEBUG] Looking in train_dir: {train_dir}")
            print(f"[DEBUG] Looking in test_dir: {test_dir}")

            train_files.extend(self._retrieve_files(train_dir))
            test_files.extend(self._retrieve_files(test_dir))
            print(f"[DEBUG] Found {len(train_files)} train files and {len(test_files)} test files so far")

        # Select evaluation videos
        eval_idx = list(range(0, len(train_files), 5))
        eval_files = [train_files[i] for i in eval_idx]
        train_files = [train_files[i] for i in range(len(train_files)) if i not in eval_idx]

        print(f"[DEBUG] Final split:")
        print(f"[DEBUG] Train files: {len(train_files)}")
        print(f"[DEBUG] Eval files: {len(eval_files)}")
        print(f"[DEBUG] Test files: {len(test_files)}")
        if len(train_files) > 0:
            print(f"[DEBUG] Sample train file: {train_files[0][0]}")
        if len(eval_files) > 0:
            print(f"[DEBUG] Sample eval file: {eval_files[0][0]}")
        if len(test_files) > 0:
            print(f"[DEBUG] Sample test file: {test_files[0][0]}")

        return train_files, eval_files, test_files, eval_idx

    def _retrieve_files(self, basedir, maxframes=200):
        """Retrieve ShanghaiTech file paths"""
        out = []
        sequences = os.listdir(basedir)
        for s in sequences:
            out.append([])
            for i in range(1, maxframes+1):
                # Try different file formats
                for fmt in [
                    '{:03d}.tif',
                    '{:03d}.png',
                    '{:04d}.png',
                    '{:02d}.png'
                ]:
                    filename = os.path.join(basedir, s, fmt.format(i))
                    if os.path.exists(filename):
                        out[-1].append(filename)
                        break

        return out


def load_frames(files, patch_size, scale):
    dest = []
    for s in files:
        frames = []

        # Preprocessing frames
        for fr in s:

            f = cv2.imread(fr, cv2.IMREAD_GRAYSCALE)

            # Cropping to bottom left
            sx, sy = f.shape[0], f.shape[1]
            nsx = sx - (sx % patch_size)
            nsy = sy - (sy % patch_size)
            f = f[sx - nsx:, :nsy]

            # Scaling
            if scale is True:
                f = f / 255.0

            frames.append(f)

        dest.append(np.array(frames))

    return dest


if __name__ == '__main__':
    path = '../../Shanghaitech/preprocessed/'
    loader = ShanghaiTechLoader(path, flatten=False, scale=False, split=['avenue'])