from __future__ import annotations
import math
import os
from pathlib import Path
import jax
import torch
import numpy as np
import jax.numpy as jnp
from torch.utils.data import Dataset
from typing import Dict, List
from functools import partial
import itertools
import random
import numpy.typing as npt
import flax
import lz4.frame
import h5py
import copy
from torch.utils.data import IterableDataset
from torch.utils.data import get_worker_info

# Setup import path
import sys
BASEDIR = Path(__file__).parent.parent
if str(BASEDIR) not in sys.path:
    sys.path.insert(0, str(BASEDIR))

import util.transform_util as tutil
from util.structs import SceneData
import util.render_util as rutil
import util.latent_obj_util as loutil
import util.camera_util as cutil

def aggregate_data_from_directory(data_dir_path: Path, category):
    """This function will load all datapoints from separated files and concatenate them"""
    # We will collapse all files. Make sure to use sorted for reproducibility
    if category == 'none':
        dataset_filenames = list(sorted(data_dir_path.iterdir()))
    else:
        dataset_filenames = []
        for cat in category.split('__'):
            dataset_filenames = dataset_filenames + list(sorted(data_dir_path.glob(f'*{cat}.pkl')))
    # Read all
    entire_dataset_list: List[SceneData] = []
    for i, fname in enumerate(dataset_filenames):
        file_path = data_dir_path/fname
        with file_path.open("rb") as f:
            np_data = np.load(f, allow_pickle=True)
            data = np_data["item"].item()
            data["depths"] = None
        entire_dataset_list.append(data)
    # Concatenate along batch dimension
    entire_dataset_batched = jax.tree_map(lambda *x: np.concatenate(x, 0), *entire_dataset_list)
    return entire_dataset_batched


def pytree_collate(batch: List[Dict]):
  """Simple collation for numpy pytree instances"""
  data = jax.tree_util.tree_map(lambda *x: np.stack(x, 0), *batch)
  return data


import flax
import pickle
def pkl_to_npz(pkl_file_path:Path):
    npz_file_path = pkl_file_path.parent / (pkl_file_path.name[:-4] + '.npz')
    if not npz_file_path.exists():
        with open(pkl_file_path, 'rb') as f:
            datapnt = pickle.load(f)
            numpy_datapoints = flax.serialization.to_state_dict(datapnt)
        with open(npz_file_path, 'wb') as f:
            np.savez_compressed(f, item=numpy_datapoints)
        print(f'save {npz_file_path}')
    else:
        print(f'pass {npz_file_path}')

def del_if_exist(data, key):
    if key in data:
        del data[key]

