from torch.utils.data import DataLoader
from PIL import Image
import os
import numpy as np
from numpy.random import randint
import random
import torch.utils.data as data
from rekognition_online_action_detection.utils.registry import Registry
DATA_LAYERS = Registry()

def build_dataset(video_data,model_name,data_name,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, phase,memory,tag=''):
    data_layer = DATA_LAYERS[model_name + tag + data_name]
    return  data_layer(video_data,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, memory,phase)

def build_data_loader(video_data,batch_size,shuffle,num_workers,pin_memory,model_name,data_name,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, phase,memory):

                        data_loader = data.DataLoader(
                                dataset=build_dataset(video_data,model_name,data_name,data_root,
                                                visual_feature,motion_feature,target_perframe,long_memory_length,
                                                long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                                                work_memory_sample_rate,work_memory_num_samples, phase,memory),   
                                batch_size=batch_size,
                                shuffle=True if phase == 'train' else False,
                                num_workers=num_workers,
                                pin_memory=pin_memory)
                        return data_loader
class CILSetTask:
    def __init__(self,video_tasks,memory_size, batch_size,
                 shuffle,num_workers,pin_memory,model_name,data_name,data_root,
                 visual_feature,motion_feature,target_perframe,long_memory_length,
                 long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                 work_memory_sample_rate,work_memory_num_samples,phase,phases,perc,train_enable=True):
        self.model_name=model_name
        self.data_name=data_name
        self.memory = {}
        self.num_tasks = len(video_tasks)
        self.current_task = 0
        self.batch_size=batch_size
        self.shuffle=shuffle
        self.num_workers=num_workers
        self.pin_memory=pin_memory
        self.current_task_dataset = None
        self.memory_size = memory_size
        self.video_tasks = video_tasks
        self.train_enable = train_enable
        self.data_root=data_root
        self.visual_feature=visual_feature
        self.motion_feature=motion_feature
        self.target_perframe=target_perframe
        self.long_memory_length=long_memory_length
        self.long_memory_sample_rate=long_memory_sample_rate
        self.long_memory_num_samples=long_memory_num_samples
        self.work_memory_length=work_memory_length
        self.work_memory_sample_rate=work_memory_sample_rate
        self.work_memory_num_samples=work_memory_num_samples
        self.phase=phase
        self.phases=phases
        self.num_frame_to_save='ALL'
        self.memory={}
        self.perc=perc


    def __iter__(self):
        self.memory = {}
        self.current_task_dataset = None
        self.current_task = 0
        return self

    def __next__(self):
        video_data = self.video_tasks[self.current_task] 
        if self.current_task==0:
            train_train_current_task_dataloaders={
                phase:build_data_loader(video_data,self.batch_size,self.shuffle,self.num_workers,self.pin_memory,self.model_name,self.data_name,self.data_root,
                            self.visual_feature,self.motion_feature,self.target_perframe,self.long_memory_length,
                            self.long_memory_sample_rate,self.long_memory_num_samples,self.work_memory_length,
                            self.work_memory_sample_rate,self.work_memory_num_samples, phase,self.memory)
                            for phase in self.phases
                                }
            self.current_task += 1
            return train_train_current_task_dataloaders,None, len(self.video_tasks[self.current_task].keys()),video_data
        else:
            train_train_data={}
            train_val_data={}
            for key,values in video_data.items():
                total_data_value=len(values)
                len_train_train_data=int(total_data_value*self.perc)
                train_train_data[key]=values[:len_train_train_data]
                train_val_data[key]=values[len_train_train_data:]
            train_train_current_task_dataloaders = {
                phase: build_data_loader(train_train_data, self.batch_size, self.shuffle, self.num_workers,
                                        self.pin_memory, self.model_name, self.data_name, self.data_root,
                                        self.visual_feature, self.motion_feature, self.target_perframe,
                                        self.long_memory_length,
                                        self.long_memory_sample_rate, self.long_memory_num_samples,
                                        self.work_memory_length,
                                        self.work_memory_sample_rate, self.work_memory_num_samples, phase,
                                        self.memory)
                for phase in self.phases
            }
            train_val_current_task_dataloaders = {
                phase: build_data_loader(train_val_data, self.batch_size, self.shuffle, self.num_workers,
                                        self.pin_memory, self.model_name, self.data_name, self.data_root,
                                        self.visual_feature, self.motion_feature, self.target_perframe,
                                        self.long_memory_length,
                                        self.long_memory_sample_rate, self.long_memory_num_samples,
                                        self.work_memory_length,
                                        self.work_memory_sample_rate, self.work_memory_num_samples, phase,
                                        self.memory)
                for phase in self.phases
            }
            self.current_task += 1
            if self.current_task < len(self.video_tasks):
                return train_train_current_task_dataloaders,train_val_current_task_dataloaders, len(self.video_tasks[self.current_task].keys()),video_data
            else:
                return train_train_current_task_dataloaders,train_val_current_task_dataloaders, None,video_data



    def rehearsal_randomMethod(self, current_task):
        saved_classes = self.memory.keys()
        current_classes = current_task.keys()
        num_classes = len(saved_classes) + len(current_classes)
        elem_to_save = {**self.memory, **current_task}
        if self.memory_size != 'ALL':
            num_instances_per_class = self.memory_size // num_classes
            for class_n, elems in elem_to_save.items():
                random.shuffle(elems)
                elem_to_save[class_n] = elems[:num_instances_per_class]
        self.memory = elem_to_save

    def get_dataloader(self, data): 
        dataloader=build_data_loader(data,self.batch_size,self.shuffle,self.num_workers,self.pin_memory,self.model_name,self.data_name,self.data_root,
                        self.visual_feature,self.motion_feature,self.target_perframe,self.long_memory_length,
                        self.long_memory_sample_rate,self.long_memory_num_samples,self.work_memory_length,
                        self.work_memory_sample_rate,self.work_memory_num_samples, phase='train')
        return dataloader






