import os
import pickle

import torch
import torchvision.transforms
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
from einops import rearrange


class FullDataset(Dataset):
    def __init__(self, stimulus_path, name_path, transform, spike_array_path, idx=None,
                 target_idx=None, silhouettes=False):
        """

        Args:
            stimulus_path:
            name_path: If name path is set, images are loaded in the order given in the corresponding text doc.
                        If name path is None, images are loaded according to the order of listdir()
            transform:
            spike_array_path:
            idx: Indices of images to use. If target_idx is None, also selects the corresponding rows of the spike matrix.
            target_idx: Useful if indices between images and spike matrix do not match because not all images were shown.
            If None, use idx argument to select rows of the spike matrix. If not None, use these rows
            from the spike matrix
            silhouettes:
        """
        # Images
        self.stimulus_path = stimulus_path
        if name_path is not None:
            with open(name_path, "rb") as fp:
                self.image_paths = pickle.load(fp)
        else:
            self.image_paths = os.listdir(stimulus_path)
        self.transform = transform

        # Targets
        self.spike_array = np.load(spike_array_path)

        # Select subset
        if idx is not None:
            subset = [self.image_paths[i] for i in idx]
            self.image_paths = subset
            if target_idx is None:
                # Use the idx argument for selecting rows from spike matrix
                self.spike_array = self.spike_array[idx, :]
            else:
                self.spike_array = self.spike_array[target_idx, :]
        assert len(self.image_paths) == self.spike_array.shape[0]
        self.silhouettes = silhouettes


    def __getitem__(self, index):
        # Image
        try:
            image_path = os.path.join(self.stimulus_path, self.image_paths[index])
            x = Image.open(image_path)
        except:
            filename = self.image_paths[index]
            filename = filename.replace('jpg', 'png')
            image_path = os.path.join(self.stimulus_path, filename)
            x = Image.open(image_path)
        if self.transform is not None:
            x = self.transform(x)

        if self.silhouettes:    # Convert Image to Silhouette
            mask = (x != x[0, 0, 0]).any(dim=0)
            x[:, mask] = 0

        # Target
        y = self.spike_array[index, :]
        y = torch.from_numpy(y)
        return x, y

    def __len__(self):
        return len(self.image_paths)


class ImageDataset(Dataset):
    '''
    Dataset containing only images, not neural data
    '''
    def __init__(self, stimulus_path, idx=None):
        # Images
        self.stimulus_path = stimulus_path
        self.image_paths = os.listdir(stimulus_path)
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Resize(280)
        ])
        # Select subset
        if idx is not None:
            subset = [self.image_paths[i] for i in idx]
            self.image_paths = subset

    def set_subset(self, filenames):
        '''
        Reduce the dataset by selecting only those in filenames
        '''
        self.image_paths = filenames

    def __getitem__(self, index):
        # Image
        image_path = os.path.join(self.stimulus_path, self.image_paths[index])
        x = Image.open(image_path)
        x = self.transform(x)
        return x

    def __len__(self):
        return len(self.image_paths)

