import os
import os.path as osp
import pickle
from typing import IO, List, Tuple
import lmdb as lmdb
from torch.utils.data import DataLoader, Dataset
import pyarrow as pa
import pandas as pd
import numpy as np
import imageio


class BlenderPNGDataset(Dataset):
    def __init__(
        self,
        path_list_imgs: List[str],
        path_list_csv: List[str],
        return_params: bool = True,
        process_params: bool = True,
        transform=None,
        target_transform=None,
    ):
        super(BlenderPNGDataset, self).__init__()
        self.return_params = return_params
        self.process_params = process_params
        self.transform = transform
        self.target_transform = target_transform
        # region found the len of dataset
        list_arrays, list_max = [], []
        for imgs_path, csv_path in zip(path_list_imgs, path_list_csv):
            data = pd.read_csv(csv_path)
            data['file_path'] = imgs_path
            list_max.append(data['num'].max())
            list_arrays.append(data)
        # reorder the list to concat in the right order
        list_arrays, list_max = zip(*sorted(zip(list_arrays, list_max), key=lambda x: x[1]))

        full_params = pd.concat(list_arrays)
        if len(full_params) % 3 != 0:
            raise Exception('Dataset length is not a multiple of 3')
        self.full_dataframe = full_params
        self.dataset_length = max(list_max) + 1
        # endregion
        shape2int = {'cube': 0, 'pyramid': 1, 'icosphere': 2}
        self.full_dataframe['type'] = self.full_dataframe['type'].apply(lambda shape_name: shape2int[shape_name])

    def _get_imgs(self, index: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        base_path = self.full_dataframe.iloc[index*3].file_path
        cube = np.array(imageio.imread(osp.join(base_path, f'{index}_A.png')))
        pyramid = np.array(imageio.imread(osp.join(base_path, f'{index}_B.png')))
        cylinder = np.array(imageio.imread(osp.join(base_path, f'{index}_C.png')))

        return cube, pyramid, cylinder

    def _process_params(self, params):
        columns_to_keep = ['type', 'pos_x', 'pos_y', 'pos_z', 'camera_angle', 'obj_color', 'floor_color', 'wall1_color', 'wall2_color']
        return params[columns_to_keep]

    def _get_params(self, index: int):
        cube_params = self.full_dataframe.iloc[index*3]
        pyramid_params = self.full_dataframe.iloc[index*3+1]
        cylinder_params = self.full_dataframe.iloc[index*3+2]

        if self.process_params:
            cube_params = self._process_params(cube_params)
            pyramid_params = self._process_params(pyramid_params)
            cylinder_params = self._process_params(cylinder_params)

        return cube_params.to_numpy(), pyramid_params.to_numpy(), cylinder_params.to_numpy()

    def __len__(self) -> int:
        return self.dataset_length

    def __getitem__(self, index: int):
        cube_img, pyramid_img, cylinder_img = self._get_imgs(index)
        if self.transform is not None:
            cube_img = self.transform(cube_img)
            pyramid_img = self.transform(pyramid_img)
            cylinder_img = self.transform(cylinder_img)

        if self.return_params:
            cube_par, pyramid_par, cylinder_par = self._get_params(index)
            if self.target_transform is not None:
                cube_par = self.target_transform(cube_par)
                pyramid_par = self.target_transform(pyramid_par)
                cylinder_par = self.target_transform(cylinder_par)
            return (cube_img, pyramid_img, cylinder_img), (cube_par, pyramid_par, cylinder_par)
        else:
            return cube_img, pyramid_img, cylinder_img


def _raw_reader(path: str) -> IO[bytes]:
    with open(path, 'rb') as f:
        bin_data = f.read()
    return bin_data


def _dumps_pyarrow(obj) -> IO[bytes]:
    """
    Serialize an object.
    :return: Implementation-dependent bytes-like object
    """
    return pa.serialize(obj).to_buffer()


def _dumps_pickle(obj) -> bytes:
    """
    Serialize an object using pickle.
    """
    return pickle.dumps(obj)


def folder2lmdb(
        src_dataset: Dataset, dest_path: str, name: str = 'train',
        write_frequency: int = 500, num_workers: int = 16,
) -> None:
    data_loader = DataLoader(src_dataset, collate_fn=lambda x: x)

    lmdb_path = osp.join(dest_path, name)
    isdir = os.path.isdir(lmdb_path)

    print(f'Generate LMDB to {lmdb_path}')
    with lmdb.open(lmdb_path, subdir=isdir, map_size=int(8e9), readonly=False, meminit=False, map_async=True) as db:
        print(f'{len(src_dataset)=}', f'{len(data_loader)=}')
        txn = db.begin(write=True)
        for idx, batch in enumerate(data_loader):
            data = batch[0]
            txn.put(f'{idx}'.encode(), _dumps_pyarrow(data))
            if idx % write_frequency == 0:
                print(f'[{idx+1}/{len(data_loader)}]')
                txn.commit()
                txn = db.begin(write=True)

        # finish iterating through the dataset
        txn.commit()
        keys = [str(k) for k in range(len(data_loader))]
        with db.begin(write=True) as txn:
            txn.put(b'__keys__', _dumps_pyarrow(keys))
            txn.put(b'__len__', _dumps_pyarrow(len(keys)))

        print('Flushing database ....')
        db.sync()


def lmdb2lmdbpickle(
        src_dataset: Dataset, dest_path: str, name: str = 'train',
        write_frequency: int = 500, num_workers: int = 16,
) -> None:
    data_loader = DataLoader(src_dataset, collate_fn=lambda x: x)

    lmdb_path = osp.join(dest_path, name)
    isdir = os.path.isdir(lmdb_path)

    print(f'Generate LMDB to {lmdb_path}')
    with lmdb.open(lmdb_path, subdir=isdir, map_size=int(8e9), readonly=False, meminit=False, map_async=True) as db:
        print(f'{len(src_dataset)=}', f'{len(data_loader)=}')
        txn = db.begin(write=True)
        for idx, batch in enumerate(data_loader):
            data = batch[0]
            txn.put(f'{idx}'.encode(), _dumps_pickle(data))
            if idx % write_frequency == 0:
                print(f'[{idx+1}/{len(data_loader)}]')
                txn.commit()
                txn = db.begin(write=True)

        # finish iterating through the dataset
        txn.commit()
        keys = [str(k) for k in range(len(data_loader))]
        with db.begin(write=True) as txn:
            txn.put(b'__keys__', _dumps_pickle(keys))
            txn.put(b'__len__', _dumps_pickle(len(keys)))

        print('Flushing database ....')
        db.sync()