def preprocess_datapoint(data, obj_id_names:List[str]):
    extended = False
    if len(data['rgbs'].shape) == 4:
        extended = True
        data = jax.tree_util.tree_map(lambda x: x[None], data)
        
    if 'obj_info' not in data:
        obj_info = {k.replace('ObjInfo_',''):data[k] for k in data.keys() if 'ObjInfo' in k}
        data['obj_info'] = obj_info
        data = {k:data[k] for k in data.keys() if 'ObjInfo' not in k}
    if 'cam_info' not in data:
        cam_info_key = ['cam_posquats', 'cam_intrinsics']
        cam_info = {k:data[k] for k in cam_info_key}
        data['cam_info'] = cam_info
        for ck in cam_info_key:
            del data[ck]
    if 'env_info' not in data:
        env_info = {k.replace('EnvInfo_',''):data[k] for k in data.keys() if 'EnvInfo' in k}
        data['env_info'] = env_info
        data = {k:data[k] for k in data.keys() if 'EnvInfo' not in k}
    
    alloable_uid_list = np.concatenate([data['obj_info']['uid_list'], data['env_info']['uid_list'][...,-1:]], axis=-1)
    data['seg'] = np.any(data['seg'][...,None] == alloable_uid_list, axis=-1)
    
    entire_key = list(data.keys())
    for key in entire_key:
        # if "EnvInfo" in key:
        #     del_if_exist(data, key)
        if "RobotInfo" in key:
            del_if_exist(data, key)
    

    del_if_exist(data, "depths")
    del_if_exist(data, "table_params")
    del_if_exist(data, "robot_params")
    del_if_exist(data, "uid_clss")
    del_if_exist(data, "nvren_info")
    del_if_exist(data['obj_info'], "obj_cvx_verts_padded")
    del_if_exist(data['obj_info'], "obj_cvx_faces_padded")
    del_if_exist(data['obj_info'], "uid_list")
    del_if_exist(data['env_info'], "uid_list")

    valid_ds = []
    mesh_name = data['obj_info']['mesh_name']
    idx_array = -np.ones(mesh_name.shape, dtype=np.int32)
    for i, nm_batch in enumerate(mesh_name):
        # if byte array, convert to string
        nm_batch = [nm.decode() if isinstance(nm, bytes) else nm for nm in nm_batch]
        valid_ = False
        for j, nm in enumerate(nm_batch):
            if nm is not None and nm!='None':
                obj_id_nm = extract_obj_id_name(nm)
                for oid_idx_, oid_target in enumerate(obj_id_names):
                    if obj_id_nm in oid_target:
                        idx_array[i,j] = oid_idx_
                        valid_ = True
                        break
        valid_ds.append(valid_)
    del_if_exist(data['obj_info'], "mesh_name")
    data['obj_info']['idx'] = idx_array
    scale = data['obj_info']["scale"]
    assert np.all(scale[...,0]==scale[...,1])
    assert np.all(scale[...,1]==scale[...,2])
    data['obj_info']["scale"] = scale[...,0]

    assert np.sum(np.abs(data['obj_info']['obj_posquats'][data['obj_info']["scale"] == 0])) < 1e-5


    mesh_name = data['env_info']['mesh_name']
    idx_array = -np.ones(mesh_name.shape, dtype=np.int32)
    for i, nm_batch in enumerate(mesh_name):
        # if byte array, convert to string
        nm_batch = [nm.decode() if isinstance(nm, bytes) else nm for nm in nm_batch]
        valid_ = False
        for j, nm in enumerate(nm_batch):
            if nm is not None and nm!='None':
                obj_id_nm = extract_obj_id_name(nm)
                for oid_idx_, oid_target in enumerate(obj_id_names):
                    if obj_id_nm in oid_target:
                        idx_array[i,j] = oid_idx_
                        valid_ = True
                        break
        valid_ds.append(valid_)
    del_if_exist(data['env_info'], "mesh_name")
    data['env_info']['idx'] = idx_array
    scale = data['env_info']["scale"]
    assert np.all(scale[...,0]==scale[...,1])
    assert np.all(scale[...,1]==scale[...,2])
    data['env_info']["scale"] = scale[...,0]

    assert np.sum(np.abs(data['env_info']['obj_posquats'][data['env_info']["scale"] == 0])) < 1e-5
    
    if extended:
        data = jax.tree_util.tree_map(lambda x: x.squeeze(0), data)        
    else:
        data = jax.tree_util.tree_map(lambda x: x[valid_ds], data)
        
    return data

