import torch.utils.data as data
import numpy as np
import random
import numpy.ma as ma
import yaml
import json
from PIL import Image, ImageFilter
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from lib.transformations import quaternion_from_matrix
import open3d as o3d

class PoseDataset(data.Dataset):
    def __init__(self, mode, num_pt, root, add_noise, noise_trans):
        self.objlist = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
        self.obj_idx = {1: 0, 2: 1, 4: 2, 5: 3, 6: 4, 8: 5, 9: 6, 10: 7, 11: 8, 12: 9, 13: 10, 14: 11, 15: 12}
        self.mode = mode

        self.list_rgb = []
        self.list_depth = []
        self.list_label = []
        self.list_obj = []
        self.list_rank = []
        self.meta = {}
        self.cld = {}
        self.cld_rgb = {}
        self.root = root
        self.noise_trans = noise_trans
        self.num_pt = num_pt
        item_count = 0
        for item in self.objlist:
            if self.mode == 'train':
                input_file = open('{0}/data/{1}/train.txt'.format(self.root, '%02d' % item))
            else:
                input_file = open('{0}/data/{1}/test.txt'.format(self.root, '%02d' % item))
            while 1:
                item_count += 1
                input_line = input_file.readline()
                if self.mode == 'test' and item_count % 10 != 0:
                    continue
                if not input_line:
                    break
                if input_line[-1:] == '\n':
                    input_line = input_line[:-1]
                self.list_rgb.append('{0}/data/{1}/rgb/{2}.png'.format(self.root, '%02d' % item, input_line))
                self.list_depth.append('{0}/data/{1}/depth/{2}.png'.format(self.root, '%02d' % item, input_line))
                if self.mode == 'eval':
                    self.list_label.append('{0}/segnet_results/{1}_label/{2}_label.png'.format(self.root, '%02d' % item, input_line))
                else:
                    self.list_label.append('{0}/data/{1}/mask/{2}.png'.format(self.root, '%02d' % item, input_line))
                
                self.list_obj.append(item)
                self.list_rank.append(int(input_line))

           
            
            

             

            import yaml
            import json

            

            meta_file = '{0}/data/{1}/gt.yml'.format(self.root, '%02d' % item)

            with open(meta_file, 'r') as fd:
                data = yaml.safe_load(fd)
                json_data = json.dumps(data)
                self.meta[item] = json.loads(json_data)

            
            input_cloud = o3d.io.read_point_cloud('{0}/models/pcd/obj_{1}.pcd'.format(self.root, '%02d' % item))
            raw_xyz = torch.tensor(np.asarray(input_cloud.points).reshape((1, -1, 3)), dtype=torch.float32)
            xyz_ids = farthest_point_sample(raw_xyz, num_pt).cpu().numpy()
            raw_xyz = np.asarray(input_cloud.points).astype(np.float32)
            raw_rgb = np.asarray(input_cloud.colors).astype(np.float32) * 255.0
            self.cld[item] = raw_xyz[xyz_ids[0, :], :]
            self.cld_rgb[item] = raw_rgb[xyz_ids[0, :], :]
            
            print("Object {0} buffer loaded".format(item))
        if mode == 'train':
            for i in range(4):
                self.list_rgb = self.list_rgb + self.list_rgb
                self.list_depth = self.list_depth + self.list_depth
                self.list_label = self.list_label + self.list_label
                self.list_obj = self.list_obj + self.list_obj
                self.list_rank = self.list_rank + self.list_rank

        self.length = len(self.list_rgb)
        print('data length: {}'.format(self.length))

        self.cam_cx = 325.26110
        self.cam_cy = 242.04899
        self.cam_fx = 572.41140
        self.cam_fy = 573.57043

        self.xmap = np.array([[i for i in range(640)] for j in range(480)]).astype(np.float32)
        self.ymap = np.array([[j for i in range(640)] for j in range(480)]).astype(np.float32)

        self.add_noise = add_noise
        self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
        self.resize_img_width = 128
        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.obj_radius = [0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15]
        self.symmetry_obj_idx = []

        
        self.prim_groups = []
        for i in range(len(self.objlist)):
            tmp = []
            tmp.append(torch.tensor([[0, 0, 0]], dtype=torch.float).permute(1, 0).contiguous() / self.obj_radius[i])
            tmp.append(torch.tensor([[0.1, 0, 0]], dtype=torch.float).permute(1, 0).contiguous() / self.obj_radius[i])
            tmp.append(torch.tensor([[0, 0.1, 0]], dtype=torch.float).permute(1, 0).contiguous() / self.obj_radius[i])
            self.prim_groups.append(tmp)

    def __getitem__(self, index):
        img = Image.open(self.list_rgb[index])
        depth = np.array(Image.open(self.list_depth[index]))
        label = np.array(Image.open(self.list_label[index]))
        obj = self.list_obj[index] 
        idx = self.obj_idx[obj]
        rank = '{}'.format(self.list_rank[index])

        if obj == 2:
            for i in range(0, len(self.meta[obj][rank])):
                if self.meta[obj][rank][i]['obj_id'] == 2:
                    meta = self.meta[obj][rank][i]
                    break
        else:
            meta = self.meta[obj][rank][0]

        mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0))
        if self.mode == 'eval':
            mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255)))
        else:
            mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0]
        
        
        mask = mask_label

        rmin, rmax, cmin, cmax = get_bbox(mask)

        mask_crop = mask[rmin:rmax, cmin:cmax]

        if self.add_noise:
            img = self.trancolor(img)

        img = np.array(img)[:, :, :3]
        img_masked = img[rmin:rmax, cmin:cmax]

        target_r = np.resize(np.array(meta['cam_R_m2c']), (3, 3))
        cam_scale = 1000.0
        target_t = np.array(meta['cam_t_m2c']) / cam_scale

        depth_crop = depth[rmin:rmax, cmin:cmax, np.newaxis].astype(np.float32)

        
        xmap_masked = self.xmap[rmin:rmax, cmin:cmax, np.newaxis]
        ymap_masked = self.ymap[rmin:rmax, cmin:cmax, np.newaxis]

        pt2 = depth_crop / cam_scale
        pt0 = (xmap_masked - self.cam_cx) * pt2 / self.cam_fx
        pt1 = (ymap_masked - self.cam_cy) * pt2 / self.cam_fy
        depth_xyz = np.concatenate((pt0, pt1, pt2), axis=2)

        
        depth_mask_xyz = depth_xyz * mask_crop[:, :, np.newaxis]
        choose = depth_mask_xyz[:, :, 2].flatten().nonzero()[0]
        mask_x = depth_xyz[:, :, 0].flatten()[choose][:, np.newaxis]
        mask_y = depth_xyz[:, :, 1].flatten()[choose][:, np.newaxis]
        mask_z = depth_xyz[:, :, 2].flatten()[choose][:, np.newaxis]
        mask_xyz = np.concatenate((mask_x, mask_y, mask_z), axis=1)
        mean_xyz = mask_xyz.mean(axis=0).reshape((1, 1, 3))
        xyz = (depth_xyz - mean_xyz) * mask_crop[:, :, np.newaxis]

        target_t = target_t - mean_xyz.reshape(3)

        
        dis_xyz = np.sqrt(xyz[:, :, 0] * xyz[:, :, 0] + xyz[:, :, 1] * xyz[:, :, 1] + xyz[:, :, 2] * xyz[:, :, 2])
        mask_xyz = np.where(dis_xyz > self.obj_radius[idx], 0.0, 1.0).astype(np.float32)
        xyz = xyz * mask_xyz[:, :, np.newaxis]

        
        if self.mode == 'train':
            noise_t = np.asarray([np.random.uniform(-self.noise_trans, self.noise_trans) for i in range(3)]).astype(
                np.float32)
            xyz += noise_t.reshape((1, 1, 3))
            target_t += noise_t.reshape((3))

        
        xyz = xyz / self.obj_radius[idx]

        
        model_points = self.cld[obj].T / cam_scale
        target_xyz = np.matmul(target_r, model_points) + target_t.reshape((3, 1))
        target_xyz = target_xyz / self.obj_radius[idx]
        model_points = model_points / self.obj_radius[idx]
        target_t = target_t / self.obj_radius[idx]

        
        rgb, xyz, mask = resize(img_masked, xyz, mask_crop, self.resize_img_width, self.resize_img_width)
        if np.isnan(xyz).sum()>0 and \
                np.isnan(rgb).sum()>0 and \
                np.isnan(mask).sum()>0 and \
                np.isinf(xyz).sum()>0 and \
                np.isinf(rgb).sum()>0 and \
                np.isinf(mask).sum()>0:
            rgb = np.zeros(rgb.shape, dtype=np.float32)
            xyz = np.zeros(xyz.shape, dtype=np.float32)
            mask = np.zeros(mask.shape, dtype=np.float32)
            print('???????????????????? find nan\inf')

        rgb = torch.from_numpy(rgb.astype(np.float32)).permute(2, 0, 1).contiguous()
        xyz = torch.from_numpy(xyz.astype(np.float32)).permute(2, 0, 1).contiguous()
        if (mask.sum() == 0.0):
            mask = np.ones(mask.shape, dtype=np.float32)
        mask = torch.from_numpy(mask.astype(np.float32)).unsqueeze(dim=0)

        rgb = self.norm(rgb)
        return {
            'rgb': rgb,
            'xyz': xyz,
            'mask': mask,
            'target_r': torch.from_numpy(target_r.astype(np.float32)).view(3, 3),
            'target_t': torch.from_numpy(target_t.astype(np.float32)).view(3),
            'model_xyz': torch.from_numpy(model_points.astype(np.float32)),
            'class_id': torch.LongTensor([int(idx)])}  

    def __len__(self):
        return self.length

    def get_sym_list(self):
        return self.symmetry_obj_idx

    def get_num_points_mesh(self):
        return self.num_pt

