import os.path as osp
import pickle
from typing import Set, Tuple

import lmdb as lmdb
from jaxtyping import Shaped, jaxtyped, Float32
from beartype import beartype as typechecker
from torch.utils import data
from torch.utils.data import Dataset
import pyarrow as pa

from conf.dataset import BlenderParams
from data.BlenderDataset.blender_transforms import sample_transforms
from data.UtilsDataset import CustomDataModule
from utils.utils import display_tensor
import numpy as np


class BlenderLMDBDataset(Dataset):
    def __init__(
        self,
        db_path: str,
        return_indice: bool,
        return_params: bool = True,
        target_transform=None,
        return_domain: Set[int] = None,
        is_one: bool = True,
        get_item_for_translation: bool = False,
        use_pickle: bool = False,
    ):
        super(BlenderLMDBDataset, self).__init__()
        if return_domain is None and not is_one:
            return_domain = sorted({0, 1, 2})
        self.return_domain = return_domain
        self.return_indice = return_indice
        self.get_item_for_translation = get_item_for_translation
        self.use_pickle = use_pickle

        self.db_path = db_path
        self.env = lmdb.open(
            path=db_path, subdir=osp.isdir(db_path),
            readonly=True, lock=False, readahead=False, meminit=False
        )
        with self.env.begin(write=False) as txn:
            if use_pickle:
                self.length = pickle.loads(txn.get(b'__len__'))
                self.keys   = pickle.loads(txn.get(b'__keys__'))
            else:
                self.length = pa.deserialize(txn.get(b'__len__'))
                self.keys   = pa.deserialize(txn.get(b'__keys__'))

        self.return_params = return_params
        self.target_transform = target_transform
        self.is_one = is_one

        self.transform = sample_transforms

        if is_one and return_domain is not None:
            raise ValueError('return_domain must be None if is_one is True')

    def __getitem__(self, index):
        if self.get_item_for_translation:
            res = self.get_for_translate(index)
        if self.is_one:
            res = self.get_data_one(index)
        else:
            res = self.get_data_all(index)

        if self.return_indice:
            return tuple(res) + (index, )
        else:
            return res

    def get_for_translate(self, index):
        assert not self.use_pickle
        with self.env.begin(write=False) as txn:
            byteflow = txn.get(self.keys[index].encode())
        unpacked = pa.deserialize(byteflow)
        (cube_img, pyramid_img, cylinder_img), (cube_par, pyramid_par, cylinder_par) = unpacked
        return (cube_img, pyramid_img, cylinder_img), (cube_par, pyramid_par, cylinder_par)

    def get_data_one(self, index):

        fetch_in_db_index = index // 3
        in_row_inde = index % 3

        with self.env.begin(write=False) as txn:
            byteflow = txn.get(self.keys[fetch_in_db_index].encode())
        if self.use_pickle:
            unpacked = pickle.loads(byteflow)
        else:
            unpacked = pa.deserialize(byteflow)
        (cube_img, pyramid_img, cylinder_img), (cube_par, pyramid_par, cylinder_par) = unpacked
        cube_par = self.normalize_params(self.filter_params(cube_par))
        pyramid_par = self.normalize_params(self.filter_params(pyramid_par))
        cylinder_par = self.normalize_params(self.filter_params(cylinder_par))

        if self.transform is not None:
            cube_img     = self.transform(cube_img)
            pyramid_img  = self.transform(pyramid_img)
            cylinder_img = self.transform(cylinder_img)

        if self.return_params and self.target_transform is not None:
            cube_par     = self.target_transform(cube_par)
            pyramid_par  = self.target_transform(pyramid_par)
            cylinder_par = self.target_transform(cylinder_par)

        imgs = [cube_img, pyramid_img, cylinder_img]
        params = [cube_par, pyramid_par, cylinder_par]

        img   = imgs[in_row_inde]
        param = params[in_row_inde]

        if self.return_params:
            return img, param
        else:
            return img

    def get_data_all(self, index):
        with self.env.begin(write=False) as txn:
            byteflow = txn.get(self.keys[index].encode())
        if self.use_pickle:
            unpacked = pickle.loads(byteflow)
        else:
            unpacked = pa.deserialize(byteflow)
        (cube_img, pyramid_img, cylinder_img), (cube_par, pyramid_par, cylinder_par) = unpacked
        cube_par = self.normalize_params(self.filter_params(cube_par))
        pyramid_par = self.normalize_params(self.filter_params(pyramid_par))
        cylinder_par = self.normalize_params(self.filter_params(cylinder_par))

        if self.transform is not None:
            cube_img     = self.transform(cube_img)
            pyramid_img  = self.transform(pyramid_img)
            cylinder_img = self.transform(cylinder_img)

        if self.return_params and self.target_transform is not None:
            cube_par     = self.target_transform(cube_par)
            pyramid_par  = self.target_transform(pyramid_par)
            cylinder_par = self.target_transform(cylinder_par)

        imgs = [cube_img, pyramid_img, cylinder_img]
        params = [cube_par, pyramid_par, cylinder_par]

        imgs   = [imgs[i] for i in self.return_domain]
        params = [params[i] for i in self.return_domain]

        if self.return_params:
            return imgs, params
        else:
            return imgs

    @jaxtyped
    @typechecker
    def normalize_params(self, params: Shaped[np.ndarray, '19']) -> Float32[np.ndarray, '19']:
        """
        Normalize the parameters value to be in the range [0, 1]
        """
        type_obj_params = params[:3]
        pos_params = params[3:6]
        cam_params = params[6:7]
        obj_color_params = params[7:10]
        floor_color_params = params[10:13]
        wall1_color_params = params[13:16]
        wall2_color_params = params[16:19]

        # only need to normalize the colors
        obj_color_params = obj_color_params / 255.
        floor_color_params = floor_color_params / 255.
        wall1_color_params = wall1_color_params / 255.
        wall2_color_params = wall2_color_params / 255.

        normalized = np.concatenate([type_obj_params, pos_params, cam_params, obj_color_params, floor_color_params, wall1_color_params, wall2_color_params])

        # cast to float32
        normalized = normalized.astype(np.float32)

        return normalized

    @jaxtyped
    @typechecker
    def filter_params(self, params: np.ndarray) -> Shaped[np.ndarray, '19']:
        # num,type,pos_x,pos_y,pos_z,scale,camera_angle,lamp_pos_x,lamp_pos_y,lamp_pos_z,obj_color,floor_color,wall1_color,wall2_color
        num, type_obj, x, y, z, scale, camera_angle, lamp_x, lamp_y, lamp_z, obj_color, floor_color, wall1_color, wall2_color, path = params

        # one hot encore type_obj
        type_obj = np.eye(3)[int(type_obj)]
        obj_color = [float(i) for i in obj_color[1:-1].split(',')]
        floor_color = [float(i) for i in floor_color[1:-1].split(',')]
        wall1_color = [float(i) for i in wall1_color[1:-1].split(',')]
        wall2_color = [float(i) for i in wall2_color[1:-1].split(',')]

        return np.array([
            *type_obj,
            x, y, z,
            camera_angle,
            *obj_color,
            *floor_color,
            *wall1_color,
            *wall2_color,
        ])

    def __len__(self) -> int:
        if self.is_one:
            return self.length * 3
        else:
            return self.length

    def __repr__(self):
        return self.__class__.__name__ + ' (' + self.db_path + ')'


class BlenderDataModule(CustomDataModule):
    def _fetch_base_dataset(self) -> Tuple[data.Dataset, data.Dataset, data.Dataset]:
        """
        Return train, valid and test dataset
        """
        blender_params: BlenderParams = self.p.data_params
        global_dataset = BlenderLMDBDataset(
            db_path         =blender_params.root,
            return_params   =blender_params.return_params,
            target_transform=None,
            return_domain   =blender_params.return_domain,
            is_one          =blender_params.is_one,
            get_item_for_translation=blender_params.get_item_for_translation,
            use_pickle              =blender_params.use_pickle,
            return_indice=blender_params.return_indice,
        )

        train_dataset, valid_dataset, test_dataset = self.split_dataset(global_dataset)

        return train_dataset, valid_dataset, test_dataset