class EstimatorDataset(Dataset):
    """Estimator dataset tailored for FLAX"""
    def __init__(
            self,
            data_dir_path:Path,
            size_limit:int,
            ds_obj_no:int,
            sdf_dirs:List[str],
            ds_type='train',
    ):
        """Entire data is already loaded in the memory"""
        
        self.sdf_dirs = sdf_dirs
        self.size_limit = size_limit
        self.ds_type = ds_type
        dataset_filenames = list(sorted(data_dir_path.glob(f'*.npz')))

        # filter max num
        dataset_filenames_ = []
        for df in dataset_filenames:
            base_name = str(df.name).split('.')[0]
            # ds_max_obj_no = int(base_name.split('_')[-3])
            ds_max_obj_no = int(base_name.split('_')[-1])
            if ds_max_obj_no == ds_obj_no:
                dataset_filenames_.append(df)
        dataset_filenames = dataset_filenames_

        if self.ds_type == 'test':
            dataset_filenames = [df for df in dataset_filenames if str(df.name).split('_')[0]=='val']
        else:
            dataset_filenames = [df for df in dataset_filenames if str(df.name).split('_')[0]!='val' and str(df.name).split('_')[0]!='test']
        shelf_ds_no = len([df for df in dataset_filenames if 'shelf' in str(df.name).split('_')])
        table_ds_no = len(dataset_filenames) - shelf_ds_no
        print(f"ds type: {self.ds_type} // fn loaded: {len(dataset_filenames)} // shelf no: {shelf_ds_no} // table no: {table_ds_no}")

        random.shuffle(dataset_filenames)
        if self.ds_type == 'test' or len(dataset_filenames) < 5:
            self.dataset_filenames = dataset_filenames
            self.cyclic_dataset = False
        else:
            self.dataset_filenames = itertools.cycle(dataset_filenames)
            self.cyclic_dataset = True

        # Read all
        for i, fname in enumerate(self.dataset_filenames):
            print(f'load {str(fname)}')
            with fname.open("rb") as f:
                np_data = np.load(f, allow_pickle=True)
                data = np_data["item"].item()
                data = preprocess_datapoint(data, self.sdf_dirs)
            if i == 0:
                entire_dataset_batched = data
            else:
                entire_dataset_batched = jax.tree_map(lambda *x: np.concatenate(x, 0), entire_dataset_batched, data)
            
            if self.ds_type == 'train' and entire_dataset_batched["rgbs"].shape[0] > size_limit:
                break
        if self.ds_type == 'train':
            entire_dataset_batched = jax.tree_map(lambda x: x[-self.size_limit:], entire_dataset_batched)
        if self.size_limit < 10:
            entire_dataset_batched = jax.tree_map(lambda x: x.repeat(500, axis=0), entire_dataset_batched)
        self.entire_data = entire_dataset_batched
        print(f"ds type {self.ds_type} finish / ds num: {entire_dataset_batched['rgbs'].shape[0]}")

    def __len__(self):
        """Dataset size"""
        return self.entire_data["rgbs"].shape[0]

    def __getitem__(self, index) -> SceneData:
        """All operations will be based on tree_map"""
        # Index an item (squeeze batch dim)
        data = jax.tree_map(lambda x: x[index], self.entire_data)
        # NOTE(ssh): 
        # - If you need some online pre-processing, add them here.
        #   We will not use torch tensor here.
        # pass
        
        return data
    
    def push(self):
        if not self.cyclic_dataset:
            print("non cyclic dataset / no need to push")
            return
        for i, fname in enumerate(self.dataset_filenames):
            print(f"fn {str(fname)} start pushing")
            with fname.open("rb") as f:
                np_data = np.load(f, allow_pickle=True)
                data = np_data["item"].item()
                data = preprocess_datapoint(data, self.sdf_dirs)
                # data["depths"] = None
                # data['table_params'] = None
            if 'seg' not in data:
                data['seg'] = None
            self.entire_data = jax.tree_map(lambda *x: np.concatenate(x, 0), self.entire_data, data)
            if self.entire_data['rgbs'].shape[0] > self.size_limit:
                print(f"fn {str(fname)} pushed ds len: {data['rgbs'].shape[0]}")
                break
        self.entire_data = jax.tree_map(lambda x: x[-self.size_limit:], self.entire_data)



def extract_obj_id_name(path):
    if len(path.split('/')) == 1:
        return path.split('.')[0]
    else:
        if path.split('/')[-3] == 'cvx' and path.split('/')[-1][:3] == 'cvx':
            ds_name = path.split('/')[-4]
        else:
            ds_name = path.split('/')[-3]
        return f"{ds_name}_{path.split('/')[-1].split('.')[0]}"

