# Training to a set of multiple objects (e.g. ShapeNet or DTU)
# tensorboard logs available in logs/<expname>

import sys
import os

sys.path.insert(
    0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))
)

import warnings
import trainlib
from model import make_model, loss
from render import NeRFRenderer, UnisurfRenderer, common
from data import get_split_dataset
import util
import numpy as np
import torch.nn.functional as F
import torch
from dotmap import DotMap



from torchmetrics import MetricCollection
# from torchmetrics import IoU, Accuracy
from metrics import MeanAccuracy, MeanIoU, Accuracy, Chamfer_FScore


def save_pc(PC, PC_color, filename):
    from plyfile import PlyElement, PlyData
    PC = np.concatenate((PC, PC_color), axis=1)
    PC = [tuple(element) for element in PC]
    el = PlyElement.describe(np.array(PC, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]), 'vertex')
    PlyData([el]).write(filename)

def extra_args(parser):
    parser.add_argument(
        "--batch_size", "-B", type=int, default=4, help="Object batch size ('SB')"
    )
    parser.add_argument(
        "--nviews",
        "-V",
        type=str,
        default="1",
        help="Number of source views (multiview); put multiple (comma delim) to pick randomly per batch ('NV')",
    )
    parser.add_argument(
        "--freeze_enc",
        action="store_true",
        default=None,
        help="Freeze encoder weights and only train MLP",
    )
    parser.add_argument(
        "--separate_heads",
        action="store_true",
        default=None,
        help="Use separate heads for rgb, seg, density",
    )

    parser.add_argument(
        "--no_bbox_step",
        type=int,
        default=300000,
        help="Step to stop using bbox sampling",
    )
    parser.add_argument(
        "--fixed_test",
        action="store_true",
        default=None,
        help="Freeze encoder weights and only train MLP",
    )
    parser.add_argument(
        "--calc_metrics",
        action="store_true",
        default=None,
        help="calculate metrics during eval",
    )
    parser.add_argument(
        "--sigmoid_scaling",
        type=float,
        default=1.0,
        help="Model sigmoid scaling",
    )
    return parser


args, conf = util.args.parse_args(extra_args, training=True, default_ray_batch_size=128)

n_gpus = len(args.gpu_id)
device = util.get_cuda(args.gpu_id[0])
# print(device)

dset, val_dset, _ = get_split_dataset(args.dataset_format, args.datadir, load_pc=args.calc_metrics, level=args.level, category=args.category)
n_classes = dset.n_classes
conf["model"]['n_classes'] = n_classes
conf["renderer"]['n_classes'] = n_classes
conf["renderer"]['use_rgb_head'] = conf["model"]['use_rgb_head']
conf["renderer"]['use_seg_head'] = conf["model"]['use_seg_head']

renderer_type = conf["renderer"]["type"].lower()

print(
    "dset z_near {}, z_far {}, lindisp {}".format(dset.z_near, dset.z_far, dset.lindisp)
)

# print(conf["model"])
net = make_model(conf["model"], sigmoid_scaling=args.sigmoid_scaling,
                 separate_heads=args.separate_heads, init_ckpt=args.init_ckpt).to(device=device)
# print(net)
net.stop_encoder_grad = args.freeze_enc
if args.freeze_enc:
    print("Encoder frozen")
    if renderer_type == "nerf":
        net.encoder.eval()
    elif renderer_type == "unisurf":
        net.pixelnerf.encoder.eval()



if renderer_type == "nerf":
    renderer = NeRFRenderer.from_conf(conf["renderer"], lindisp=dset.lindisp,).to(
        device=device
    )
elif renderer_type == "unisurf":
    renderer = UnisurfRenderer.from_conf(conf["renderer"], lindisp=dset.lindisp,).to(
        device=device
    )

# Parallize
render_par = renderer.bind_parallel(net, args.gpu_id)
# print(net.pixelnerf, net.net_3d)
pts_par = renderer.bind_pts_parallel(net, args.gpu_id)
nviews = list(map(int, args.nviews.split(',')))

