import os
import torch
from torchvision import transforms
from models.utils import rescale
from omegaconf import DictConfig, OmegaConf, open_dict
import numpy as np

from typing import List

from models import utils as m_utils


def get_dataset(params, config=None):


    if params.dataset.lower() == 'continual_clevr':
        from dataset.continual_dataset.continual_clevr import get_continual_clevr
        assert config is not None

        datasets, params = get_continual_clevr(params=params, configs=config)

        return datasets, params

    
    elif params.dataset.lower() == 'continual_tetrominoes':
        from dataset.continual_dataset.contiual_tetrominoes import get_continual_tetrominoes

        assert config is not None
        tasks = []

        if config['color'] == 'random':
            c_idx = np.cumsum([int(config[x]['dataset']['color'][0]) for x in range(params.num_task)])
            assert c_idx[-1] == params.total_colors
            colors = np.random.choice(np.arange(0.0, 1.0, 1 / params.total_colors), params.total_colors, replace=False)
            colors = np.split(colors, indices_or_sections=c_idx)
        elif config['color'] == 'user':
            colors = []
            for t_idx in range(params.num_task):
                color = config[t_idx]['dataset']['color'][1]
                color = np.array(color)
                colors.append(color)
        else:
            colors = []
            for t_idx in range(params.num_task):
                color = config[t_idx]['dataset']['color']
                color = np.arange(color[1], color[2], color[3])
                colors.append(color)

        if config['shape'] == 'random':
            raise ValueError('T.B.U')
        
        elif config['shape'] == 'user':
            shapes = []
            for t_idx in range(params.num_task):
                shape = config[t_idx]['dataset']['shape'][1]
                shape = np.array(shape)
                shapes.append(shape)
        else:
            shapes = []
            for t_idx in range(params.num_task):
                shape = config[t_idx]['dataset']['shape']
                shape = np.arange(shape[0], shape[1], shape[2])
                shapes.append(shape)


        for t_idx in range(params.num_task):
            mod = config[t_idx]['dataset']['_target_']
            func = config[t_idx]['dataset']['_function_']
            mod = __import__(mod, fromlist=[func])
            func = getattr(mod, func)

            train_size, val_size, test_size = config[t_idx]['dataset']['data_sizes']
            
            shape = shapes[t_idx]
            color = colors[t_idx]

            tasks.append((train_size, val_size, test_size, func, shape, color))
            
        assert params.num_task == len(tasks)

        params.num_labels = []
        datasets = []
        for t_idx, (train_size, val_size, test_size, dataset_f, shape_list, color_list) in enumerate(tasks):
            
            kargs = {
                'name': params.dataset.lower(),
                'width': config['width'],
                'height': config['height'],
                'max_num_objects': config['max_num_objects'],
                'num_background_objects': config['num_background_objects'],
                'input_channels': config['input_channels'],
                'dataset_root': '.',
                'dataset_name': '.',
                'generate_dataset':True,
            }
            kargs.update({
                'dataset_f': dataset_f,
                'shape_list': shape_list,
                'color_list': color_list,
                'train_dataset_size': train_size,
                'test_dataset_size': test_size,
                })
            if m_utils.is_main_process():
                print('task:', t_idx, 'shape', shape_list.shape, shape_list)
                print('task:', t_idx, 'color', color_list.shape, color_list)
            train_dataset, val_dataset = get_continual_tetrominoes(**kargs)
            sample_dataset = val_dataset

            params.resolution = (config['width'], config['height'])
            params.num_slots = config['max_num_objects'] + config['num_background_objects']
            params.in_channels = config['input_channels']
            params.steps = 500000

            params.num_labels += [train_dataset.len_labels]

            datasets.append((train_dataset, val_dataset, sample_dataset))
        
        return datasets, params

    return train_dataset, val_dataset, sample_dataset, params

    

def get_dataloader(params, dataset: List, config=None):
    dataloaders = []
    for task in dataset:
        train_dataset, val_dataset, sample_dataset = task
        sampler = torch.utils.data.DistributedSampler(train_dataset, shuffle=True)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            sampler=sampler,
            batch_size=params.batch_size,
            num_workers=params.num_workers,
            pin_memory=True,
        )


        val_dataloader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=params.val_batch_size,
            num_workers=params.num_workers,
            pin_memory=True,
            drop_last=False,
            shuffle=False,
        )
        
        sample_dataloader = torch.utils.data.DataLoader(
            sample_dataset,
            batch_size=params.val_batch_size,
            num_workers=params.num_workers,
            pin_memory=True,
            drop_last=False,
            shuffle=False,
        )
        dataloaders.append((train_dataloader, val_dataloader, sample_dataloader))
    return dataloaders


