import matplotlib.pyplot as plt
import os
import numpy as np
import torch

from dataloaders.utils import divide_sequence, create_patches


class YUP_Loader():

    def __init__(self, path, patch_size=16, time_steps=20, flatten=True, scale=False, normalize=True, category=None, camera=None):
        '''
        :param path: path to YUP dataset
        :param patch_size: size of the image patches
        :param flatten: Flatten all frames (images) to one-directional arrays
        :param scale: Scale 8-bit images of range 0-255 to range 0-1
        :param category: Restrict dataset to a set of categories
        :param camera: Restrict dataset to a to a set of camera modes (stationary, moving)
        '''

        #Generate list of files
        filelist = os.path.join(path, 'split.txt')
        with open(filelist, 'r') as myfile:
            lines = myfile.readlines()
            files = [{'file': l.split(' ')[0],
                      'cat': l.split(' ')[0].split('/')[0].capitalize(),
                      'split': l.split(' ')[1].split('_')[1].strip(),
                      'camera': l.split(' ')[1].split('_')[0]}
                     for l in lines]

            for f in files:
                #cam = 'stationary' if f['camera'] == 'static' else 'moving'
                #f['location'] = os.path.join(path, 'camera_' + cam, f['cat'], f['file'] + '.npy')
                f['location'] = os.path.join(path, f['file'].replace('.avi', '.npy'))

        #Restrict to categories and camera motion
        if category is not None:
            files = [f for f in files if f['cat'] in category]
        if camera is not None:
            files = [f for f in files if f['camera'] in camera]

        # Load clips and split train/eval
        # Warning: it seems that train and test is interverted in the dataset
        #trainval_files = [f for f in files if f['split'] == 'test']
        #eval_cutoff = int(0.9*len(trainval_files))
        #self.train_files = trainval_files[:eval_cutoff]
        #self.eval_files = trainval_files[eval_cutoff:]
        self.test_files = [f for f in files if f['split'] == 'test']
        self.train_files = [f for f in files if f['split'] == 'train']
        self.eval_files = [f for f in files if f['split'] == 'eval']

        self.train_clips = [np.load(f['location']).astype('float32') for f in self.train_files]
        self.eval_clips = [np.load(f['location']).astype('float32') for f in self.eval_files]
        self.test_clips = [np.load(f['location']).astype('float32') for f in self.test_files]

        if scale:
            self.train_clips = [c / 255.0 for c in self.train_clips]
            self.eval_clips = [c / 255.0 for c in self.eval_clips]
            self.test_clips = [c / 255.0 for c in self.test_clips]

        # Divide each clip into sequences of time_step frames, all time_step clips are stored in lists
        self.train_clips = [divide_sequence(c, time_steps) for c in self.train_clips]
        self.train_clips = [item for sublist in self.train_clips for item in sublist]
        self.eval_clips = [divide_sequence(c, time_steps) for c in self.eval_clips]
        self.eval_clips = [item for sublist in self.eval_clips for item in sublist]
        self.test_clips = [divide_sequence(c, time_steps) for c in self.test_clips]
        self.test_clips = [item for sublist in self.test_clips for item in sublist]

        print('Data loaded. Number of {} frame long clips: {} train, {} eval, {} test'.format(time_steps, len(self.train_clips), len(self.eval_clips), len(self.test_clips)))

        #Generate image patches for training and testing
        self.train_patches = np.concatenate([create_patches(c, patch_size) for c in self.train_clips])
        self.eval_patches = np.concatenate([create_patches(c, patch_size) for c in self.eval_clips])
        self.test_patches = np.concatenate([create_patches(c, patch_size) for c in self.test_clips])


        self.train_num, self.eval_num, self.test_num = self.train_patches.shape[0], self.eval_patches.shape[0], self.test_patches.shape[0]
        print('Data formatted to patches. Number of {} frame long patches: {} train, {} eval, {} test'.format(time_steps,
                                                                                                            self.train_num,
                                                                                                            self.eval_num,
                                                                                                            self.test_num))


        if flatten:
            self.train_patches = self.train_patches.reshape([self.train_num, time_steps, -1])
            self.eval_patches = self.eval_patches.reshape([self.eval_num, time_steps, -1])
            self.test_patches = self.test_patches.reshape([self.test_num, time_steps, -1])

        #Permute samples /frames indices
        self.train_patches = np.swapaxes(self.train_patches, 0, 1)
        self.eval_patches = np.swapaxes(self.eval_patches, 0, 1)
        self.test_patches = np.swapaxes(self.test_patches, 0, 1)

        self.mean = np.mean(self.train_patches, axis=(0,1))
        self.std = np.std(self.train_patches, axis=(0,1))
        print(np.min(self.std))

        if normalize is True:
            self.train_patches -= self.mean
            self.eval_patches -= self.mean
            self.test_patches -= self.mean

            self.train_patches /= self.std
            self.eval_patches /= self.std
            self.test_patches /= self.std

        self.mean = torch.from_numpy(self.mean).cuda()
        self.std = torch.from_numpy(self.std).cuda()

        self.current_idx_train = 0
        self.current_idx_eval = 0
        self.current_idx_test = 0

    def load_videos_as_arrays(self, split):
        files = None
        videos = None
        if split == 'eval':
            files = [f['file'] for f in self.eval_files]
            videos = [np.load(f['location']).astype('float32') for f in self.eval_files]
        elif split == 'test':
            files = [f['file'] for f in self.test_files]
            videos = [np.load(f['location']).astype('float32') for f in self.test_files]
        return videos, files


    @property
    def train(self):
        return self.train_patches
    @property
    def eval(self):
        return self.eval_patches
    @property
    def test(self):
        return self.test_patches

    @property
    def eval_f(self):
        return self.eval_files
    @property
    def test_f(self):
        return self.test_files

    def reset_indices(self):
        self.current_idx_test, self.current_idx_eval, self.current_idx_train = 0, 0, 0

    def normalize(self, sample):
        return (sample - self.mean)/self.std

    def unnormalize(self, sample):
        return sample*self.std + self.mean


    def shuffle(self):
        '''
        Shuffle training set (video patches, across all clips)
        '''
        indices = np.random.permutation(self.train_patches.shape[1])
        self.train_patches = self.train_patches[:, indices, ...]
        self.current_idx_train = 0

    def load_batch_train(self, batch_size):
        if self.current_idx_train + batch_size >= self.train_num:
            batch_end = self.train_patches[:, self.current_idx_train:, ...]
            self.shuffle()
            batch_start = self.train_patches[:, 0:batch_size-batch_end.shape[1], ...]
            batch = np.concatenate((batch_end, batch_start), axis=1)
        else:
            batch = self.train_patches[:, self.current_idx_train:self.current_idx_train + batch_size, ...]
            self.current_idx_train += batch_size
        return batch

    def load_batch_validation(self, batch_size):
        if self.current_idx_eval + batch_size >= self.eval_num:
            batch = self.eval_patches[:, self.current_idx_eval:, ...]
            self.current_idx_eval = 0
        else:
            batch = self.eval_patches[:, self.current_idx_eval: self.current_idx_eval + batch_size, ...]
            self.current_idx_eval += batch_size
        return batch

    def load_batch_test(self, batch_size):
        if self.current_idx_eval + batch_size >= self.test_num:
            batch = self.test_patches[:, self.current_idx_test:, ...]
            self.current_idx_test = 0
        else:
            batch = self.test_patches[:, self.current_idx_test: self.current_idx_test + batch_size, ...]
            self.current_idx_test += batch_size
        return batch



    def visualize(self, start=0, end=1):
        #NOT YET IMPLEMENTED
        for i in range(start, end):
            clip = self.data[:, i, :, :]
            clip = 255 - clip
            plt.figure(1)
            plt.clf()
            plt.title('our method')
            for j in range(7, 8):
                img = clip[j]
                plt.imshow(img, cmap='gray')
                plt.pause(100)
                plt.draw()


if __name__ == '__main__':
    path = '../../YUPENN Dynamic Scenes Data Set/'
    loader = YUP_Loader(path, flatten=False, scale=False, category=['Street'], camera=['static'])