from typing import *
from numpy import ndarray
from torch import Tensor

import os
import json
from collections import defaultdict

import cv2
import numpy as np
import torch
import torch.nn.functional as tF
from kiui.cam import orbit_camera, undo_orbit_camera

from src.data.utils.chunk_dataset import ChunkedDataset
from src.options import Options
from src.utils import normalize_normals, unproject_depth


# class GObjaverseParquetDataset(ChunkedDataset):
class GObjaverseDataset():
    def __init__(self, opt: Options, training: bool = True, batch_size = 8, *args, **kwargs):
        self.opt = opt
        self.training = training

        # Default camera intrinsics
        self.fxfycxcy = torch.tensor([opt.fxfy, opt.fxfy, 0.5, 0.5], dtype=torch.float32)  # (4,)

        if opt.prompt_embed_dir is not None:
            try:
                self.negative_prompt_embed = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null.npy")).float()
            except FileNotFoundError:
                self.negative_prompt_embed = None
            try:
                self.negative_pooled_prompt_embed = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null_pooled.npy")).float()
            except FileNotFoundError:
                self.negative_pooled_prompt_embed = None
            try:
                self.negative_prompt_attention_mask = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null_attention_mask.npy")).float()
            except FileNotFoundError:
                self.negative_prompt_attention_mask = None

            if "xl" in opt.pretrained_model_name_or_path:  # SDXL: zero out negative prompt embedding
                if self.negative_prompt_embed is not None and self.negative_pooled_prompt_embed is not None:
                    self.negative_prompt_embed = torch.zeros_like(self.negative_prompt_embed)
                    self.negative_pooled_prompt_embed = torch.zeros_like(self.negative_pooled_prompt_embed)

        # super().__init__(*args, **kwargs)
        
        
        self.data_path = "/path/to/gobjaverse_unzip"
        with open('/path/to/objaverse_filter/gobj_merged.json', 'r') as f:
            self.items = json.load(f)
            
        # naive split
        if self.training:
            self.items = self.items[:-batch_size]
        else:
            self.items = self.items[-batch_size:]
        
        # print(f"number of items: {len(self.items)}")
        # print("self.opt.dataset_size: ", self.opt.dataset_size)
        # exit()
            
        
    def __len__(self):
        # return self.opt.dataset_size
        return len(self.items)
    
    def __getitem__(self, idx):
        
        dictionary_id_instance_id = self.items[idx]
        # dictionary_id, instance_id = dictionary_id_instance_id.split('/')
        
        raw_data = dict()
        campos_dir = os.path.join(self.data_path, dictionary_id_instance_id, 'campos_512_v4')
        
        for vid in range(40):
            raw_data[f'{vid:05d}.png'] = os.path.join(campos_dir, f"{vid:05d}/{vid:05d}.png")
            raw_data[f'{vid:05d}.json'] = os.path.join(campos_dir, f"{vid:05d}/{vid:05d}.json")
            raw_data[f'{vid:05d}_albedo.png'] = os.path.join(campos_dir, f"{vid:05d}/{vid:05d}_albedo.png")
            raw_data[f'{vid:05d}_nd.png'] = os.path.join(campos_dir, f"{vid:05d}/{vid:05d}_nd.png")
            raw_data[f'{vid:05d}_mr.png'] = os.path.join(campos_dir, f"{vid:05d}/{vid:05d}_mr.png")
            
            # print(raw_data)
            # exit()
        return self.get_trainable_data_from_raw_data([raw_data])
        

    def get_trainable_data_from_raw_data(self, raw_data_list) -> Dict[str, Tensor]:  # only `sample["__key__"]` is in str type
        assert len(raw_data_list) == 1
        sample: Dict[str, bytes] = raw_data_list[0]
        
        # print(sample.keys())
        # exit()

        V, V_in = self.opt.num_views, self.opt.num_input_views
        assert V >= V_in

        if self.opt.load_even_views or not self.training:
            _pick_func = self._pick_even_view_indices
        else:
            _pick_func = self._pick_random_view_indices

        # Randomly sample `V_in` views (some objects may not appear in the dataset)
        random_idxs = _pick_func(V_in)
        _num_tries = 0
        while not self._check_views_exist(sample, random_idxs):
            random_idxs = _pick_func(V_in)
            _num_tries += 1
            if _num_tries > 100:  # TODO: make `100` configurable
                raise ValueError(f"Cannot find 4 views in {sample['__key__']}")

        except_idxs = random_idxs + [24, 39]  # filter duplicated views; hard-coded for GObjaverse
        if self.opt.exclude_topdown_views:
            except_idxs += [25, 26]

        # Randomly sample `V` views (some views may not appear in the dataset)
        for i in np.random.permutation(40):  # `40` is hard-coded for GObjaverse
            if len(random_idxs) >= V:
                break
            if f"{i:05d}.png" in sample and i not in except_idxs:
                try:
                    # _ = np.frombuffer(sample[f"{i:05d}.png"], np.uint8)
                    with open(sample[f"{i:05d}.png"], "rb") as f:
                        _ = f.read()
                    assert sample[f"{i:05d}.json"] is not None
                    random_idxs.append(i)
                except:  # TypeError: a bytes-like object is required, not 'NoneType'; KeyError: '00001.json'
                    pass
        # Randomly repeat views if not enough views
        while len(random_idxs) < V:
            random_idxs.append(np.random.choice(random_idxs))

        return_dict = defaultdict(list)
        init_azi = None
        for vid in random_idxs:
            return_dict["fxfycxcy"].append(self.fxfycxcy)  # (V, 4); fixed intrinsics for GObjaverse

            image = self._load_png(sample[f"{vid:05d}.png"])  # (4, 512, 512)
            mask = image[3:4]  # (1, 512, 512)
            image = image[:3] * mask + (1. - mask)  # (3, 512, 512), to white bg
            return_dict["image"].append(image)  # (V, 3, H, W)
            return_dict["mask"].append(mask)  # (V, 1, H, W)

            if self.opt.load_canny:
                gray = cv2.cvtColor(image.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2GRAY)
                canny = cv2.Canny((gray * 255.).astype(np.uint8), 100., 200.)
                canny = torch.from_numpy(canny).unsqueeze(0).float().repeat(3, 1, 1) / 255.  # (3, 512, 512) in [0, 1]
                canny = -canny + 1.  # 0->1, 1->0, i.e., white bg
                return_dict["canny"].append(canny)  # (V, 3, H, W)

            c2w = self._load_camera_from_json(sample[f"{vid:05d}.json"])
            # Blender world + OpenCV cam -> OpenGL world & cam; https://kit.kiui.moe/camera
            c2w[1] *= -1
            c2w[[1, 2]] = c2w[[2, 1]]
            c2w[:3, 1:3] *= -1  # invert up and forward direction
            return_dict["original_C2W"].append(torch.from_numpy(c2w).float())  # (V, 4, 4); for normal normalization only

            # Relative azimuth w.r.t. the first view
            ele, azi, dis = undo_orbit_camera(c2w)  # elevation: [-90, 90] from +y(-90) to -y(90)
            if init_azi is None:
                init_azi = azi
            azi = (azi - init_azi) % 360.  # azimuth: [0, 360] from +z(0) to +x(90)
            # To avoid numerical errors for elevation +/- 90 (GObjaverse index 25 (up) & 26 (down))
            ele_sign = ele >= 0
            ele = abs(ele) - 1e-8
            ele = ele * (1. if ele_sign else -1.)

            new_c2w = torch.from_numpy(orbit_camera(ele, azi, dis)).float()
            return_dict["C2W"].append(new_c2w)  # (V, 4, 4)
            return_dict["cam_pose"].append(torch.tensor(
                [np.deg2rad(ele), np.deg2rad(azi), dis], dtype=torch.float32))  # (V, 3)

            # Albedo
            if self.opt.load_albedo:
                albedo = self._load_png(sample[f"{vid:05d}_albedo.png"])  # (3, 512, 512)
                albedo = albedo * mask + (1. - mask)  # (3, 512, 512), to white bg
                return_dict["albedo"].append(albedo)  # (V, 3, H, W)
            # Normal & Depth
            if self.opt.load_normal or self.opt.load_coord:
                nd = self._load_png(sample[f"{vid:05d}_nd.png"], uint16=True)  # (4, 512, 512)
                if self.opt.load_normal:
                    normal = nd[:3] * 2. - 1.  # (3, 512, 512) in [-1, 1]
                    normal[0, ...] *= -1  # to OpenGL world convention
                    return_dict["normal"].append(normal)  # (V, 3, H, W)
                if self.opt.load_coord or self.opt.load_depth:
                    depth = nd[3] * 5.  # (512, 512); NOTE: depth is scaled by 1/5 in my data preprocessing
                    return_dict["depth"].append(depth)  # (V, H, W)
            # Metal & Roughness
            if self.opt.load_mr:
                mr = self._load_png(sample[f"{vid:05d}_mr.png"])  # (3, 512, 512); (metallic, roughness, padding)
                mr = mr * mask + (1. - mask)  # (3, 512, 512), to white bg
                return_dict["mr"].append(mr)  # (V, 3, H, W)

        for key in return_dict.keys():
            return_dict[key] = torch.stack(return_dict[key], dim=0)

        if self.opt.load_normal:
            # Normalize normals by the first view and transform the first view to a fixed azimuth (i.e., 0)
            # Ensure `normals` and `original_C2W` are in the same camera convention
            normals = normalize_normals(return_dict["normal"].unsqueeze(0), return_dict["original_C2W"].unsqueeze(0), i=0).squeeze(0)
            normals = torch.einsum("brc,bvchw->bvrhw", return_dict["C2W"][0, :3, :3].unsqueeze(0), normals.unsqueeze(0)).squeeze(0)
            normals = normals * 0.5 + 0.5  # [0, 1]
            normals = normals * return_dict["mask"] + (1. - return_dict["mask"])  # (V, 3, 512, 512), to white bg
            return_dict["normal"] = normals
            return_dict.pop("original_C2W")  # original C2W is only used for normal normalization

        # OpenGL to COLMAP camera for Gaussian renderer
        return_dict["C2W"][:, :3, 1:3] *= -1

        # Whether scale the object w.r.t. the first view to a fixed size
        if self.opt.norm_camera:
            scale = self.opt.norm_radius / (torch.norm(return_dict["C2W"][0, :3, 3], dim=-1))
        else:
            scale = 1.
        return_dict["C2W"][:, :3, 3] *= scale
        return_dict["cam_pose"][:, 2] *= scale

        if self.opt.load_coord:
            # Unproject depth map to 3D world coordinate
            coords = unproject_depth(return_dict["depth"].unsqueeze(0) * scale,
                return_dict["C2W"].unsqueeze(0), return_dict["fxfycxcy"].unsqueeze(0)).squeeze(0)
            coords = coords * 0.5 + 0.5  # [0, 1]
            coords = coords * return_dict["mask"] + (1. - return_dict["mask"])  # (V, 3, 512, 512), to white bg
            return_dict["coord"] = coords
            if not self.opt.load_depth:
                return_dict.pop("depth")

        if self.opt.load_depth:
            depths = return_dict["depth"].unsqueeze(1) * return_dict["mask"]  # (V, 1, 512, 512), to black bg
            assert depths.min() == 0.
            if self.opt.normalize_depth:
                H, W = depths.shape[-2:]
                depths = depths.reshape(V, -1)
                depths_max = depths.max(dim=-1, keepdim=True).values
                depths = depths / depths_max  # [0, 1]
                depths = depths.reshape(V, 1, H, W)
            depths = -depths + 1.  # 0->1, 1->0, i.e., white bg
            return_dict["depth"] = depths.repeat(1, 3, 1, 1)

        # Resize to the input resolution
        for key in ["image", "mask", "albedo", "normal", "coord", "depth", "mr", "canny"]:
            if key in return_dict.keys():
                return_dict[key] = tF.interpolate(
                    return_dict[key], size=(self.opt.input_res, self.opt.input_res),
                    mode="bilinear", align_corners=False, antialias=True
                )

        # Handle anti-aliased normal, coord and depth (GObjaverse renders anti-aliased normal & depth)
        if self.opt.load_normal:
            return_dict["normal"] = return_dict["normal"] * return_dict["mask"] + (1. - return_dict["mask"])
        if self.opt.load_coord:
            return_dict["coord"] = return_dict["coord"] * return_dict["mask"] + (1. - return_dict["mask"])
        if self.opt.load_depth:
            return_dict["depth"] = return_dict["depth"] * return_dict["mask"] + (1. - return_dict["mask"])

        # Load precomputed caption embeddings
        if self.opt.prompt_embed_dir is not None:
            uid = sample["uid"].decode("utf-8").split("/")[-1].split(".")[0]
            return_dict["prompt_embed"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}.npy"))
            if "xl" in self.opt.pretrained_model_name_or_path or "3" in self.opt.pretrained_model_name_or_path:  # SDXL or SD3
                return_dict["pooled_prompt_embed"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}_pooled.npy"))
            if "PixArt" in self.opt.pretrained_model_name_or_path:  # PixArt-alpha, PixArt-Sigma
                return_dict["prompt_attention_mask"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}_attention_mask.npy"))

        for key in return_dict.keys():
            assert isinstance(return_dict[key], Tensor), f"Value of the key [{key}] is not a Tensor, but {type(return_dict[key])}."

        return dict(return_dict)

    # def _load_png(self, png_bytes: bytes, uint16: bool = False) -> Tensor:
    def _load_png(self, png_path: str, uint16: bool = False) -> Tensor:
        with open(png_path, "rb") as f:
            png_bytes = f.read()
        png = np.frombuffer(png_bytes, np.uint8)
        png = cv2.imdecode(png, cv2.IMREAD_UNCHANGED)  # (H, W, C) ndarray in [0, 255] or [0, 65553]

        png = png.astype(np.float32) / (65535. if uint16 else 255.)  # (H, W, C) in [0, 1]
        png[:, :, :3] = png[:, :, :3][..., ::-1]  # BGR -> RGB
        png_tensor = torch.from_numpy(png).nan_to_num_(0.)  # there are nan in GObjaverse gt normal
        return png_tensor.permute(2, 0, 1)  # (C, H, W) in [0, 1]

    # def _load_camera_from_json(self, json_bytes: bytes) -> ndarray:
    def _load_camera_from_json(self, json_path: str) -> ndarray:
        # json_dict = json.loads(json_bytes)
        with open(json_path, 'r') as f:
            json_dict = json.load(f)

        # In OpenCV convention
        c2w = np.eye(4)  # float64
        c2w[:3, 0] = np.array(json_dict["x"])
        c2w[:3, 1] = np.array(json_dict["y"])
        c2w[:3, 2] = np.array(json_dict["z"])
        c2w[:3, 3] = np.array(json_dict["origin"])
        return c2w

    def _pick_even_view_indices(self, num_views: int = 4) -> List[int]:
        assert 12 % num_views == 0  # `12` for even-view sampling in GObjaverse

        if np.random.rand() < 2/3:
            index0 = np.random.choice(range(24))  # 0~23: 24 views in ele from [5, 30]; hard-coded for GObjaverse
            return [(index0 + (24 // num_views)*i) % 24 for i in range(num_views)]
        else:
            index0 = np.random.choice(range(12))  # 27~38: 12 views in ele from [-5, 5]; hard-coded for GObjaverse
            return [((index0 + (12 // num_views)*i) % 12 + 27) for i in range(num_views)]

    def _pick_random_view_indices(self, num_views: int = 4) -> List[int]:
        assert num_views <= 40  # `40` is hard-coded for GObjaverse

        indices = (set(range(40)) - set([25, 26])) if self.opt.exclude_topdown_views else (set(range(40)))  # `40` is hard-coded for GObjaverse
        return np.random.choice(list(indices), num_views, replace=False).tolist()

    def _check_views_exist(self, sample: Dict[str, Union[str, bytes]], vids: List[int]) -> bool:
        for vid in vids:
            if f"{vid:05d}.png" not in sample:
                return False
            try:
                assert sample[f"{vid:05d}.png"] is not None and sample[f"{vid:05d}.json"] is not None
            except:  # TypeError: a bytes-like object is required, not 'NoneType'; KeyError: '00001.json'
                return False
        return True