class EstimatorDatasetsplit(Dataset):
    """Estimator dataset tailored for FLAX"""
    def __init__(
            self,
            data_dir_path:Path,
            ds_obj_no:int,
            sdf_dirs:List[str],
            ds_type='train',
    ):
        """Entire data is already loaded in the memory"""
        
        if sdf_dirs[0].split('.')[-1] == 'pkl':
            # sdf dir to obj id
            # obj id - dataset name + obj mesh name
            self.obj_id_name = [extract_obj_id_name(sd) for sd in sdf_dirs]
        else:
            raise NotImplementedError
        self.ds_type = ds_type
        dataset_filenames = list(sorted(data_dir_path.glob(f'*.lz4')))

        # filter max num
        dataset_filenames_ = []
        for df in dataset_filenames:
            base_name = str(df.name).split('.')[0]
            ds_max_obj_no = int(base_name.split('_')[-1])
            if ds_max_obj_no == ds_obj_no:
                dataset_filenames_.append(df)
        dataset_filenames = dataset_filenames_

        if self.ds_type == 'single_ds':
            ## test with single dataset!!
            self.dataset_filenames = [dataset_filenames[2] for df in range(16*100)]
            print('single dataset test')
        elif self.ds_type == 'single_ds_test':
            self.dataset_filenames = [dataset_filenames[2] for df in range(16*2)]
            print('single dataset test')
        else:
            # random.shuffle(dataset_filenames)
            # shuffle with numpy
            np.random.default_rng(0).shuffle(dataset_filenames)
            if self.ds_type == 'test':
                self.dataset_filenames = dataset_filenames[-len(dataset_filenames)//20:]
            else:
                self.dataset_filenames = dataset_filenames[:-len(dataset_filenames)//20]
            shelf_ds_no = len([df for df in self.dataset_filenames if 'shelf' in str(df.name).split('_')])
            table_ds_no = len(self.dataset_filenames) - shelf_ds_no
            print(f"ds type: {self.ds_type} // fn loaded: {len(self.dataset_filenames)} // shelf no: {shelf_ds_no} // table no: {table_ds_no}")

    def __len__(self):
        """Dataset size"""
        return len(self.dataset_filenames)

    def __getitem__(self, index) -> SceneData:
        """All operations will be based on tree_map"""
        # Index an item (squeeze batch dim)
        fname = self.dataset_filenames[index]
        
        # with lz4.frame.open(str(fname), "r") as fp:
        #         bytes_data = fp.read()
        # data = pickle.loads(bytes_data)
        # data = preprocess_datapoint(data, self.obj_id_name)
        
        try:
            with lz4.frame.open(str(fname), "r") as fp:
                bytes_data = fp.read()
            data = pickle.loads(bytes_data)
            data = preprocess_datapoint(data, self.obj_id_name)
        except:
            print(f"Error in processing datapoint from {fname}")
            # remove file from directory
            # if fname.exists():
            #     os.remove(fname)
            return self.__getitem__(index+1)

        return data
    

class EstimatorIterableDataset(IterableDataset):
    """Estimator dataset using IterableDataset and h5py for efficient data loading."""

    def __init__(
        self,
        data_dir: Path,
        sdf_dirs,
        dataset_type: str = 'train',
        shuffle: bool = False,
    ):
        """
        Args:
            data_dir (Path): Directory containing the HDF5 data files.
            dataset_type (str): Type of the dataset ('train', 'test', or 'single_ds').
            shuffle (bool): Whether to shuffle the data files.
        """
        self.data_dir = data_dir
        self.dataset_type = dataset_type
        self.shuffle = shuffle


        if sdf_dirs[0].split('.')[-1] == 'pkl':
            self.obj_id_name = [extract_obj_id_name(sd) for sd in sdf_dirs]
        else:
            raise NotImplementedError

        # Get a list of HDF5 data files
        self.data_files = sorted(self.data_dir.glob('*.h5'))

        # Filter files based on the dataset type
        if self.dataset_type == 'test':
            self.data_files = [fp for fp in self.data_files if fp.name.startswith('val')]
        else:
            self.data_files = [
                fp for fp in self.data_files if not fp.name.startswith(('val', 'test'))
            ]

        if self.shuffle:
            random.shuffle(self.data_files)

        with h5py.File(str(self.data_files[0]), 'r') as h5_file:
            # Get dataset keys
            keys = list(h5_file.keys())
            num_data_points_per_chunk = h5_file[keys[0]].shape[0]
        self.num_data_points_per_chunk = num_data_points_per_chunk

    def __len__(self):
        return len(self.data_files)*self.num_data_points_per_chunk

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            # Single-process data loading
            file_list = self.data_files
        else:
            # In a worker process, split workload
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            per_worker = int(math.ceil(len(self.data_files) / float(num_workers)))
            start_index = worker_id * per_worker
            end_index = min(start_index + per_worker, len(self.data_files))
            file_list = self.data_files[start_index:end_index]
        if self.shuffle:
            random.shuffle(file_list)
        # Iterate over assigned files
        for file_path in file_list:
            with h5py.File(str(file_path), 'r') as h5_file:
                # Get dataset keys
                keys = list(h5_file.keys())
                num_data_points = h5_file[keys[0]].shape[0]

                # Optionally shuffle the indices
                indices = list(range(num_data_points))
                if self.shuffle:
                    random.shuffle(indices)

                # Iterate over data points
                for idx in indices:
                    data_point = {}
                    for key in keys:
                        data = h5_file[key][idx]
                        data_point[key] = data
                    try:
                        data_point = preprocess_datapoint(data_point, self.obj_id_name)
                        yield data_point
                    except:
                        print(f"Error in processing datapoint from {file_path}")
                        # remove file from directory
                        # if file_path.exists():
                        #     os.remove(file_path)
                        continue




if __name__ == '__main__':
    import util.model_util as mutil
    from torch.utils.data import DataLoader
    from tqdm import tqdm

    models = mutil.Models().load_pretrained_models()
    
    # data_dir = Path('predicate_scenedata_h5/4_5')  # Adjust path as needed
    # dataset = EstimatorIterableDataset(
    #     data_dir=data_dir,
    #     sdf_dirs=models.canonical_latent_obj_filename_list,
    #     dataset_type='train',
    #     shuffle=True
    # )

    # data_dir = Path('predicate_scenedata/4_5')  # Adjust path as needed
    # dataset = EstimatorDatasetsplit(data_dir,
    #         ds_obj_no=4,
    #         sdf_dirs=models.canonical_latent_obj_filename_list,
    #         ds_type='train',
    # )

    # dataloader = DataLoader(
    #     dataset,
    #     batch_size=2,    # Adjust batch size as needed
    #     num_workers=20,    # Adjust number of workers as needed
    #     pin_memory=False   # Use pin_memory for faster data transfer to GPU
    # )

    # for batch in tqdm(dataloader):
    #     # Each batch is a list of data points
    #     # You may need to collate data points into tensors
    #     pass



    # test_data = {'rgbs': np.random.rand(10, 3, 128, 128).astype(np.float32),}
    # with lz4.frame.open('testfile.lz4', mode='wb') as fp:
    #     fp.write(pickle.dumps(test_data))
    # with lz4.frame.open('testfile.lz4', mode='r') as fp:
    #     bytes_data = fp.read()
    # # cp_data = lz4.frame.compress(sr_data)
    # # dc_data = lz4.frame.decompress(bytes_data)
    # loaded = pickle.loads(bytes_data)

    # Configs
    
    DATA_DIR = Path('sink_scenedata/3_5')
    # DATA_DIR = Path('scene_data/4_5')
    BATCH_SIZE = 16
    NUM_WORKERS = 20
    train_dataset = EstimatorDatasetsplit(DATA_DIR, 3, models.canonical_latent_obj_filename_list, 'train')
    train_dataset[0]
    train_loader = DataLoader(
        dataset = train_dataset, 
        batch_size = BATCH_SIZE,
        # num_workers = NUM_WORKERS,
        collate_fn = pytree_collate,
        pin_memory = False, # Only for torch
        shuffle=True,
        drop_last = True)
    train_ds_len = len(train_dataset)//BATCH_SIZE
    # Stats
    print(f"Train examples: {len(train_dataset)}")
    # print(f"Eval examples: {len(eval_dataset)}")

    # ## data visualization
    for ds_batch in tqdm(train_loader):
        rgbs = ds_batch['rgbs']

        # print(rgbs.shape)
        # print(1)