import os
import cv2
import numpy as np
from abc import ABC, abstractmethod
from dataloaders.utils import divide_sequence, create_patches
import torch
from torchvision import transforms



class GenericDataLoader(ABC):
    def __init__(
        self,
        path,
        patch_size=16,
        time_steps=20,
        flatten=True,
        scale=False,
        normalize=False,
        vit_clip_processor:'Optional[vit.ViTClipProcessor]'=None, #x->x
        normalized_images=False,
        cache_clips=True,
    ):
        self.patch_size = patch_size
        self.time_steps = time_steps
        self.scale = scale
        self.normalized_images = normalized_images

        # If using VIT processor, or it's requested for a normalized image with VIT processor
        if normalized_images or vit_clip_processor is not None:
            # Use ViT processor for normalization if it's a ViT model
            self.vit_processor = vit_clip_processor
            if flatten:
                print("Warning, got ViT processor and flatten=True. ViT requires unflattened frames.")
                flatten = False
        else:
            # For non-ViT models with normalization, we'll use standard normalization
            self.vit_processor = None
            # Keep the normalize flag to apply standard normalization later
            self.normalize = normalize

        # Load data files and clips

        # Determine cache path - use split-specific cache for Avenue dataset
        cache_path = path + "/cache"
        if hasattr(self, 'split') and self.split and 'avenue' in self.split:
            cache_path = path + "/cache_avenue"
            print(f"Using Avenue-specific cache path: {cache_path}")
        
        # Load files. If cache exists, load from there.
        if os.path.exists(cache_path):
            print(f"Loading clips from cache at {cache_path}")
            self.train_clips = np.load(os.path.join(cache_path, "train_clips.npy"))
            self.eval_clips = np.load(os.path.join(cache_path, "eval_clips.npy"))
            self.test_clips = np.load(os.path.join(cache_path, "test_clips.npy"))
        else:
            print(f"Loading from disk {path}")
            self.train_files, self.eval_files, self.test_files, self.eval_idx = self._load_files(path)
        
            # Loading frames
            self.train_clips = self._load_frames(self.train_files)
            self.eval_clips = self._load_frames(self.eval_files)
            self.test_clips = self._load_frames(self.test_files)

            print(
                "Videos loaded: {} train, {} eval, {} test".format(
                    len(self.train_clips), len(self.eval_clips), len(self.test_clips)
                )
            )

            # Dividing videos into clips
            self.train_clips = [divide_sequence(c, time_steps) for c in self.train_clips]
            self.train_clips = [item for sublist in self.train_clips for item in sublist]
            self.eval_clips = [divide_sequence(c, time_steps) for c in self.eval_clips]
            self.eval_clips = [item for sublist in self.eval_clips for item in sublist]
            self.test_clips = [divide_sequence(c, time_steps) for c in self.test_clips]
            self.test_clips = [item for sublist in self.test_clips for item in sublist]

            if cache_clips and not os.path.exists(os.path.join(cache_path, "train_clips.npy")):
                print(f"Caching clip numpy arrays to disk at {cache_path}")
                if not os.path.exists(cache_path):
                        os.makedirs(cache_path)
                np.save(os.path.join(cache_path, "train_clips.npy"), self.train_clips)
                np.save(os.path.join(cache_path, "eval_clips.npy"), self.eval_clips)
                np.save(os.path.join(cache_path, "test_clips.npy"), self.test_clips)

        if self.vit_processor is not None:
            # ViT preprocessing - only for ViT models
            # Creates "patches" to adapt to the expected dataloader output
            # Output will be (total_num_patches, C, 224, 224)
            
            vit_cache_path = os.path.join(cache_path, "train_vit_patches.npy")
            if os.path.exists(vit_cache_path):
                print(f"Loading ViT-normalized patches from cache at {cache_path}")
                self.train_patches = np.load(os.path.join(cache_path, "train_vit_patches.npy"))
                self.eval_patches = np.load(os.path.join(cache_path, "eval_vit_patches.npy"))
                self.test_patches = np.load(os.path.join(cache_path, "test_vit_patches.npy"))
            else:
                print("Creating ViT-normalized patches.")                    
                self.train_patches = np.concatenate(
                    [self.vit_processor(c) for c in self.train_clips]
                )
                self.eval_patches = np.concatenate(
                    [self.vit_processor(c) for c in self.eval_clips]
                )
                self.test_patches = np.concatenate(
                    [self.vit_processor(c) for c in self.test_clips]
                )

            if cache_clips and not os.path.exists(vit_cache_path):
                print(f"Caching ViT-normalized patches to disk at {cache_path}")
                if not os.path.exists(cache_path):
                    os.makedirs(cache_path)
                np.save(os.path.join(cache_path, "train_vit_patches.npy"), self.train_patches)
                np.save(os.path.join(cache_path, "eval_vit_patches.npy"), self.eval_patches)
                np.save(os.path.join(cache_path, "test_vit_patches.npy"), self.test_patches)
        else:
            # Traditional patching (for non-ViT models)
            # create_patches out: (num_patches, T, patch_size, patch_size)
            print("Creating standard patches.")
            self.train_patches = np.concatenate(
                [create_patches(c, patch_size) for c in self.train_clips]
            )
            self.eval_patches = np.concatenate(
                [create_patches(c, patch_size) for c in self.eval_clips]
            )
            self.test_patches = np.concatenate(
                [create_patches(c, patch_size) for c in self.test_clips]
            )

        self.train_num, self.eval_num, self.test_num = (
            self.train_patches.shape[0],
            self.eval_patches.shape[0],
            self.test_patches.shape[0],
        )
        print(
            "Data formatted to patches. Number of {} frame long patches: {} train, {} eval, {} test".format(
                time_steps, self.train_num, self.eval_num, self.test_num
            )
        )
        
        # Flatten frames into a single vector
        if flatten:
            self.train_patches = self.train_patches.reshape(
                [self.train_num, time_steps, -1]
            )
            self.eval_patches = self.eval_patches.reshape(
                [self.eval_num, time_steps, -1]
            )
            self.test_patches = self.test_patches.reshape(
                [self.test_num, time_steps, -1]
            )
        # Before this, {split}_patches are (B,T,...)

        # Apply normalization if requested for non-ViT models
        if normalize and self.vit_processor is None:
            print("Applying standard normalization for non-ViT model.")
            self._apply_normalization()

        # Permute samples/frames indices
        self.train_patches = np.swapaxes(self.train_patches, 0, 1)
        self.eval_patches = np.swapaxes(self.eval_patches, 0, 1)
        self.test_patches = np.swapaxes(self.test_patches, 0, 1)

        self._reset_dataloader_indices()

        self.train_data_std = np.std(self.train_patches)
        self.eval_data_std = np.std(self.eval_patches)
        self.test_data_std = np.std(self.test_patches)

        self.shuffle()

    @abstractmethod
    def _load_files(self, basedir):
        """
        Load file paths from the dataset directory
        Returns: train_files, eval_files, test_files, eval_idx
        """
        pass

    def _load_frames(self, files):
        """Load frames from file paths"""
        dest = []
        for s in files:
            frames = []

            # Preprocessing frames
            for fr in s:
                f = cv2.imread(fr, cv2.IMREAD_GRAYSCALE)

                # Cropping to bottom left
                sx, sy = f.shape[0], f.shape[1]
                nsx = sx - (sx % self.patch_size)
                nsy = sy - (sy % self.patch_size)
                f = f[sx - nsx:, :nsy]

                # Scaling
                if self.scale is True:
                    f = f / 255.0

                frames.append(f)

            dest.append(np.array(frames))

        return dest

    def _apply_normalization(self):
        """Apply normalization to the data"""
        # Create normalizers
        self.normalizer_train = transforms.Normalize(
            self.train_patches.mean(), self.train_patches.std()
        )
        self.normalizer_eval = transforms.Normalize(
            self.eval_patches.mean(), self.eval_patches.std()
        )
        self.normalizer_test = transforms.Normalize(
            self.test_patches.mean(), self.test_patches.std()
        )

        # Normalize patches
        self.train_patches = self.normalizer_train(
            torch.tensor(self.train_patches)
        ).numpy()
        self.eval_patches = self.normalizer_eval(
            torch.tensor(self.eval_patches)
        ).numpy()
        self.test_patches = self.normalizer_test(
            torch.tensor(self.test_patches)
        ).numpy()

    def _reset_dataloader_indices(self):
        """Reset dataloader indices"""
        self.current_idx_train = 0
        self.current_idx_eval = 0
        self.current_idx_test = 0

    def load_videos_as_arrays(self, split):
        """Load videos as arrays for visualization or analysis"""
        files = None
        videos = None
        if split == "eval":
            files = [str(i) for i in self.eval_idx]
            videos = self._load_frames(self.eval_files)
        elif split == "test":
            files = [str(i) for i in range(len(self.test_files))]
            videos = self._load_frames(self.test_files)
        return videos, files

    @property
    def train(self):
        return self.train_patches

    @property
    def eval(self):
        return self.eval_patches

    @property
    def test(self):
        return self.test_patches

    @property
    def eval_f(self):
        return self.eval_files

    @property
    def test_f(self):
        return self.test_files

    def reset_indices(self):
        """Reset all dataloader indices"""
        self.current_idx_test, self.current_idx_eval, self.current_idx_train = 0, 0, 0

    def shuffle(self):
        """Shuffle training set (video patches, across all clips)"""
        np.random.seed(2021)
        indices = np.random.permutation(self.train_patches.shape[1])
        self.train_patches = self.train_patches[:, indices, ...]
        self.current_idx_train = 0

    def load_batch_train(self, batch_size):
        """Load a batch of training data"""
        if self.current_idx_train + batch_size >= self.train_num:
            batch_end = self.train_patches[:, self.current_idx_train:, ...]
            batch_start = self.train_patches[:, 0:batch_size - batch_end.shape[1], ...]
            batch = np.concatenate((batch_end, batch_start), axis=1)
            self.current_idx_train = 0
        else:
            batch = self.train_patches[:, self.current_idx_train:self.current_idx_train + batch_size, ...]
            self.current_idx_train += batch_size
        return batch

    def load_batch_validation(self, batch_size):
        """Load a batch of validation data"""
        if self.current_idx_eval + batch_size >= self.eval_num:
            batch = self.eval_patches[:, self.current_idx_eval:, ...]
            self.current_idx_eval = 0
        else:
            batch = self.eval_patches[:, self.current_idx_eval:self.current_idx_eval + batch_size, ...]
            self.current_idx_eval += batch_size
        return batch

    def load_batch_test(self, batch_size):
        """Load a batch of test data"""
        if self.current_idx_test + batch_size >= self.test_num:
            batch = self.test_patches[:, self.current_idx_test:, ...]
            self.current_idx_test = 0
        else:
            batch = self.test_patches[:, self.current_idx_test:self.current_idx_test + batch_size, ...]
            self.current_idx_test += batch_size
        return batch 
    