from torch.utils.data import DataLoader
from PIL import Image
import os
import numpy as np
from numpy.random import randint
import random
import torch.utils.data as data
from .downsample_perframe_data_layers import LSTRDataLayer
from rekognition_online_action_detection.utils.registry import Registry
DATA_LAYERS = Registry()

def build_dataset(video_data,model_name,data_name,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, phase,num_sel_frames,memory,tag=''):
    
    return  LSTRDataLayer(video_data,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, memory,num_sel_frames,phase)

def build_data_loader(video_data,batch_size,shuffle,num_workers,pin_memory,model_name,data_name,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, phase,num_sel_frames,memory):

                        data_loader = data.DataLoader(
                                dataset=build_dataset(video_data,model_name,data_name,data_root,
                                                visual_feature,motion_feature,target_perframe,long_memory_length,
                                                long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                                                work_memory_sample_rate,work_memory_num_samples, phase,num_sel_frames,memory),   
                                batch_size=batch_size,
                                shuffle=True if phase == 'train' else False,
                                num_workers=num_workers,
                                pin_memory=pin_memory)
                        return data_loader

class CILSetTask:
    def __init__(self,video_tasks,memory_size, batch_size,
                 shuffle,num_workers,pin_memory,model_name,data_name,data_root,
                 visual_feature,motion_feature,target_perframe,long_memory_length,
                 long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                 work_memory_sample_rate,work_memory_num_samples,phase,phases,num_sel_frames,train_enable=True):
        self.model_name=model_name
        self.data_name=data_name
        self.memory = {}
        self.num_tasks = len(video_tasks)
        self.current_task = 0
        self.batch_size=batch_size
        self.shuffle=shuffle
        self.num_workers=num_workers
        self.pin_memory=pin_memory
        self.current_task_dataset = None
        self.memory_size = memory_size
        self.video_tasks = video_tasks
        self.train_enable = train_enable
        self.data_root=data_root
        self.visual_feature=visual_feature
        self.motion_feature=motion_feature
        self.target_perframe=target_perframe
        self.long_memory_length=long_memory_length
        self.long_memory_sample_rate=long_memory_sample_rate
        self.long_memory_num_samples=long_memory_num_samples
        self.work_memory_length=work_memory_length
        self.work_memory_sample_rate=work_memory_sample_rate
        self.work_memory_num_samples=work_memory_num_samples
        self.phase=phase
        self.phases=phases
        self.num_frame_to_save='ALL'
        self.memory={}
        self.num_sel_frames=num_sel_frames


    def __iter__(self):
        self.memory = {}
        self.current_task_dataset = None
        self.current_task = 0
        return self

    def __next__(self):
        video_data = self.video_tasks[self.current_task]
        self.current_task_dataloaders={
            phase:build_data_loader(video_data,self.batch_size,self.shuffle,self.num_workers,self.pin_memory,self.model_name,self.data_name,self.data_root,
                        self.visual_feature,self.motion_feature,self.target_perframe,self.long_memory_length,
                        self.long_memory_sample_rate,self.long_memory_num_samples,self.work_memory_length,
                        self.work_memory_sample_rate,self.work_memory_num_samples, phase,self.num_sel_frames,self.memory)
                        for phase in self.phases
                            }


        self.current_task += 1
        if self.current_task < len(self.video_tasks):
            return self.current_task_dataloaders, len(self.video_tasks[self.current_task].keys()),video_data
        else:
            return self.current_task_dataloaders, None,video_data



    def rehearsal_randomMethod(self, current_task):
        saved_classes = self.memory.keys()
        current_classes = current_task.keys()
        num_classes = len(saved_classes) + len(current_classes)
        elem_to_save = {**self.memory, **current_task}
        if self.memory_size != 'ALL':
            num_instances_per_class = self.memory_size // num_classes
            for class_n, elems in elem_to_save.items():
                random.shuffle(elems)
                elem_to_save[class_n] = elems[:num_instances_per_class]
        self.memory = elem_to_save

    def get_dataloader(self, data): 
        dataloader=build_data_loader(data,self.batch_size,self.shuffle,self.num_workers,self.pin_memory,self.model_name,self.data_name,self.data_root,
                        self.visual_feature,self.motion_feature,self.target_perframe,self.long_memory_length,
                        self.long_memory_sample_rate,self.long_memory_num_samples,self.work_memory_length,
                        self.work_memory_sample_rate,self.work_memory_num_samples, phase='train')
        return dataloader



