import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import numpy as np
import pickle
import argparse
import ray
import jax
import flax
import nvisii
from pathlib import Path
from tqdm import tqdm
import datetime
import lz4.frame
import pybullet as pb

# Setup import path
import sys
BASEDIR = Path(__file__).parent.parent
sys.path.insert(0, str(BASEDIR))

# import util.io_util as ioutil
import dataset.render_with_nvisii as rwn
import dataset.scene_generation as sg
import util.structs as structs

# Typing
from typing import Tuple, List



def main(args: argparse.Namespace):
    
    # Env config
    PIXEL_SIZE = ([int(v) for v in args.pixel_size.split("-")])
    if args.visualize_for_debug:
        BATCH_SIZE = 4
    else:
        BATCH_SIZE = args.inner_itr_no
    NUM_VIEWS = args.num_views
    NUM_OBJ = args.num_objs
    NUM_ITERATIONS: int = args.num_iterations
    RAY_RESET_INTERVAL: int = args.ray_reset_interval
    NUM_RAY_ENVS: int = args.num_ray_envs
    SCENE_TYPE: str = args.scene_type
    USE_NVISII: bool = args.use_nvisii
    if args.visualize_for_debug:
        NO_RGB: bool = False
    else:
        NO_RGB: bool = USE_NVISII
    ADD_DISTRACTOR: bool = args.add_distractor
    # IO config
    HDR_DIR = Path(args.hdr_dir)
    TEXTURE_DIR = Path(args.texture_dir)

    cur_timesteps = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    SAVE_DIR = Path(args.save_dir)
    SAVE_DIR.mkdir(exist_ok=True)
    SAVE_DIR = SAVE_DIR/f"{NUM_OBJ}_{NUM_VIEWS}"
    
    if args.validation:
        SAVE_FILE_PATH = SAVE_DIR/f"val_{cur_timesteps}_{args.camera_type}_{args.scene_type}_{NUM_VIEWS}_{NUM_OBJ}"
    else:
        SAVE_FILE_PATH = SAVE_DIR/f"{cur_timesteps}_{args.camera_type}_{args.scene_type}_{NUM_VIEWS}_{NUM_OBJ}"

    valid_ds_list = ['NOCS', 'GoogleScannedObjects']
    with open('dataset/sdf_dirs.txt', 'r') as f:
        sdf_list = f.readlines()
    sdf_list = [odr.strip() for odr in sdf_list]
    odr_list = []
    dataset_list = []
    for i, odr in enumerate(sdf_list):
        with open(odr, 'rb') as f:
            sdf_loaded = pickle.load(f)
        if sdf_loaded['dataset'] in valid_ds_list:
            odr_list.append(sdf_loaded['path'])
            dataset_list.append(sdf_loaded['dataset'])

    # reduce the number of objects
    if args.used_objset_no is not None and args.used_objset_no > 0:
        random_idx = np.random.default_rng(seed=42).integers(0, len(odr_list), size=(args.used_objset_no,))
        odr_list = np.array(odr_list)[random_idx].tolist()

    datapoints_list: List[sg.SceneCls.SceneData] = []
    if args.visualize_for_debug == 2:
        # DEBUG: Serial data generation
        datapoints = []
        for i in tqdm(range(NUM_ITERATIONS)):
            scene = sg.SceneCls(
                object_set_dir_list = odr_list, 
                max_obj_no = NUM_OBJ, 
                pixel_size = PIXEL_SIZE,
                scene_type = SCENE_TYPE, 
                no_rgb = NO_RGB, 
                gui = True,
                robot_gen = False
            )
            datapoint = scene.gen_batch_dataset(BATCH_SIZE, NUM_VIEWS)
            datapoints.append(datapoint)
            pb.disconnect()
        datapoints = jax.tree_util.tree_map(lambda *x: np.stack(x, 0), *datapoints)
        datapoints_list.extend(datapoint)
    else:
        # Parallel data generation
        for i in tqdm(range(NUM_ITERATIONS)):
            if i % RAY_RESET_INTERVAL == 0:
                try:
                    ray.shutdown()
                except:
                    pass
                ray_actors = [ray.remote(sg.SceneCls).remote(
                        object_set_dir_list = odr_list,
                        camera_type=args.camera_type,
                        max_obj_no = NUM_OBJ, 
                        pixel_size = PIXEL_SIZE, 
                        scene_type = SCENE_TYPE, 
                        no_rgb = NO_RGB, 
                        gui = False, 
                        robot_gen = False,
                        validation=args.validation==1,
                    ) for _ in range(NUM_RAY_ENVS)]
            imgs_list = ray.get([ra.gen_batch_dataset.remote(BATCH_SIZE, NUM_VIEWS) for ra in ray_actors])
            datapoints = jax.tree_util.tree_map(lambda *x: np.stack(x, 0), *sum(imgs_list, []))
            datapoints_list.append(datapoints)
        ray.shutdown()

    # Aggregate generated data
    batched_datapoints: structs.SceneData = jax.tree_util.tree_map(lambda *x: np.concatenate(x, 0), *datapoints_list)
    num_generated = batched_datapoints.rgbs.shape[0]
    original_batched_datapoints = batched_datapoints

    # Save without image first.
    # if not args.visualize_for_debug:
    #     SAVE_DIR.mkdir(exist_ok=True)
    #     with SAVE_FILE_PATH.open('wb') as f:
    #         numpy_datapoints = flax.serialization.to_state_dict(batched_datapoints)
    #         np.savez_compressed(f, item=numpy_datapoints, protocol=pickle.HIGHEST_PROTOCOL)
    print(f'File initalized without rgbs.')

    # Convert to nvisii
    if USE_NVISII:
        print("Re-rendering using NVISII")
        nvisii_ren = rwn.NvisiiRender(
            pixel_size = PIXEL_SIZE, 
            hdr_dir = str(HDR_DIR),
            texture_dir = str(TEXTURE_DIR))
        
        # Naive iteration...
        nvisii_rgbs = np.zeros((num_generated, NUM_VIEWS, nvisii_ren.option.height, nvisii_ren.option.width, 3), dtype=np.uint8)
        if args.visualize_for_debug:
            origin_rgbs = batched_datapoints.rgbs
        for i in tqdm(range(num_generated)):
            # Render a datapoint
            datapoint = jax.tree_map(lambda x: x[i], batched_datapoints)    # Take one datapoint from batch
            datapoint_rgbs = nvisii_ren.get_rgb_for_datapoint(datapoint, add_distractor=ADD_DISTRACTOR)
            if args.visualize_for_debug:
                # Debug: visualize
                import matplotlib.pyplot as plt
                plt.figure(figsize=(15,10))
                for j, img_ in enumerate(datapoint_rgbs):
                    plt.subplot(2,len(datapoint_rgbs),j+1)
                    plt.imshow(img_)
                    plt.axis('off')
                    plt.subplot(2,len(datapoint_rgbs),j+len(datapoint_rgbs)+1)
                    plt.imshow(origin_rgbs[i,j])
                    plt.axis('off')
                plt.tight_layout()
                plt.show()
            # Update corresponding datapoints
            nvisii_rgbs[i] = datapoint_rgbs
            batched_datapoints = batched_datapoints.replace(rgbs=nvisii_rgbs)

            # Update nvisii images...
            if not args.visualize_for_debug:
                SAVE_DIR.mkdir(exist_ok=True)
                if args.validation:
                    SAVE_FILE_PATH = SAVE_DIR/f"val_{cur_timesteps}_{i:05d}_{args.camera_type}_{args.scene_type}_{NUM_VIEWS}_{NUM_OBJ}.lz4"
                else:
                    SAVE_FILE_PATH = SAVE_DIR/f"{cur_timesteps}_{i:05d}_{args.camera_type}_{args.scene_type}_{NUM_VIEWS}_{NUM_OBJ}.lz4"
                numpy_datapoints = jax.tree_map(lambda x: x[i], batched_datapoints.replace(nvren_info=None))
                numpy_datapoints = flax.serialization.to_state_dict(numpy_datapoints)
                numpy_datapoints = {k: v for k, v in numpy_datapoints.items() if v is not None}
                # with SAVE_FILE_PATH.open('wb') as f:
                with lz4.frame.open(str(SAVE_FILE_PATH), mode='wb') as fp:
                    fp.write(pickle.dumps(numpy_datapoints))
                    # np.savez_compressed(f, item=numpy_datapoints)
                    # np.savez(f, item=numpy_datapoints)

            print(f'File updated at iter={i}. Time: ')
            
            # if i == num_generated-1:
            #     # Update nvisii images...
            #     if not args.visualize_for_debug:
            #         SAVE_DIR.mkdir(exist_ok=True)
            #         with SAVE_FILE_PATH.open('wb') as f:
            #             numpy_datapoints = flax.serialization.to_state_dict(batched_datapoints.replace(nvren_info=None, table_params=None, robot_params=None))
            #             np.savez_compressed(f, item=numpy_datapoints)
            #     print(f'File updated at iter={i}. Time: ')

        nvisii.deinitialize()
    else:
        if not args.visualize_for_debug:
            SAVE_DIR.mkdir(exist_ok=True)
            for i in tqdm(range(num_generated)):
                if args.validation:
                    SAVE_FILE_PATH = SAVE_DIR/f"val_{cur_timesteps}_{i:05d}_{args.camera_type}_{args.scene_type}_{NUM_VIEWS}_{NUM_OBJ}.lz4"
                else:
                    SAVE_FILE_PATH = SAVE_DIR/f"{cur_timesteps}_{i:05d}_{args.camera_type}_{args.scene_type}_{NUM_VIEWS}_{NUM_OBJ}.lz4"
                numpy_datapoints = jax.tree_map(lambda x: x[i], batched_datapoints.replace(nvren_info=None, table_params=None, robot_params=None))
                numpy_datapoints = flax.serialization.to_state_dict(numpy_datapoints)
                numpy_datapoints = {k: v for k, v in numpy_datapoints.items() if v is not None}
                # with SAVE_FILE_PATH.open('wb') as f:
                with lz4.frame.open(str(SAVE_FILE_PATH), mode='wb') as fp:
                    fp.write(pickle.dumps(numpy_datapoints))
                    # np.savez_compressed(f, item=numpy_datapoints)
                    # np.savez(f, item=numpy_datapoints)

            # SAVE_DIR.mkdir(exist_ok=True)
            # with SAVE_FILE_PATH.open('wb') as f:
            #     numpy_datapoints = flax.serialization.to_state_dict(batched_datapoints.replace(nvren_info=None, table_params=None, robot_params=None))
            #     np.savez_compressed(f, item=numpy_datapoints)
        else:
            for i in tqdm(range(num_generated)):
                # Render a datapoint
                datapoint = jax.tree_util.tree_map(lambda x: x[i], batched_datapoints)
                datapoint_rgbs = datapoint.rgbs
                import matplotlib.pyplot as plt
                plt.figure(figsize=(15,10))
                for j, img_ in enumerate(datapoint_rgbs):
                    plt.subplot(1,len(datapoint_rgbs),j+1)
                    plt.imshow(img_)
                plt.tight_layout()
                plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_views', type=int, default=5)
    parser.add_argument('--num_objs', type=int, default=10)
    parser.add_argument('--used_objset_no', type=int, default=0)
    parser.add_argument('--save_dir', type=str, default='scene_data2')
    parser.add_argument('--hdr_dir', type=str, default='../hdr')
    parser.add_argument('--texture_dir', type=str, default="assets/texture")
    # parser.add_argument('--pixel_size', type=str, default="56-98", help="Training image size {height-width}")
    # parser.add_argument('--pixel_size', type=str, default="70-112", help="Training image size {height-width}")
    # parser.add_argument('--pixel_size', type=str, default="98-168", help="Training image size {height-width}")
    # parser.add_argument('--pixel_size', type=str, default="140-224", help="Training image size {height-width}")
    parser.add_argument('--pixel_size', type=str, default="238-420", help="Training image size {height-width}")
    # parser.add_argument('--pixel_size', type=str, default="154-280", help="Training image size {height-width}")
    # parser.add_argument('--pixel_size', type=str, default="126-210", help="Training image size {height-width}")
    parser.add_argument('--num_iterations', type=int, default=4)
    # parser.add_argument('--num_iterations', type=int, default=1)
    parser.add_argument('--ray_reset_interval', type=int, default=2)
    parser.add_argument('--num_ray_envs', type=int, default=10)
    parser.add_argument('--use_nvisii', type=int, default=1)
    parser.add_argument('--scene_type', type=str, default="table", help="['shelf', 'flat', ...]")
    parser.add_argument('--add_distractor', type=int, default=0)
    parser.add_argument('--visualize_for_debug', type=int, default=0)
    parser.add_argument('--validation', type=int, default=0)
    parser.add_argument('--inner_itr_no', type=int, default=8)
    # parser.add_argument('--inner_itr_no', type=int, default=2)
    parser.add_argument('--camera_type', type=str, default='d435')
    args = parser.parse_args()
    main(args)



