import os
import torch
import torch.nn.functional as F
import glob
import imageio
import numpy as np
from util import get_image_to_tensor_balanced, get_mask_to_tensor, get_label_to_tensor
import json

def read_hier_labels(hier_file, read_id=True):
    hier_labels = {}
    with open(hier_file) as f:
        for i, line in enumerate(f):
            split_line = line.split(' ')
            split_line = [s.replace('\n', '') for s in split_line]
            if read_id:
                hier_labels[split_line[0]] = split_line[1]
            else:
                hier_labels[i+1] = split_line[1]
    return hier_labels

def get_hier_label_mapping(base_path, cat="Chair", level=1):
    target_label_file = os.path.join(base_path, '..', 'after_merging_label_ids', cat + '-level-'+str(level)+'.txt')
    target_hier = read_hier_labels(target_label_file, read_id=False)
    src_label_file = os.path.join(base_path, '..', 'after_merging_label_ids', cat + '_before.txt')
    src_hier = read_hier_labels(src_label_file)
    merging_file = os.path.join(base_path, '..', 'after_merging_label_ids', cat + '_merging.txt')
    merging_hier = read_hier_labels(merging_file)
    # print(merging_hier)
    src_labels = np.array([int(s) for s in src_hier.keys()])
    # print(src_hier)
    for src_id, src_str in src_hier.items():
        for mer_before, mer_after in merging_hier.items():
            if src_str == mer_before:
                src_hier[src_id] = mer_after
    # print(src_hier)
    # assert(True == False)
    # print(src_labels.max())
    label_mapping = np.zeros(src_labels.max()+1, dtype=np.int16)
    for src_id, src_str in src_hier.items():
        for target_id, target_str in target_hier.items():
            if target_str in src_str:
                label_mapping[int(src_id)] = target_id
                break
    label_mapping += 1
    return label_mapping

def fill_results_dict(results, ori_id_dict, prefix=""):
    for result in results:
        cur_obj = prefix + result['name']
        if "children" not in result:
            ori_id_dict[result["id"]] = cur_obj
        if "children" in result:
            fill_results_dict(result["children"], ori_id_dict, prefix=cur_obj + '/')

def instance_to_sem_label_mapping(results_path, target_hier):
    with open(results_path) as f:
        results = json.load(f)
    ori_id_dict = {}
    fill_results_dict(results, ori_id_dict)
    ins_to_sem_mapping_dict = {}
    for ori_k, ori_v in ori_id_dict.items():
        if ori_v in target_hier:
            ins_to_sem_mapping_dict[ori_k] = target_hier[ori_v]
        else:
            ins_to_sem_mapping_dict[ori_k] = 0
    ins_to_sem_mapping = np.zeros(max(ins_to_sem_mapping_dict.keys())+1, dtype=np.uint16)
    for k, v in ins_to_sem_mapping_dict.items():
        ins_to_sem_mapping[k] = v
    # print('perrrrrpoooooo')
    # print(ins_to_sem_mapping.shape)
    return ins_to_sem_mapping


cats = [
        "Chair",
        "Table",
        "Lamp",
        "Vase",
        "Display",
        "Clock",
        "Faucet",
        "Laptop",
        "Bottle",
        "Bed",
        "Knife",
        "Mug",
        "Bowl",
        "Microwave",
        "Bag",
        "Dishwasher",
        "Earphone",
        "Keyboard",
        "Hat",
        "TrashCan",
        "StorageFurniture",
        "Scissors",
        "Refrigerator",
        "Door",
]
# cats_l3 = [
#         "Chair",
#         "Table",
#         "Lamp",
#         "Vase",
#         "Display",
#         "Clock",
#         "Faucet",
#         "Bottle",
#         "Bed",
#         "Knife",
#         "Microwave",
#         "Dishwasher",
#         "Earphone",
#         "TrashCan",
#         "StorageFurniture",
#         "Refrigerator",
#         "Door",
# ]

# print(category_map)

