from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from attrdict import AttrDict

import os
import cv2
import matplotlib
import h5py
import numpy as np
import torch
import torch.nn.functional as F

from torch import Tensor
from torch.utils.data import DataLoader, Dataset



def get_continual_clevr(params, configs):
    configs = AttrDict(configs)
    dataset_params = {
        "width": configs.width,
        "height": configs.height,
        "num_background_objects": configs.num_background_objects,
        "max_num_objects": configs.max_num_objects,
        "input_channels": configs.input_channels,
    }
    datasets = []
    for task in range(configs.num_task):
        task_config = AttrDict(getattr(configs, f'task{task}'))
        train_params = {
            "name": f'{task_config.dataset.name}_train',
            "dataset_root": task_config.dataset.dataset_path,
            "downstream_features": task_config.downstream_features,
            "data_sizes": task_config.dataset.data_sizes[0],
        }
        train_dataset = ContinualClevr(**dataset_params, **train_params)

        test_params = {
            "name": f'{task_config.dataset.name}_test',
            "dataset_root": task_config.dataset.dataset_path,
            "downstream_features": task_config.downstream_features,
            "data_sizes": task_config.dataset.data_sizes[2],
        }
        test_dataset = ContinualClevr(**dataset_params, **test_params)
        val_dataset = ContinualClevr(**dataset_params, **test_params)
        datasets.append((train_dataset, val_dataset, test_dataset))


    params.resolution = (configs.width, configs.height)
    params.num_slots = configs.max_num_objects + configs.num_background_objects
    params.in_channels = configs.input_channels
    params.steps = 500000
    return datasets, params
        

class ContinualClevr(Dataset):
    def __init__(
        self,
        name: str,
        width: int,
        height: int,
        max_num_objects: int,
        num_background_objects: int,
        input_channels: int,
        dataset_root: str,
        data_sizes: int,
        downstream_features: List[str],

        ):
        super().__init__()

        self.name = name
        self.width = width
        self.height = height
        self.max_num_objects = max_num_objects
        self.num_background_objects = num_background_objects
        self.input_channels = input_channels
        self.data_sizes = data_sizes
        self.dataset_root = dataset_root
        self.dataset_path = f'{dataset_root}/{name}.h5'


        self.dataset = self._load_data()
        

        self.data = {}
        self.preload_range = (0, self.data_sizes)
        self.idx_range = min(self.data_sizes, self.dataset['image'].shape[0] )
        
        self.output_features = list(self.dataset.keys())
        
        for feature in self.output_features:
            self.data[feature] = self.dataset[feature][
                self.preload_range[0] : self.preload_range[1]
            ][:self.idx_range]

        self.output_features.append('is_foreground')
        print('>>>', self.output_features)

        self.logits = None

    def __len__(self,):

        return self.idx_range


    def _load_data(self,):
        # return _load_data_hdf5(data_path=self.full_dataset_path)
        return self._load_data_hdf5()
    

    def _load_data_hdf5(self,):
        """Loads data and metadata assuming the data is hdf5, and converts it to dict."""
        dataset = h5py.File(self.dataset_path, "r")
        dataset = {k: dataset[k] for k in dataset}
        return dataset
    
    
    def _preprocess_feature(self, feature: np.ndarray, feature_name: str):
        if feature_name == "image":
            return (
                torch.as_tensor(feature, dtype=torch.float32).permute(2, 0, 1) / 255.0
            )
        
        if feature_name == "mask":
            one_hot_masks = F.one_hot(
                torch.as_tensor(feature, dtype=torch.int64),
                num_classes=self.max_num_objects + self.num_background_objects,
            )
            # (num_objects, 1, height, width)
            return one_hot_masks.permute(3, 2, 0, 1).to(torch.float32)
        
        if feature_name in ["num_actual_objects", "pixel_coords", "rotation"]:
            return torch.as_tensor(feature, dtype=torch.float32)
        
        return torch.as_tensor(feature, dtype=torch.uint8)
        



    def __getitem__(self, index):
        # 'color', 'image', 'mask', 'num_actual_objects', 'shape', 'x', 'y', 'is_foreground'

        
        out = {}
        for feature_name in self.data.keys():
            out[feature_name] = self._preprocess_feature(
                self.data[feature_name][index], feature_name
            )

        out["is_foreground"] = torch.as_tensor([0] * self.num_background_objects + [1] * self.max_num_objects, dtype=torch.uint8).view(-1, 1)
        
        if "is_modified" not in out:
            out["is_modified"] = torch.as_tensor([0] * self.num_background_objects + [0] * self.max_num_objects, dtype=torch.uint8)
        else: 
            out["is_modified"] = torch.FloatTensor(out["is_modified"])

        assert out["mask"].shape == (self.max_num_objects + self.num_background_objects, 1, self.height, self.width)
        assert out["mask"].sum(1).max() <= 1.0
        assert out["mask"].min() >= 0.0

        return out