import os
import torch
import torch.nn.functional as F
import glob
import imageio
import numpy as np
from util import get_image_to_tensor_balanced, get_mask_to_tensor, get_label_to_tensor
from .data_util import load_labels, load_rgb, load_pose, load_seg, load_pts, load_labels


class SRNDatasetOrig(torch.utils.data.Dataset):
    """
    Dataset from SRN (V. Sitzmann et al. 2020)
    """

    def __init__(
        self, path, stage="train", image_size=(128, 128), world_scale=1.0, load_pc=False, category="chair", **kwargs,
    ):
        """
        :param stage train | val | test
        :param image_size result image size (resizes if different)
        :param world_scale amount to scale entire world by
        """
        super().__init__()
        self.category = category.capitalize()
        path = os.path.dirname(path)
        print(path)
        path = os.path.join(path, f'semantic_srn_data/{self.category}/{self.category}')
        print(path)
        self.base_path = path + "." + stage
        self.dataset_name = os.path.basename(path)

        print("Loading SRN dataset", self.base_path, "name:", self.dataset_name)
        self.stage = stage
        assert os.path.exists(self.base_path)

        self.intrins = sorted(
            glob.glob(os.path.join(self.base_path, "*", "intrinsics.txt"))
        )
        self.image_to_tensor = get_image_to_tensor_balanced()
        self.mask_to_tensor = get_mask_to_tensor()
        self.seg_to_tensor = get_label_to_tensor()

        self.image_size = image_size
        self.world_scale = world_scale
        self._coord_trans = torch.diag(
            torch.tensor([1, -1, -1, 1], dtype=torch.float32)
        )

        self.load_pc = load_pc
        classes = load_labels(path + '-level-1.txt')
        # print(classes)
        self.n_classes = len(classes)

        # if is_chair:
        #     self.z_near = 1.25
        #     self.z_far = 2.75
        # else:
        self.z_near = 0.01
        self.z_far = 2.5
        self.lindisp = False

    def __len__(self):
        return len(self.intrins)

    def __getitem__(self, index):
        intrin_path = self.intrins[index]
        dir_path = os.path.dirname(intrin_path)
        rgb_paths = sorted(glob.glob(os.path.join(dir_path, "rgb", "*")))
        pose_paths = sorted(glob.glob(os.path.join(dir_path, "pose", "*")))
        assert len(rgb_paths) == len(pose_paths)
        seg_paths = sorted(glob.glob(os.path.join(dir_path, "seg", "*")))
        # mapping_path = os.path.join(dir_path, 'result.json')
        # transfer_path = os.path.join(dir_path, 'result_after_merging.json')
        assert len(rgb_paths) == len(seg_paths)

        with open(intrin_path, "r") as intrinfile:
            lines = intrinfile.readlines()
            focal, cx, cy, _ = map(float, lines[0].split())
            height, width = map(int, lines[-1].split())

        all_imgs = []
        all_poses = []
        all_masks = []
        all_bboxes = []
        all_segs = []
        for rgb_path, pose_path, seg_path in zip(rgb_paths, pose_paths, seg_paths):
            img = load_rgb(rgb_path)
            mask = (1-img[3])[..., None].astype(np.uint8) * 255
            mask_tensor = self.mask_to_tensor(mask)
            # print(mask_tensor)
            img = (img[:3].transpose(1, 2, 0)+1)/2
            # print(img.max(), img.min())
            # img = img[::-1,::-1]

            # img = imageio.imread(rgb_path)[..., :3]
            # print(img.min(), img.max())
            
            img_tensor = self.image_to_tensor(img)
            # print(img_tensor.max(), img_tensor.min())
            

            # pose = torch.from_numpy(
            #     np.loadtxt(pose_path, dtype=np.float32).reshape(4, 4)
            # )
            # pose = pose @ self._coord_trans
            pose = torch.from_numpy(load_pose(pose_path))
            pose = pose @ self._coord_trans

            # print(pose, pose.shape)


            seg = load_seg(seg_path)
            seg = seg[..., None]
            # seg = seg[::-1].copy()
            # imageio.imsave("test_seg.png", seg*255)
            # imageio.imsave("test_img.png", (img+1)*127.5)
            # assert(True == False)
            seg_tensor = self.seg_to_tensor(seg) * 255.
            
            # import pdb; pdb.set_trace()

            rows = np.any(mask, axis=1)
            cols = np.any(mask, axis=0)
            rnz = np.where(rows)[0]
            cnz = np.where(cols)[0]
            if len(rnz) == 0:
                raise RuntimeError(
                    "ERROR: Bad image at", rgb_path, "please investigate!"
                )
            rmin, rmax = rnz[[0, -1]]
            cmin, cmax = cnz[[0, -1]]
            bbox = torch.tensor([cmin, rmin, cmax, rmax], dtype=torch.float32)

            all_imgs.append(img_tensor)
            all_masks.append(mask_tensor)
            all_poses.append(pose)
            all_bboxes.append(bbox)
            all_segs.append(seg_tensor)

        all_imgs = torch.stack(all_imgs)
        all_poses = torch.stack(all_poses)
        all_masks = torch.stack(all_masks)
        all_bboxes = torch.stack(all_bboxes)
        all_segs = torch.stack(all_segs)

        if all_imgs.shape[-2:] != self.image_size:
            scale = self.image_size[0] / all_imgs.shape[-2]
            focal *= scale
            cx *= scale
            cy *= scale
            all_bboxes *= scale

            all_imgs = F.interpolate(all_imgs, size=self.image_size, mode="area")
            all_masks = F.interpolate(all_masks, size=self.image_size, mode="area")
            all_segs = F.interpolate(all_segs, size=self.image_size, mode="nearest")

        if self.world_scale != 1.0:
            focal *= self.world_scale
            all_poses[:, :3, 3] *= self.world_scale
        focal = torch.tensor(focal, dtype=torch.float32)

        result = {
            "path": dir_path,
            "img_id": index,
            "focal": focal,
            "c": torch.tensor([cx, cy], dtype=torch.float32),
            "images": all_imgs,
            "masks": all_masks,
            "bbox": all_bboxes,
            "poses": all_poses,
            "labels": all_segs
        }
        if self.load_pc:
            pts_path = os.path.join(dir_path, 'point_cloud', 'sample-points-all-pts-nor-rgba-10000.txt')
            labels_path = os.path.join(dir_path, 'point_cloud', 'sample-points-label-10000.txt')
            result['pts'], result['pts_rgb'] = load_pts(pts_path)
            result['pts_labels'] = load_labels(labels_path)
        return result