class SRNDataset(torch.utils.data.Dataset):
    """
    Dataset from SRN (V. Sitzmann et al. 2020)
    """

    def __init__(
        self, path, stage="train", image_size=(128, 128), world_scale=1.0, load_pc=False, category="Chair", level=1,
    ):
        """
        :param stage train | val | test
        :param image_size result image size (resizes if different)
        :param world_scale amount to scale entire world by
        """
        super().__init__()
        self.base_path = path + "_" + stage
        self.dataset_name = os.path.basename(path)

        print("Loading SRN dataset", self.base_path, "name:", self.dataset_name)
        self.stage = stage
        assert os.path.exists(self.base_path)

        # is_chair = "chair" in self.dataset_name
        # if is_chair and stage == "train":
        #     # Ugly thing from SRN's public dataset
        #     tmp = os.path.join(self.base_path, "chairs_2.0_train")
        #     if os.path.exists(tmp):
        #         self.base_path = tmp

        self.intrins = sorted(
            glob.glob(os.path.join(self.base_path, "*", "intrinsics.txt"))
        )
        # self.intrins= self.intrins[:10]
        self.image_to_tensor = get_image_to_tensor_balanced()
        self.label_to_tensor = get_label_to_tensor()
        self.mask_to_tensor = get_mask_to_tensor()

        self.image_size = image_size
        self.world_scale = world_scale
        self._coord_trans = torch.diag(
            torch.tensor([1, -1, -1, 1], dtype=torch.float32)
        )

        #if is_chair:
            #self.z_near = 1.25
            #self.z_far = 2.75
        #    self.z_near = 0.01
        #    self.z_far = 1.75
        #else:
            # self.z_near = 0.8
            # self.z_far = 1.8
        self.z_near = 0.01
        self.z_far = 1.75
        self.lindisp = False
        self.load_pc = load_pc
        self.category = category.capitalize()
        if self.category == 'Trashcan': self.category = 'TrashCan'
        if self.category == 'Storagefurniture': self.category = 'StorageFurniture'

        if self.category == 'All':
            self.cats = [c for c in cats if os.path.exists(os.path.join(self.base_path, '..', 'after_merging_label_ids', c + '-level-'+str(level)+'.txt'))]
            self.category_map = {k:i for i, k in enumerate(self.cats)}
            self.cat_nclasses = np.zeros( len(self.category_map))
            self.cat_pt_label_mapping = [None] * len(self.category_map)
            self.cat_target_hier = [None] * len(self.category_map)
            for cat, idx in self.category_map.items():
                self.cat_pt_label_mapping[idx] = get_hier_label_mapping(self.base_path, cat=cat, level=level)
                self.cat_nclasses[idx] = self.cat_pt_label_mapping[idx].max().astype(np.int16)
                if self.load_pc:
                    self.cat_target_hier[idx] = read_hier_labels(os.path.join(self.base_path, '..', 'after_merging_label_ids', cat+'_before.txt'), read_id=True)
                    self.cat_target_hier[idx] = {v: k for k, v in self.cat_target_hier[idx].items()}
            self.cat_start_class = np.zeros_like(self.cat_nclasses)
            self.cat_start_class[1:] = np.cumsum(self.cat_nclasses)[:-1].astype(np.int16)
            self.cat_start_class += 1
            # print(self.cat_start_class)
            self.n_classes = np.sum(self.cat_nclasses).astype(np.int16)+1
            # print(self.n_classes)
            new_intrins = []
            for intrin_path in self.intrins:
                dir_path = os.path.dirname(intrin_path)
                with open(os.path.join(dir_path, 'meta.json')) as f:
                    meta_dict = json.load(f)
                    model_cat = meta_dict['model_cat']
                if model_cat in self.category_map:
                    new_intrins.append(intrin_path)
            self.intrins = new_intrins
        else:
            self.pt_label_mapping = get_hier_label_mapping(self.base_path, cat=self.category, level=level)
            self.n_classes = self.pt_label_mapping.max() + 1
            if self.load_pc:
                self.target_hier = read_hier_labels(os.path.join(self.base_path, '..', 'after_merging_label_ids', self.category+'_before.txt'), read_id=True)
                self.target_hier = {v: k for k, v in self.target_hier.items()}
            # self.target_hier = {k.split('/')[-1]: v for k,v in target_hier.items()}
            # print(self.pt_label_mapping, len(self.pt_label_mapping))

    def __len__(self):
        return len(self.intrins)

    def __getitem__(self, index):
        intrin_path = self.intrins[index]
        dir_path = os.path.dirname(intrin_path)
        rgb_paths = sorted(glob.glob(os.path.join(dir_path, "rgb", "*")))
        pose_paths = sorted(glob.glob(os.path.join(dir_path, "pose", "*")))
        labels_paths = sorted(glob.glob(os.path.join(dir_path, "labels", "*")))
        label_mapping = np.load(os.path.join(dir_path, 'part_id_to_label_before.npy'))
        if self.category == "All":
            with open(os.path.join(dir_path, 'meta.json')) as f:
                meta_dict = json.load(f)
                model_cat = meta_dict['model_cat']
            cat_idx = self.category_map[model_cat]
            cat_start_class = self.cat_start_class[cat_idx]
            n_classes = self.cat_nclasses[cat_idx]
            if self.load_pc:
                target_hier = self.cat_target_hier[cat_idx]
            pt_label_mapping = self.cat_pt_label_mapping[cat_idx]
        else:
            if self.load_pc:
                target_hier = self.target_hier
            pt_label_mapping = self.pt_label_mapping
        
        assert len(rgb_paths) == len(pose_paths)
        assert len(rgb_paths) == len(labels_paths)

        with open(intrin_path, "r") as intrinfile:
            lines = intrinfile.readlines()
            focal, cx, cy, _ = map(float, lines[0].split())
            height, width = map(int, lines[-1].split())

        all_imgs = []
        all_poses = []
        all_masks = []
        all_bboxes = []
        all_labels = []
        for rgb_path, pose_path, label_path in zip(rgb_paths, pose_paths, labels_paths):
            img = imageio.imread(rgb_path)[..., :3]
            img_tensor = self.image_to_tensor(img)
            mask = (img != 255).all(axis=-1)[..., None].astype(np.uint8) * 255
            mask_tensor = self.mask_to_tensor(mask)

            pose = torch.from_numpy(
                np.loadtxt(pose_path, dtype=np.float32).reshape(4, 4)
            ).inverse()
            pose = pose @ self._coord_trans

            rows = np.any(mask, axis=1)
            cols = np.any(mask, axis=0)
            rnz = np.where(rows)[0]
            cnz = np.where(cols)[0]
            if len(rnz) == 0:
                raise RuntimeError(
                    "ERROR: Bad image at", rgb_path, "please investigate!"
                )
            rmin, rmax = rnz[[0, -1]]
            cmin, cmax = cnz[[0, -1]]
            bbox = torch.tensor([cmin, rmin, cmax, rmax], dtype=torch.float32)

            labels = imageio.imread(label_path)
            # print(labels.min(), labels.max(), label_mapping.shape)

            #TODO: treat background different from other shape parts?
            #labels of 0 are both other shape parts and background
            #label mapping goes from [0, #mesh parts-1] -> [0, #shape parts-1]
            #labels == 0 is background
            # print(label_mapping, self.pt_label_mapping)
            labels[labels != 0] = label_mapping[labels[labels != 0]-1]
            labels[labels != 0] = pt_label_mapping[labels[labels != 0]]
            # print(labels.min(), labels.max())

            # print(labels.min(), labels.max(), labels.shape)
            labels_tensor = self.label_to_tensor(labels) * 255.
            # print(labels_tensor.min(), labels_tensor.max(), labels_tensor.shape)


            all_imgs.append(img_tensor)
            all_masks.append(mask_tensor)
            all_poses.append(pose)
            all_bboxes.append(bbox)
            all_labels.append(labels_tensor)

        all_imgs = torch.stack(all_imgs)
        all_poses = torch.stack(all_poses)
        all_masks = torch.stack(all_masks)
        all_bboxes = torch.stack(all_bboxes)
        all_labels = torch.stack(all_labels)

        if all_imgs.shape[-2:] != self.image_size:
            scale = self.image_size[0] / all_imgs.shape[-2]
            focal *= scale
            cx *= scale
            cy *= scale
            all_bboxes *= scale

            all_imgs = F.interpolate(all_imgs, size=self.image_size, mode="area")
            all_masks = F.interpolate(all_masks, size=self.image_size, mode="nearest")
            all_labels = F.interpolate(all_labels, size=self.image_size, mode="nearest")

        if self.world_scale != 1.0:
            focal *= self.world_scale
            all_poses[:, :3, 3] *= self.world_scale

        # focal = 131.25
        # cx, cy = 64, 64
        focal = torch.tensor(focal, dtype=torch.float32)
        c = torch.tensor([cx, cy], dtype=torch.float32)
        # import pdb; pdb.set_trace()
        # print('8'*80)
        # print(all_imgs.shape, all_labels.shape)
        # print(focal, c)
        result = {
            "path": dir_path,
            "img_id": index,
            "focal": focal,
            "c": c,
            "images": all_imgs,
            "masks": all_masks,
            "bbox": all_bboxes,
            "poses": all_poses,
            "labels": all_labels,
        }
        if self.category == "All":
            result['n_classes'] = n_classes
            result['cat_start_class'] = cat_start_class
        if self.load_pc:
            pc_path = os.path.join(dir_path, "sample-points-all-pts-nor-rgba-label-10000.npy")
            pc = np.load(pc_path)
            result['pts'] = pc[:,:3]
            result['pts_normals'] = pc[:,3:6]
            result['pts_rgb'] = pc[:,6:9]
            # print(pc[:,9].astype(np.uint16).max(), pc[:,9].astype(np.uint16).min())
            results_path = os.path.join(dir_path, 'result.json')
            ins_to_sem_mapping = instance_to_sem_label_mapping(results_path, target_hier)
            # result['pts_labels'] = self.pt_label_mapping[pc[:,9].astype(np.uint16)]
            # print(ins_to_sem_mapping, self.pt_label_mapping, dir_path, pc[:,9].astype(np.uint16).max(), pc[:,9].astype(np.uint16).min())
            sem_labels = ins_to_sem_mapping[pc[:,9].astype(np.uint16)]
            result['pts_labels'] = pt_label_mapping[sem_labels]
            # result['pts_labels'] = pc[:,9].astype(np.int16)

            # print(self.pt_label_mapping[pc[:20,9].astype(np.uint16)], pc[:20,9].astype(np.uint16))
        return result