class data_prefetcher():
    def __init__(self, loader, device=0):
        
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream(device=device)
        self.preload()

    def preload(self):
        try:
            self.next_input = next(self.loader)
        except StopIteration:
            self.next_input = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input['rgb'] = self.next_input['rgb'].cuda(non_blocking=True).float()
            self.next_input['xyz'] = self.next_input['xyz'].cuda(non_blocking=True).float()
            self.next_input['mask'] = self.next_input['mask'].cuda(non_blocking=True).float()
            
            
            self.next_input['target_xyz'] = self.next_input['target_xyz'].cuda(non_blocking=True).float()
            self.next_input['model_xyz'] = self.next_input['model_xyz'].cuda(non_blocking=True).float()
            
            self.next_input['class_id'] = self.next_input['class_id'].cuda(non_blocking=True).int()
            self.next_input['target_r'] = self.next_input['target_r'].cuda(non_blocking=True).float()
            self.next_input['target_t'] = self.next_input['target_t'].cuda(non_blocking=True).float()
            self.next_input['gt_x'] = self.next_input['gt_x'].cuda(non_blocking=True).float()
            self.next_input['gt_y'] = self.next_input['gt_y'].cuda(non_blocking=True).float()
            self.next_input['gt_z'] = self.next_input['gt_z'].cuda(non_blocking=True).float()

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        input = self.next_input
        self.preload()
        return input