# import pdb; pdb.set_trace()


class PixelNeRFTrainer(trainlib.Trainer):
    def __init__(self):
        super().__init__(net, dset, val_dset, args, conf["train"], device=device)
        self.renderer_state_path = "%s/%s/_renderer" % (
            self.args.checkpoints_path,
            self.args.name,
        )

        self.lambda_coarse = conf.get_float("loss.lambda_coarse")
        self.lambda_fine = conf.get_float("loss.lambda_fine", 1.0)
        self.lambda_seg = conf.get_float("loss.lambda_seg", 1.0)
        self.lambda_normals = conf.get_float("loss.lambda_normals", 1.0)
        print(
            "lambda coarse {} and fine {}".format(self.lambda_coarse, self.lambda_fine)
        )
        self.use_rgb_head = conf.get_bool("model.use_rgb_head")
        self.use_seg_head = conf.get_bool("model.use_seg_head")
        print(
            "using rgb head: {}, using seg head: {}".format(self.use_rgb_head, self.use_seg_head)
        )

        if self.use_rgb_head:
            self.rgb_coarse_crit = loss.get_rgb_loss(conf["loss.rgb"], True)
            fine_loss_conf = conf["loss.rgb"]
            if "rgb_fine" in conf["loss"]:
                print("using fine loss")
                fine_loss_conf = conf["loss.rgb_fine"]
            self.rgb_fine_crit = loss.get_rgb_loss(fine_loss_conf, False)

        if self.use_seg_head:
            self.seg_coarse_crit = loss.get_seg_loss(conf["loss.seg"])
            self.seg_fine_crit = loss.get_seg_loss(conf["loss.seg"])

        if args.resume:
            if os.path.exists(self.renderer_state_path):
                renderer.load_state_dict(
                    torch.load(self.renderer_state_path, map_location=device)
                )

        self.z_near = dset.z_near
        self.z_far = dset.z_far

        self.use_bbox = args.no_bbox_step > 0
        if self.args.calc_metrics:
            # metrics = MetricCollection({
            #                             'acc': Accuracy(ignore_index=0, dist_sync_on_step=True),
            #                             'macc': Accuracy(ignore_index=0, num_classes=n_classes, average='macro', dist_sync_on_step=True),
            #                             'miou': IoU(ignore_index=0, num_classes=n_classes, dist_sync_on_step=True),
            #                             })
            metrics = MetricCollection({
                                        'acc': Accuracy.Accuracy(dist_sync_on_step=True),
                                        'macc': MeanAccuracy.MeanAccuracy(num_classes=n_classes-1, dist_sync_on_step=True),
                                        'miou': MeanIoU.MeanIoU(num_classes=n_classes-1, dist_sync_on_step=True),
                                        })
            self.metrics_2D = metrics.clone(prefix='2D_').to(device)
            self.metrics_3D = metrics.clone(prefix='3D_').to(device)
            self.metrics_3D_rec = MetricCollection({
                                        '3D_L2': Chamfer_FScore.Chamfer_FScore(p=2, dist_sync_on_step=True),
                                        '3D_L1': Chamfer_FScore.Chamfer_FScore(p=1, dist_sync_on_step=True),
                                        }).to(device)

    def post_batch(self, epoch, batch):
        renderer.sched_step(args.batch_size)

    # def post_epoch(self, epoch):
    #     self.train_metrics.reset()

    def extra_save_state(self):
        torch.save(renderer.state_dict(), self.renderer_state_path)


    def reset_metrics(self):
        self.metrics_2D.reset()
        self.metrics_3D.reset()

    def compute_metrics(self):
        all_metrics = {}
        metrics_2D = self.metrics_2D.compute()
        all_metrics.update(metrics_2D)
        metrics_3D = self.metrics_3D.compute()
        all_metrics.update(metrics_3D)
        return all_metrics

    def calc_losses(self, data, is_train=True, global_step=0):
        if "images" not in data:
            return {}
        all_images = data["images"].to(device=device)  # (SB, NV, 3, H, W)

        SB, NV, _, H, W = all_images.shape
        all_poses = data["poses"].to(device=device)  # (SB, NV, 4, 4)
        all_bboxes = data.get("bbox")  # (SB, NV, 4)  cmin rmin cmax rmax
        all_focals = data["focal"]  # (SB)
        all_c = data.get("c")  # (SB)
        all_labels = data["labels"].to(device=device)
        if self.args.calc_metrics:
            all_pts = data["pts"].to(device=device)
            all_pts_labels = data["pts_labels"].to(device=device)

        if self.use_bbox and global_step >= args.no_bbox_step:
            self.use_bbox = False
            print(">>> Stopped using bbox sampling @ iter", global_step)

        if not is_train or not self.use_bbox:
            all_bboxes = None

        all_rgb_gt = []
        all_seg_gt = []
        all_rays = []

        
        if is_train or not args.calc_metrics:
            curr_nviews = nviews[torch.randint(0, len(nviews), ()).item()]
            if curr_nviews == 1:
                image_ord = torch.randint(0, NV, (SB, 1))
            else:
                image_ord = torch.empty((SB, curr_nviews), dtype=torch.long)
        else:
            curr_nviews=1
            image_ord = torch.full((SB, 1), 192, dtype=torch.long)
        for obj_idx in range(SB):
            if all_bboxes is not None:
                bboxes = all_bboxes[obj_idx]
            images = all_images[obj_idx]  # (NV, 3, H, W)
            poses = all_poses[obj_idx]  # (NV, 4, 4)
            focal = all_focals[obj_idx]
            label = all_labels[obj_idx]

            c = None
            if "c" in data:
                c = data["c"][obj_idx]
            if curr_nviews > 1:
                # Somewhat inefficient, don't know better way
                image_ord[obj_idx] = torch.from_numpy(
                    np.random.choice(NV, curr_nviews, replace=False)
                )
            images_0to1 = images * 0.5 + 0.5
            rgb_gt_all = images_0to1
            rgb_gt_all = (
                rgb_gt_all.permute(0, 2, 3, 1).contiguous().reshape(-1, 3)
            )  # (NV, H, W, 3)
            #TODO: check dim labels
            seg_gt_all = (torch.squeeze(label).reshape(-1))
            if all_bboxes is not None:
                pix = util.bbox_sample(bboxes, args.ray_batch_size)
                pix_inds = pix[..., 0] * H * W + pix[..., 1] * W + pix[..., 2]
            else:
                # bboxes = []
                # pix = util.bbox_sample(bboxes, args.ray_batch_size)
                pix_inds = torch.randint(0, NV * H * W, (args.ray_batch_size,))
            if renderer_type == "nerf":
                cam_rays = util.gen_rays(
                    poses, W, H, focal, self.z_near, self.z_far, c=c
                )  # (NV, H, W, 8)
                rays = cam_rays.view(-1, cam_rays.shape[-1])[pix_inds].to(
                    device=device
                )  # (ray_batch_size, 8)
                all_rays.append(rays)
            elif renderer_type == "unisurf":
                net.target_poses = poses
                # print(pix_inds.shape)
                # print(pix.shape)

                pix = pix_inds.clone().unsqueeze(-1).repeat(1,3)
                pix[:,0] = torch.floor(pix[:,0] / (H * W)).int()
                pix[:,1] = torch.floor((pix[:,1] % (H * W))/W).int()
                pix[:,2] = pix[:,2] % W

                all_rays.append(pix.to(device))
                pix_inds = pix[..., 0] * H * W + (127-pix[..., 2]) * W + pix[..., 1]
            rgb_gt = rgb_gt_all[pix_inds]  # (ray_batch_size, 3)
            seg_gt = seg_gt_all[pix_inds]
            all_rgb_gt.append(rgb_gt)
            all_seg_gt.append(seg_gt)
            
            
            

        all_rgb_gt = torch.stack(all_rgb_gt)  # (SB, ray_batch_size, 3)
        all_seg_gt = torch.stack(all_seg_gt)  # (SB, ray_batch_size, 3)
        all_rays = torch.stack(all_rays)  # (SB, ray_batch_size, 8)
        

        image_ord = image_ord.to(device)
        src_images = util.batched_index_select_nd(
            all_images, image_ord
        )  # (SB, NS, 3, H, W)
        src_poses = util.batched_index_select_nd(all_poses, image_ord)  # (SB, NS, 4, 4)

        all_bboxes = all_poses = all_images = None

        all_focals = all_focals.to(device=device)
        all_c = all_c.to(device=device)
        if renderer_type == "nerf":
            net.encode(
                src_images,
                src_poses,
                all_focals,
                c=all_c if all_c is not None else None,
            )
            render_dict = DotMap(render_par(all_rays, want_weights=True))
        elif renderer_type == "unisurf":
            render_dict = DotMap(render_par(all_rays, src_images, src_poses,
                                        all_focals, all_c if all_c is not None else None,
                                        want_weights=True, it=global_step, gt_img=rgb_gt_all))
        coarse = render_dict.coarse
        fine = render_dict.fine
        using_fine = len(fine) > 0
        # import pdb; pdb.set_trace()

        loss_dict = {}
        loss = 0
        if self.use_rgb_head:
            rgb_loss = self.rgb_coarse_crit(coarse.rgb, all_rgb_gt) * self.lambda_coarse
            loss_dict["rc"] = rgb_loss.item()
            if using_fine:
                fine_loss = self.rgb_fine_crit(fine.rgb, all_rgb_gt) * self.lambda_fine
                rgb_loss += fine_loss
                loss_dict["rf"] = fine_loss.item()
            loss += rgb_loss
        if self.use_seg_head:
            if 'n_classes' in data:
                seg_loss = 0.
                for i, seg in enumerate(coarse.seg):
                    s = data['cat_start_class'][i].int().item()
                    n = data['n_classes'][i].int().item()
                    cur_seg_flat = seg.reshape(-1,n_classes)
                    # import pdb; pdb.set_trace()
                    cur_seg_logits = torch.cat([cur_seg_flat[...,:1], cur_seg_flat[...,s:s+n]], axis=-1)
                    seg_loss += self.seg_coarse_crit(cur_seg_logits,
                                            all_seg_gt[i].reshape(-1).long()) * self.lambda_coarse * self.lambda_seg
                seg_loss /= len(coarse.seg)
            else:
                seg_loss = self.seg_coarse_crit(coarse.seg.reshape(-1,n_classes),
                                            all_seg_gt.reshape(-1).long()) * self.lambda_coarse * self.lambda_seg
            loss_dict["sc"] = seg_loss.item()
            if using_fine:
                if 'n_classes' in data:
                    seg_fine_loss = 0
                    for i, seg in enumerate(fine.seg):
                        s = data['cat_start_class'][i].int().item()
                        n = data['n_classes'][i].int().item()
                        cur_seg_flat = seg.reshape(-1,n_classes)
                        cur_seg_logits = torch.cat([cur_seg_flat[...,:1], cur_seg_flat[...,s:s+n]], axis=-1)
                        seg_fine_loss += self.seg_coarse_crit(cur_seg_logits,
                                                all_seg_gt[i].reshape(-1).long()) * self.lambda_fine * self.lambda_seg
                    seg_fine_loss /= len(fine.seg)
                else:
                    seg_fine_loss = self.seg_fine_crit(fine.seg.reshape(-1,n_classes),
                                                all_seg_gt.reshape(-1).long()) * self.lambda_fine * self.lambda_seg
                seg_loss += seg_fine_loss
                loss_dict["sf"] = seg_fine_loss.item() 
            loss += seg_loss

        if renderer_type == "unisurf":
            diff_norm = coarse.diff_norm
            # print(diff_norm)
            if diff_norm is None or diff_norm.shape[0]==0:
                normals_diff_loss = torch.tensor(0.0).to(device).float()
            else:
                normals_diff_loss = diff_norm.mean()
            loss_dict["nc"] = self.lambda_normals * normals_diff_loss.item()
            loss += normals_diff_loss
        if is_train:
            # import pdb; pdb.set_trace()
            loss.backward()
        # if loss.isnan():
            # print([n for n in net.parameters()])
            # print([n.grad for n in net.parameters()])
            # print(loss, rgb_loss, seg_loss, normals_diff_loss)
            # import pdb; pdb.set_trace()
            # assert(True == False)
        loss_dict["t"] = loss.item()
        # import pdb; pdb.set_trace()
        # del loss

        if self.args.calc_metrics and not is_train:
            with torch.no_grad():
                #2D Metrics
                if self.use_seg_head:
                    if using_fine:
                        seg = fine.seg
                    else:
                        seg = coarse.seg
                    if 'n_classes' in data:
                        all_gt = []
                        all_seg = []
                        for i in range(len(data['n_classes'])):
                            s = data['cat_start_class'][i].int().item()
                            n = data['n_classes'][i].int().item()
                            cur_seg_gt = all_seg_gt[i].reshape(-1).long()
                            not_background_2D = cur_seg_gt != 0
                            cur_seg = seg[i]
                            cur_seg = cur_seg[...,s:s+n].argmax(-1).reshape(-1).long()

                            cur_seg = cur_seg[not_background_2D]
                            cur_seg_gt = cur_seg_gt[not_background_2D] - 1
                            all_seg.append(cur_seg)
                            all_gt.append(cur_seg_gt)
                        seg = torch.cat(all_seg, dim=0)
                        all_seg_gt = torch.cat(all_gt, dim=0)
                    else:
                        all_seg_gt = all_seg_gt.reshape(-1).long()
                        not_background_2D = all_seg_gt != 0
                        # if not_background_2D.any():
                        seg = seg[...,1:].argmax(-1).reshape(-1).long()
                        seg = seg[not_background_2D]
                        all_seg_gt = all_seg_gt[not_background_2D] - 1
                            # print(seg.shape)
                    if seg.numel() > 0:
                        metrics_2D = self.metrics_2D(seg, all_seg_gt)
                        metrics_2D = {'shape_'+k: v.item() for (k, v) in metrics_2D.items()}
                        loss_dict.update(metrics_2D)

                    if renderer_type == "nerf":
                        out = pts_par(all_pts, coarse=False, viewdirs=torch.zeros_like(all_pts))
                    elif renderer_type == "unisurf":
                        out = pts_par(all_pts, src_images, src_poses,
                                        all_focals, all_c if all_c is not None else None,
                                        coarse=False, viewdirs=torch.zeros_like(all_pts))
                    # print(out.shape)
                    if 'n_classes' in data:
                        seg = []
                        for i in range(len(data['n_classes'])):
                            s = data['cat_start_class'][i].int().item()
                            n = data['n_classes'][i].int().item()
                            cur_seg = out[i]
                            cur_seg = cur_seg[...,s:s+n].argmax(-1).reshape(-1).long()
                            # print(seg.shape)
                            seg.append(cur_seg)
                        seg = torch.cat(seg, dim=0)
                    else:
                        seg = out[...,-n_classes+1:].argmax(-1).reshape(-1)
                    metrics_3D = self.metrics_3D(seg,
                                                (all_pts_labels-1).reshape(-1).long())
                    metrics_3D = {'shape_'+k: v.item() for (k, v) in metrics_3D.items()}
                    loss_dict.update(metrics_3D)


                N = 64
                t = torch.linspace(-1.0, 1.0, N, device=device, dtype=torch.float32)
                query_pts = torch.stack(torch.meshgrid(t, t, t), -1)
                all_rays_pts = query_pts.reshape(-1,3).unsqueeze(0).repeat(all_pts.shape[0], 1, 1)
                # print(query_pts.shape)
                rays_spl = torch.split(all_rays_pts, 80000, dim=1)
                dense_sigma = []
                dense_seg = []
                for rays in rays_spl:
                    if renderer_type == "nerf":
                        out = pts_par(rays, coarse=False, viewdirs=torch.zeros_like(rays))
                    elif renderer_type == "unisurf":
                        out = pts_par(rays, src_images, src_poses,
                                        all_focals, all_c if all_c is not None else None,
                                        coarse=False, viewdirs=torch.zeros_like(rays))
                    # print(out, out.shape)
                    sigma = out[...,0]
                    dense_sigma.append(sigma)
                    if self.use_seg_head:
                        if 'n_classes' in data:
                            seg = []
                            for i in range(len(data['n_classes'])):
                                s = data['cat_start_class'][i].int().item()
                                n = data['n_classes'][i].int().item()
                                cur_seg = out[i]
                                cur_seg = torch.cat([cur_seg[...,:1], cur_seg[...,s:s+n]], axis=-1).argmax(-1).unsqueeze(0)
                                seg.append(cur_seg)
                            seg = torch.cat(seg, dim=0)
                        else:
                            seg = out[...,-n_classes:].argmax(-1)
                        dense_seg.append(seg)
                dense_sigma = torch.cat(dense_sigma, dim=1)
                if renderer_type == 'nerf':
                    density_th = 1
                elif renderer_type == "unisurf":
                    density_th = 0.5
                valid_pts = dense_sigma > density_th
                if self.use_seg_head:
                    dense_seg = torch.cat(dense_seg, dim=1)
                    valid_pts = torch.logical_and(dense_seg != 0, valid_pts)
                # print(valid_pts.shape, query_pts.shape, all_rays_pts.shape)
                dense_split = [all_rays_pts[i, valid_pts[i]] for i in range(all_rays_pts.shape[0])]
                # print([d.shape for d in dense_split])

                metrics_3D_rec = self.metrics_3D_rec(dense_split, all_pts)
                metrics_3D_rec = {k_t: v_t.item() for (k, v) in metrics_3D_rec.items() for k_t, v_t in zip([k+'_'+m for m in self.metrics_3D_rec['3D_L1'].returned_metrics()], v)}
                # print(metrics_3D_rec)
                # metrics_3D_rec = {k: v.item() for v_t in v for (k, v) in metrics_3D_rec.items()}
                loss_dict.update(metrics_3D_rec)

        return loss_dict

    def train_step(self, data, global_step):
        loss_dict = self.calc_losses(data, is_train=True, global_step=global_step)
        return loss_dict

    def eval_step(self, data, global_step):
        renderer.eval()
        loss_dict = self.calc_losses(data, is_train=False, global_step=global_step)
        renderer.train()
        return loss_dict

    def vis_step(self, data, global_step, idx=None):
        if "images" not in data:
            return {}
        if idx is None:
            batch_idx = np.random.randint(0, data["images"].shape[0])
        else:
            print(idx)
            batch_idx = idx
        images = data["images"][batch_idx].to(device=device)  # (NV, 3, H, W)
        labels = data["labels"][batch_idx].to(device=device)  # (NV, 1, H, W)
        poses = data["poses"][batch_idx].to(device=device)  # (NV, 4, 4)
        focal = data["focal"][batch_idx : batch_idx + 1]  # (1)
        # if self.args.calc_metrics:
        #     all_pts = data["pts"].to(device=device)
        #     all_pts_labels = data["pts_labels"].to(device=device)

        c = data.get("c")
        if c is not None:
            c = c[batch_idx : batch_idx + 1]  # (1)
        NV, _, H, W = images.shape
  
        if renderer_type == "nerf":
            cam_rays = util.gen_rays(
                poses, W, H, focal, self.z_near, self.z_far, c=c
            )  # (NV, H, W, 8)

        elif renderer_type == "unisurf":
            net.target_poses = poses
            # import pdb; pdb.set_trace()
            mesh = torch.meshgrid(torch.arange(NV), torch.arange(H), torch.arange(W))
            cam_rays = torch.stack(mesh, axis=-1).to(device)
            # print(cam_rays.shape)
        images_0to1 = images * 0.5 + 0.5  # (NV, 3, H, W)

        curr_nviews = nviews[torch.randint(0, len(nviews), (1,)).item()]
        views_src = np.sort(np.random.choice(NV, curr_nviews, replace=False))
        view_dest = np.random.randint(0, NV - curr_nviews)
        for vs in range(curr_nviews):
            view_dest += view_dest >= views_src[vs]
        views_src = torch.from_numpy(views_src)

        # set renderer net to eval mode
        renderer.eval()
        source_views = (
            images_0to1[views_src]
            .permute(0, 2, 3, 1)
            .cpu()
            .numpy()
            .reshape(-1, H, W, 3)
        )

        gt = images_0to1[view_dest].permute(1, 2, 0).cpu().numpy().reshape(H, W, 3)
        seg_gt = labels[view_dest].permute(1, 2, 0).cpu().numpy().reshape(H, W, 1)
        with torch.no_grad():
            test_rays = cam_rays[view_dest]  # (H, W, 8)
            # print(test_rays.shape)
            test_images = images[views_src]  # (NS, 3, H, W)
            
            test_rays = test_rays.reshape(1, H * W, -1)
            # print(test_rays.shape)
            # import pdb; pdb.set_trace()
            src_images = test_images.unsqueeze(0)
            src_poses = poses[views_src].unsqueeze(0)
            all_focals = focal.to(device=device)
            all_rays = test_rays
            all_c = c.to(device=device)
            if renderer_type == "nerf":
                net.encode(
                    src_images,
                    src_poses,
                    all_focals,
                    c=all_c if all_c is not None else None,
                )
                render_dict = DotMap(render_par(all_rays, want_weights=True))
            elif renderer_type == "unisurf":
                # all_rays = all_rays.repeat(2)
                render_dict = DotMap(render_par(all_rays.repeat_interleave(n_gpus,dim=0),
                                                src_images.repeat_interleave(n_gpus,dim=0),
                                                src_poses.repeat_interleave(n_gpus,dim=0),
                                                all_focals.repeat_interleave(n_gpus,dim=0),
                                                all_c.repeat_interleave(n_gpus,dim=0) if all_c is not None else None,
                                                want_weights=True, it=global_step))
            # net.encode(
            #     test_images.unsqueeze(0),
            #     poses[views_src].unsqueeze(0),
            #     focal.to(device=device),
            #     c=c.to(device=device) if c is not None else None,
            # )
            # render_dict = DotMap(render_par(test_rays, want_weights=True))
            coarse = render_dict.coarse
            fine = render_dict.fine

            using_fine = len(fine) > 0

            alpha_coarse_np = coarse.weights[0].sum(dim=-1).cpu().numpy().reshape(H, W)
            print(
                "c alpha min {}, max {}".format(
                    alpha_coarse_np.min(), alpha_coarse_np.max()
                )
            )
            depth_coarse_np = coarse.depth[0].cpu().numpy().reshape(H, W)
            if renderer_type == "unisurf":
                alpha_coarse_np = alpha_coarse_np.transpose(1,0)
                alpha_coarse_np = alpha_coarse_np[::-1]
                depth_coarse_np = depth_coarse_np.transpose(1,0)
                depth_coarse_np = depth_coarse_np[::-1]
            alpha_coarse_cmap = util.cmap(alpha_coarse_np) / 255
            depth_coarse_cmap = util.cmap(depth_coarse_np) / 255

            if self.use_rgb_head:

                rgb_coarse_np = coarse.rgb[0].cpu().numpy().reshape(H, W, 3)
                # import pdb; pdb.set_trace()
                if renderer_type == "unisurf":
                    rgb_coarse_np = rgb_coarse_np.transpose(1,0,2)
                    rgb_coarse_np = rgb_coarse_np[::-1]
                rgb_psnr = rgb_coarse_np
                print("c rgb min {} max {}".format(rgb_coarse_np.min(), rgb_coarse_np.max()))
            if self.use_seg_head:
                if 'n_classes' in data:
                    s = data['cat_start_class'][batch_idx].int().item()
                    n = data['n_classes'][batch_idx].int().item()
                    # print(coarse.seg.shape)
                    seg_coarse_np = torch.cat([coarse.seg[0,...,:1], coarse.seg[0,...,s:s+n]], axis=-1).argmax(-1).long()
                    seg_coarse_np = seg_coarse_np.reshape(H, W, 1).cpu().numpy()
                else:
                    seg_coarse_np = coarse.seg[0].reshape(H, W, -1).argmax(axis=-1, keepdim=True).cpu().numpy()
                if renderer_type == "unisurf":
                    seg_coarse_np = seg_coarse_np.transpose(1,0,2)
                    seg_coarse_np = seg_coarse_np[::-1]
                seg_coarse_cmap = util.cmap(seg_coarse_np) / 255
            if using_fine:
                alpha_fine_np = fine.weights[0].sum(dim=1).cpu().numpy().reshape(H, W)
                print(
                    "f alpha min {}, max {}".format(
                        alpha_fine_np.min(), alpha_fine_np.max()
                    )
                )
                depth_fine_np = fine.depth[0].cpu().numpy().reshape(H, W)
                alpha_fine_cmap = util.cmap(alpha_fine_np) / 255
                depth_fine_cmap = util.cmap(depth_fine_np) / 255
                if self.use_rgb_head:
                    rgb_fine_np = fine.rgb[0].cpu().numpy().reshape(H, W, 3)
                    rgb_psnr = rgb_fine_np
                    print("f rgb min {} max {}".format(rgb_fine_np.min(), rgb_fine_np.max()))
                if self.use_seg_head:
                    if 'n_classes' in data:
                        s = data['cat_start_class'][batch_idx].int().item()
                        n = data['n_classes'][batch_idx].int().item()
                        seg_fine_np = torch.cat([fine.seg[0,...,:1], fine.seg[0,...,s:s+n]], axis=-1).argmax(-1).long()
                        seg_fine_np = seg_fine_np.reshape(H, W, 1).cpu().numpy()
                    else:
                        seg_fine_np = fine.seg[0].reshape(H, W, -1).argmax(axis=-1, keepdim=True).cpu().numpy()
                    seg_fine_cmap = util.cmap(seg_fine_np) / 255

        seg_gt_cmap = util.cmap(seg_gt) / 255
        vis_list = [
            *source_views,
            gt,
            depth_coarse_cmap,
            alpha_coarse_cmap,
            
        ]
        # if renderer_type == "nerf":

        vals = {}
        if self.use_rgb_head:
            vis_list.append(rgb_coarse_np)
            psnr = util.psnr(rgb_psnr, gt)
            vals["psnr"] =  psnr
            print("psnr", psnr)
        if self.use_seg_head:
            vis_list.extend([
                             seg_coarse_cmap,
                             seg_gt_cmap,
                            ])

        vis_coarse = np.hstack(vis_list)
        vis = vis_coarse
        if using_fine:
            
            vis_list = [
                *source_views,
                gt,
                depth_fine_cmap,
                alpha_fine_cmap,
            ]
            if self.use_rgb_head:
                vis_list.append(rgb_fine_np)
            if self.use_seg_head:
                vis_list.extend([
                                 seg_fine_cmap,
                                 seg_gt_cmap,
                                ])

            vis_fine = np.hstack(vis_list)
            vis = np.vstack((vis_coarse, vis_fine))

        # set the renderer network back to train mode
        renderer.train()
        return vis, vals


trainer = PixelNeRFTrainer()
trainer.start()