def get_bbox(label):
    img_length = label.shape[1]
    img_width = label.shape[0]
    rows = np.any(label, axis=1)
    cols = np.any(label, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    rmax += 1
    cmax += 1
    r_b = rmax - rmin
    c_b = cmax - cmin
    wid = max(r_b, c_b)

    extend_wid = int(wid / 8)
    center = [int((rmin + rmax) / 2), int((cmin + cmax) / 2)]
    rmin = center[0] - int(wid / 2) - extend_wid
    rmax = center[0] + int(wid / 2) + extend_wid
    cmin = center[1] - int(wid / 2) - extend_wid
    cmax = center[1] + int(wid / 2) + extend_wid

    if rmin < 0:
        delt = -rmin
        rmin = 0
        rmax += delt
    if cmin < 0:
        delt = -cmin
        cmin = 0
        cmax += delt
    if rmax > img_width:
        delt = rmax - img_width
        rmax = img_width
        rmin -= delt
    if cmax > img_length:
        delt = cmax - img_length
        cmax = img_length
        cmin -= delt
    return rmin, rmax, cmin, cmax

def farthest_point_sample(xyz, npoint):
    
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

def resize(rgb, xyz, mask, width, height):
    rgb = torch.from_numpy(rgb.astype(np.float32)).unsqueeze(dim=0).permute(0, 3, 1, 2).contiguous()
    xyz = torch.from_numpy(xyz.astype(np.float32)).unsqueeze(dim=0).permute(0, 3, 1, 2).contiguous()
    mask = torch.from_numpy(mask.astype(np.float32)).unsqueeze(dim=0).unsqueeze(dim=0)

    rgb = F.interpolate(rgb, size=(height, width), mode='bilinear').squeeze(dim=0).permute(1, 2, 0).contiguous()
    xyz = F.interpolate(xyz, size=(height, width), mode='bilinear').squeeze(dim=0).permute(1, 2, 0).contiguous()
    mask = F.interpolate(mask, size=(height, width), mode='nearest').squeeze(dim=0).squeeze(dim=0)
    return rgb.cpu().numpy(), xyz.cpu().numpy(), mask.cpu().numpy()

def ply_vtx(path):
    f = open(path)
    assert f.readline().strip() == "ply"
    f.readline()
    f.readline()
    N = int(f.readline().split()[-1])
    while f.readline().strip() != "end_header":
        continue
    pts = []
    for _ in range(N):
        pts.append(np.float32(f.readline().split()[:3]))
    return np.array(pts)
