"""
Main model implementation
"""
import torch
import torch.nn.functional as F
from .encoder import ImageEncoder
from .code import PositionalEncoding
from .model_util import make_encoder, make_mlp
import torch.nn as nn

import torch.autograd.profiler as profiler
from util import repeat_interleave
import os
import os.path as osp
import warnings
import numpy as np
import copy


class PixelNeRFNet(torch.nn.Module):
    def __init__(self, conf, stop_encoder_grad=False, separate_heads=False, init_ckpt=None, uniseg=False, **kwargs):
        """
        :param conf PyHocon config subtree 'model'
        """
        super().__init__()
        # import pdb; pdb.set_trace()
        self.encoder = make_encoder(conf["encoder"])
        self.use_encoder = conf.get_bool("use_encoder", True)  # Image features?

        self.use_xyz = conf.get_bool("use_xyz", False)
        self.init_ckpt = init_ckpt

        assert self.use_encoder or self.use_xyz  # Must use some feature..
        self.uniseg=uniseg

        # Whether to shift z to align in canonical frame.
        # So that all objects, regardless of camera distance to center, will
        # be centered at z=0.
        # Only makes sense in ShapeNet-type setting.
        self.normalize_z = conf.get_bool("normalize_z", True)

        self.stop_encoder_grad = (
            stop_encoder_grad  # Stop ConvNet gradient (freeze weights)
        )
        self.use_code = conf.get_bool("use_code", False)  # Positional encoding
        self.use_code_viewdirs = conf.get_bool(
            "use_code_viewdirs", True
        )  # Positional encoding applies to viewdirs

        # Enable view directions
        self.use_viewdirs = conf.get_bool("use_viewdirs", False)

        # Global image features?
        self.use_global_encoder = conf.get_bool("use_global_encoder", False)

        d_latent = self.encoder.latent_size if self.use_encoder else 0
        d_in = 3 if self.use_xyz else 1

        if self.use_viewdirs and self.use_code_viewdirs:
            # Apply positional encoding to viewdirs
            d_in += 3
        if self.use_code and d_in > 0:
            # Positional encoding for x,y,z OR view z
            self.code = PositionalEncoding.from_conf(conf["code"], d_in=d_in)
            d_in = self.code.d_out
        if self.use_viewdirs and not self.use_code_viewdirs:
            # Don't apply positional encoding to viewdirs (concat after encoded)
            d_in += 3

        if self.use_global_encoder:
            # Global image feature
            self.global_encoder = ImageEncoder.from_conf(conf["global_encoder"])
            self.global_latent_size = self.global_encoder.latent_size
            d_latent += self.global_latent_size

        self.use_rgb_head = conf.get_bool("use_rgb_head", True)
        self.use_seg_head = conf.get_bool("use_seg_head", True)
        self.use_feat_head = conf.get_bool("use_feat_head", False)
        self.n_classes = conf.n_classes
        self.n_feat =conf.n_feat
        #TODO: d_out
        d_out = 1
        if self.use_rgb_head:
            d_out += 3
        if self.use_seg_head:
            d_out += self.n_classes
        if self.use_feat_head:
            d_out += self.n_feat

        self.separate_heads = separate_heads
        self.mlp_coarse = torch.nn.ModuleList()
        self.mlp_fine = torch.nn.ModuleList()
        if self.separate_heads:
            self.mlp_coarse.append(make_mlp(conf["mlp_coarse"], d_in, d_latent, d_out=1)) # density head
            self.mlp_fine.append(make_mlp(conf["mlp_fine"], d_in, d_latent, d_out=1, allow_empty=True)) # density head
            if self.use_rgb_head:
                self.mlp_coarse.append(make_mlp(conf["mlp_coarse"], d_in, d_latent, d_out=3)) # rgb head
                self.mlp_fine.append(make_mlp(conf["mlp_fine"], d_in, d_latent, d_out=3, allow_empty=True)) # rgb head
            if self.use_seg_head:
                self.mlp_coarse.append(make_mlp(conf["mlp_coarse"], d_in, d_latent, d_out=self.n_classes)) # seg head
                self.mlp_fine.append(make_mlp(conf["mlp_fine"], d_in, d_latent, d_out=self.n_classes, allow_empty=True)) # seg head
            if self.use_feat_head:
                self.mlp_coarse.append(make_mlp(conf["mlp_coarse"], d_in, d_latent, d_out=self.n_feat)) # feat head
                self.mlp_fine.append(make_mlp(conf["mlp_fine"], d_in, d_latent, d_out=self.n_feat, allow_empty=True)) # feat head
        else:
            # self.mlp_coarse.append(make_mlp(conf["mlp_coarse"], d_in, d_latent, d_out=d_out))
            # self.mlp_fine.append(make_mlp(conf["mlp_fine"], d_in, d_latent, d_out=d_out, allow_empty=True))
            self.mlp_coarse = make_mlp(
                conf["mlp_coarse"], d_in, d_latent, d_out=d_out
            )
            self.mlp_fine = make_mlp(
                conf["mlp_fine"], d_in, d_latent, d_out=d_out, allow_empty=True
            )

        self.latent_size = self.encoder.latent_size
        
        # self.mlp_fine = make_mlp(
        #     conf["mlp_fine"], d_in, d_latent, d_out=d_out, allow_empty=True
        # )
        # Note: this is world -> camera, and bottom row is omitted
        self.register_buffer("poses", torch.empty(1, 3, 4), persistent=False)
        self.register_buffer("image_shape", torch.empty(2), persistent=False)

        self.d_in = d_in
        self.d_out = d_out
        self.d_latent = d_latent
        self.register_buffer("focal", torch.empty(1, 2), persistent=False)
        # Principal point
        self.register_buffer("c", torch.empty(1, 2), persistent=False)

        self.num_objs = 0
        self.num_views_per_obj = 1
        self.bckgd_pred = conf.get_string("bckgd_pred", "mlp")

    def encode(self, images, poses, focal, z_bounds=None, c=None):
        """
        :param images (NS, 3, H, W)
        NS is number of input (aka source or reference) views
        :param poses (NS, 4, 4)
        :param focal focal length () or (2) or (NS) or (NS, 2) [fx, fy]
        :param z_bounds ignored argument (used in the past)
        :param c principal point None or () or (2) or (NS) or (NS, 2) [cx, cy],
        default is center of image
        """
        # print(id(self))
        # print(images.device, poses.device, focal.device)
        self.num_objs = images.size(0)
        if len(images.shape) == 5:
            assert len(poses.shape) == 4
            assert poses.size(1) == images.size(
                1
            )  # Be consistent with NS = num input views
            self.num_views_per_obj = images.size(1)
            images = images.reshape(-1, *images.shape[2:])
            poses = poses.reshape(-1, 4, 4)
        else:
            self.num_views_per_obj = 1

        self.encoder(images)
        rot = poses[:, :3, :3].transpose(1, 2)  # (B, 3, 3)
        trans = -torch.bmm(rot, poses[:, :3, 3:])  # (B, 3, 1)
        self.poses = torch.cat((rot, trans), dim=-1)  # (B, 3, 4)

        self.image_shape[0] = images.shape[-1]
        self.image_shape[1] = images.shape[-2]

        # Handle various focal length/principal point formats
        if len(focal.shape) == 0:
            # Scalar: fx = fy = value for all views
            focal = focal[None, None].repeat((1, 2))
        elif len(focal.shape) == 1:
            # Vector f: fx = fy = f_i *for view i*
            # Length should match NS (or 1 for broadcast)
            focal = focal.unsqueeze(-1).repeat((1, 2))
        else:
            focal = focal.clone()
        self.focal = focal.float()
        self.focal[..., 1] *= -1.0

        if c is None:
            # Default principal point is center of image
            c = (self.image_shape * 0.5).unsqueeze(0)
        elif len(c.shape) == 0:
            # Scalar: cx = cy = value for all views
            c = c[None, None].repeat((1, 2))
        elif len(c.shape) == 1:
            # Vector c: cx = cy = c_i *for view i*
            c = c.unsqueeze(-1).repeat((1, 2))
        self.c = c

        if self.use_global_encoder:
            self.global_encoder(images)

    def forward(self, xyz, coarse=True, viewdirs=None, far=False, sigma_activations=True, no_activations=False):
        """
        Predict (r, g, b, sigma) at world space points xyz.
        Please call encode first!
        :param xyz (SB, B, 3)
        SB is batch of objects
        B is batch of points (in rays)
        NS is number of input views
        :return (SB, B, 4) r g b sigma
        """
        with profiler.record_function("model_inference"):
            SB, B, _ = xyz.shape
            NS = self.num_views_per_obj

            # import pdb; pdb.set_trace()

            # Transform query points into the camera spaces of the input views
            xyz = repeat_interleave(xyz, NS)  # (SB*NS, B, 3)
            xyz_rot = torch.matmul(self.poses[:, None, :3, :3], xyz.unsqueeze(-1))[
                ..., 0
            ]
            xyz = xyz_rot + self.poses[:, None, :3, 3]
            # import pdb; pdb.set_trace()
            if self.d_in > 0:
                # * Encode the xyz coordinates
                if self.use_xyz:
                    if self.normalize_z:
                        z_feature = xyz_rot.reshape(-1, 3)  # (SB*NS*B, 3)
                    else:
                        z_feature = xyz.reshape(-1, 3)  # (SB*NS*B, 3)
                else:
                    if self.normalize_z:
                        z_feature = -xyz_rot[..., 2].reshape(-1, 1)  # (SB*NS*B, 1)
                    else:
                        z_feature = -xyz[..., 2].reshape(-1, 1)  # (SB*NS*B, 1)

                if self.use_code and not self.use_code_viewdirs:
                    # Positional encoding (no viewdirs)
                    z_feature = self.code(z_feature)

                if self.use_viewdirs:
                    # * Encode the view directions
                    assert viewdirs is not None
                    # Viewdirs to input view space
                    viewdirs = viewdirs.reshape(SB, B, 3, 1)
                    viewdirs = repeat_interleave(viewdirs, NS)  # (SB*NS, B, 3, 1)
                    viewdirs = torch.matmul(
                        self.poses[:, None, :3, :3], viewdirs
                    )  # (SB*NS, B, 3, 1)
                    viewdirs = viewdirs.reshape(-1, 3)  # (SB*B, 3)
                    z_feature = torch.cat(
                        (z_feature, viewdirs), dim=1
                    )  # (SB*B, 4 or 6)

                if self.use_code and self.use_code_viewdirs:
                    # Positional encoding (with viewdirs)
                    z_feature = self.code(z_feature)

                mlp_input = z_feature

            if self.use_encoder:
                # import pdb; pdb.set_trace()
                # Grab encoder's latent code.
                uv = -xyz[:, :, :2] / xyz[:, :, 2:]  # (SB*NS, B, 2)
                # print(uv, self.focal, NS)
                uv *= repeat_interleave(
                    self.focal.unsqueeze(1), NS if self.focal.shape[0] > 1 else 1
                )
                uv += repeat_interleave(
                    self.c.unsqueeze(1), NS if self.c.shape[0] > 1 else 1
                )  # (SB*NS, B, 2)
                latent = self.encoder.index(
                    uv, None, self.image_shape, custom_sample=self.uniseg
                )  # (SB * NS, latent, B)

                if self.stop_encoder_grad:
                    latent = latent.detach()
                # print(uv.shape, latent.shape, xyz.shape, id(self))
                latent = latent.transpose(1, 2).reshape(
                    -1, self.latent_size
                )  # (SB * NS * B, latent)

                if self.d_in == 0:
                    # z_feature not needed
                    mlp_input = latent
                else:
                    mlp_input = torch.cat((latent, z_feature), dim=-1)

            if self.use_global_encoder:
                # Concat global latent code if enabled
                global_latent = self.global_encoder.latent
                assert mlp_input.shape[0] % global_latent.shape[0] == 0
                num_repeats = mlp_input.shape[0] // global_latent.shape[0]
                global_latent = repeat_interleave(global_latent, num_repeats)
                mlp_input = torch.cat((global_latent, mlp_input), dim=-1)

            # Camera frustum culling stuff, currently disabled
            combine_index = None
            dim_size = None

            # Run main NeRF network
            # print(mlp_input.shape)
            
            if self.separate_heads:
                if coarse or self.mlp_fine[0] is None:
                    mlp_output = torch.cat([head(
                        mlp_input,
                        combine_inner_dims=(self.num_views_per_obj, B),
                        combine_index=combine_index,
                        dim_size=dim_size,
                    ) for head in self.mlp_coarse], dim=-1)
                else:
                    mlp_output = torch.cat([head(
                        mlp_input,
                        combine_inner_dims=(self.num_views_per_obj, B),
                        combine_index=combine_index,
                        dim_size=dim_size,
                    ) for head in self.mlp_fine], dim=-1)
            else:
                if coarse or self.mlp_fine is None:
                    mlp_output = self.mlp_coarse(
                        mlp_input,
                        combine_inner_dims=(self.num_views_per_obj, B),
                        combine_index=combine_index,
                        dim_size=dim_size,
                    )
                else:
                    mlp_output = self.mlp_fine(
                        mlp_input,
                        combine_inner_dims=(self.num_views_per_obj, B),
                        combine_index=combine_index,
                        dim_size=dim_size,
                    )
                    

            # Interpret the output
            mlp_output = mlp_output.reshape(-1, B, self.d_out)
            if no_activations:
                return mlp_output
            sigma = mlp_output[..., 0].unsqueeze(-1)
            if sigma_activations:
                if self.bckgd_pred == "one_minus_sigma":
                    sigma = torch.sigmoid(sigma)
                else:
                    sigma = torch.relu(sigma)
            output_list = [sigma]
            if self.use_rgb_head:
                rgb = mlp_output[..., 1:4]
                output_list.append(torch.sigmoid(rgb))

            #TODO: label
            if self.use_seg_head:
                if self.bckgd_pred == "constant":
                    seg = mlp_output[..., -self.n_classes+1:]
                    seg = torch.cat([torch.ones_like(sigma, requires_grad=True), seg], dim=-1)
                elif self.bckgd_pred == "one_minus_sigma":
                    seg = mlp_output[..., -self.n_classes+1:]
                    seg = torch.cat([torch.ones_like(sigma, requires_grad=True) - sigma, seg], dim=-1)
                    # seg[...,0] = 1 - sigma
                else:
                    seg = mlp_output[..., -self.n_classes:]

                # print(seg[...,0])

                output_list.append(seg)
            
            if self.use_feat_head:
                h = mlp_output[..., 1:]
                output_list.append(h)
            

            output = torch.cat(output_list, dim=-1)
            output = output.reshape(SB, B, -1)
        # print(output.shape)
        return output

    def load_weights(self, args, opt_init=False, strict=True, device=None, load_best=False):
        """
        Helper for loading weights according to argparse arguments.
        Your can put a checkpoint at checkpoints/<exp>/pixel_nerf_init to use as initialization.
        :param opt_init if true, loads from init checkpoint instead of usual even when resuming
        """
        # TODO: make backups
        # import pdb; pdb.set_trace()
        self.init_tu_culo = (args.opt_init or opt_init)
        if not self.init_tu_culo and not args.resume and not load_best:
            return

        if self.init_tu_culo:
            if self.init_ckpt is not None:
                model_path = self.init_ckpt
            else:
                model_path = "%s/%s/%s" % (args.checkpoints_path, args.name, "pixel_nerf_init")
            if args.resume:
                resume_path = "%s/%s/%s" % (args.checkpoints_path, args.name, "pixel_nerf_latest")
                if os.path.exists(resume_path):
                    model_path = resume_path
            if load_best:
                model_path = "%s/%s/%s" % (args.checkpoints_path, args.name, "best_pixel_nerf_latest")
        else:
            if load_best:
                model_path = "%s/%s/%s" % (args.checkpoints_path, args.name, "best_pixel_nerf_latest")
            else:
                model_path = "%s/%s/%s" % (args.checkpoints_path, args.name, "pixel_nerf_latest")

        if device is None:
            device = self.poses.device

        if os.path.exists(model_path):
            print("Load", model_path)
            self.load_state_dict(
                torch.load(model_path, map_location=device), strict=strict
            )
        elif not opt_init:
            warnings.warn(
                (
                    "WARNING: {} does not exist, not loaded!! Model will be re-initialized.\n"
                    + "If you are trying to load a pretrained model, STOP since it's not in the right place. "
                    + "If training, unless you are starting a new experiment, please remember to pass --resume."
                ).format(model_path)
            )
        return self

    def save_weights(self, args, best = False, opt_init=False):
        """
        Helper for saving weights according to argparse arguments
        :param opt_init if true, saves from init checkpoint instead of usual
        """
        from shutil import copyfile
        if not best:
            ckpt_name = "pixel_nerf_init" if opt_init else "pixel_nerf_latest"
            backup_name = "pixel_nerf_init_backup" if opt_init else "pixel_nerf_backup"
        else:
            ckpt_name = "best_pixel_nerf_init" if opt_init else "best_pixel_nerf_latest"
            backup_name = "best_pixel_nerf_init_backup" if opt_init else "best_pixel_nerf_backup"

        ckpt_path = osp.join(args.checkpoints_path, args.name, ckpt_name)
        ckpt_backup_path = osp.join(args.checkpoints_path, args.name, backup_name)

        if osp.exists(ckpt_path):
            copyfile(ckpt_path, ckpt_backup_path)
        torch.save(self.state_dict(), ckpt_path)
        return self

class PixelNeSFNet(PixelNeRFNet):
    def __init__(self, conf, stop_encoder_grad=False, **kwargs):
        """
        :param conf PyHocon config subtree 'model'
        """
        super().__init__(conf, stop_encoder_grad=False)
        from .unet_3d import MinkUNet
        # import pdb; pdb.set_trace()
        self.grid_N = 64
        # self.pixelnerf = PixelNeRFNet(conf=conf, stop_encoder_grad=stop_encoder_grad)
        self.pixelnerf = super()
        # self.net_3d = SparseConvUNet(d_in=self.pixelnerf.d_out, d_out=self.pixelnerf.d_out,
        #                              dimension=3, spatialSize=self.grid_N, reps=1, m=32)
        self.net_3d = MinkUNet(in_nchannel=self.d_out, out_nchannel=self.d_out)
        # self.net_3d.train()
        self.nesf_mlp = make_mlp(conf["mlp_fine"], d_in=self.d_out, d_latent=0,
                            d_out=self.d_out, allow_empty=True)
        

    def encode(self, images, poses, focal, z_bounds=None, c=None):
        """
        :param images (NS, 3, H, W)
        NS is number of input (aka source or reference) views
        :param poses (NS, 4, 4)
        :param focal focal length () or (2) or (NS) or (NS, 2) [fx, fy]
        :param z_bounds ignored argument (used in the past)
        :param c principal point None or () or (2) or (NS) or (NS, 2) [cx, cy],
        default is center of image
        """
        self.BS = images.shape[0]

        self.pixelnerf.encode(images=images, poses=poses, focal=focal,
                              z_bounds=z_bounds, c=c)
        t = torch.linspace(-1.0, 1.0, self.grid_N)
        self.nesf_grid_coords = torch.stack(torch.meshgrid(t, t, t), -1).reshape(-1,3).float().to(poses.device)
        # t = np.linspace(-1.0, 1.0, N)
        # nesf_grid = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32)
        # nesf_grid = torch.from_numpy(nesf_grid).to(self.device)
        # print(nesf_grid.shape, images.shape)
        self.nesf_grid_coords = self.nesf_grid_coords.unsqueeze(0).repeat(self.BS, 1, 1) # (BS, NxNxN, 3)
        # self.nesf_grid_coords *= self.grid_N//2
        self.nesf_grid = self.pixelnerf.forward(self.nesf_grid_coords, coarse=False,
                                        viewdirs=torch.zeros_like(self.nesf_grid_coords)) # (BS, NxNxN, C)
        if self.stop_encoder_grad:
            self.nesf_grid = self.nesf_grid.detach()

        self.nesf_grid_coords = self.nesf_grid_coords.reshape(-1,self.nesf_grid_coords.shape[-1])
        self.nesf_grid = self.nesf_grid.reshape(-1,self.nesf_grid.shape[-1])
        # self.nesf_grid = self.nesf_grid.reshape(BS, self.grid_N,self.grid_N,self.grid_N, self.nesf_grid.shape[-1])
        # self.nesf_grid = self.nesf_grid.permute((0,4,1,2,3))

    def forward(self, xyz, coarse=True, viewdirs=None, far=False):
        nesf_coords = torch.round( ( (self.nesf_grid_coords+1)*(self.grid_N-1) ) / 2 ).long()
        feats_3d = self.net_3d(([nesf_coords], self.nesf_grid, self.BS)) # 
        feats_3d = feats_3d.reshape(self.BS, self.grid_N, self.grid_N, 
            self.grid_N, -1).permute((0,4,1,2,3)).contiguous()
        out_feats = torch.nn.functional.grid_sample(feats_3d, xyz.unsqueeze(1).unsqueeze(1), mode='bilinear',
                                                    padding_mode='border', align_corners=True)
        out_feats = out_feats.squeeze().transpose(-2,-1).contiguous()
        if len(out_feats.shape) == 2:
            out_feats = out_feats.unsqueeze(0)
        nesf_output = self.nesf_mlp(out_feats)
        pixelnerf_output = self.pixelnerf.forward(xyz, coarse=coarse, viewdirs=viewdirs)

        # sigma = pixelnerf_output[..., 0].unsqueeze(-1)
        # output_list = [sigma]
        sigma = nesf_output[..., 0].unsqueeze(-1)
        output_list = [torch.relu(sigma)]

        # rgb = pixelnerf_output[..., 1:4]
        # output_list.append(rgb)
        rgb = nesf_output[..., 1:4]
        output_list.append(torch.sigmoid(rgb))

        seg = nesf_output[..., 4:]
        output_list.append(seg)

        output = torch.cat(output_list, dim=-1)
        # print(output.shape)
        output = output.reshape(self.BS, -1, nesf_output.shape[-1])
        # print(output.shape)
        return output

class PixelUNISURFNet(torch.nn.Module):
    def __init__(self, conf, dim=3,
                hidden_size=256,
                octaves_pe_views=4,
                stop_encoder_grad=False,
                sigmoid_scaling=1.0,
                init_ckpt=None,
                **kwargs):
        """
        :param conf PyHocon config subtree 'model'
        """
        super().__init__()

        pixelnerf_conf = copy.deepcopy(conf)
        # print(pixelnerf_conf)
        pixelnerf_conf['use_rgb_head'] = False
        pixelnerf_conf['use_seg_head'] = False
        pixelnerf_conf['use_feat_head'] = True
        pixelnerf_conf['use_viewdirs'] = False
        self.pixelnerf = PixelNeRFNet(conf=pixelnerf_conf, stop_encoder_grad=stop_encoder_grad, uniseg=True)
        self.use_rgb_head = conf.get_bool("use_rgb_head", True)
        self.use_seg_head = conf.get_bool("use_seg_head", True)
        self.sigmoid_scaling = sigmoid_scaling
        self.n_classes = conf.n_classes
        self.init_ckpt=init_ckpt
        
        # Enable view directions
        # import pdb; pdb.set_trace()
        # self.use_viewdirs = conf.get_bool("use_viewdirs", False)

        # self.pixelnerf = super()
        # TODO: normals dimensions
        
        dim_embed_view = self.pixelnerf.d_out -1 + dim + dim + dim
        ## appearance network
        dims_view = [dim_embed_view]+ [ hidden_size for i in range(0, 4)] + [dim]
        self.num_layers_app = len(dims_view)

        for l in range(0, self.num_layers_app - 1):
            out_dim = dims_view[l + 1]
            lina = nn.Linear(dims_view[l], out_dim)
            lina = nn.utils.weight_norm(lina)
            setattr(self, "lin_rgb" + str(l), lina)

        ## segmentation network
        dims_view = [dim_embed_view]+ [ hidden_size for i in range(0, 4)] + [self.n_classes]
        self.num_layers_app = len(dims_view)

        for l in range(0, self.num_layers_app - 1):
            out_dim = dims_view[l + 1]
            lina = nn.Linear(dims_view[l], out_dim)
            lina = nn.utils.weight_norm(lina)
            setattr(self, "lin_seg" + str(l), lina)

        self.softplus = nn.Softplus(beta=100)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        
    def gradient(self, p):
        with torch.enable_grad():
            p.requires_grad_(True)
            if self.init_tu_culo:
                y = (self.pixelnerf.forward(p, coarse=True, viewdirs=torch.zeros_like(p))[...,:1])
            else:
                y = self.pixelnerf.forward(p, coarse=True, viewdirs=torch.zeros_like(p))[...,:1]
            d_output = torch.ones_like(y, requires_grad=False, device=y.device)
            gradients = torch.autograd.grad(
                outputs=y,
                inputs=p,
                grad_outputs=d_output,
                create_graph=True,
                retain_graph=True,
                only_inputs=True, allow_unused=True)[0]
            # import pdb; pdb.set_trace()
            return gradients.unsqueeze(1)

    def infer_app(self, points, normals, view_dirs, feature_vectors):
        # print(points.shape, normals.squeeze(1).shape, view_dirs.unsqueeze(0).shape, feature_vectors.shape)
        rendering_input = torch.cat([points, view_dirs, normals.squeeze(1), feature_vectors], dim=-1)
        out = {}
        if self.use_rgb_head:
            rgb = rendering_input.clone()
            for l in range(0, self.num_layers_app - 1):
                lina = getattr(self, "lin_rgb" + str(l))
                # import pdb; pdb.set_trace()
                rgb = lina(rgb)
                if l < self.num_layers_app - 2:
                    rgb = self.relu(rgb)
            rgb = self.tanh(rgb) * 0.5 + 0.5
            out['rgb'] = rgb

        if self.use_seg_head:
            seg = rendering_input.clone()
            for l in range(0, self.num_layers_app - 1):
                lina = getattr(self, "lin_seg" + str(l))
                # import pdb; pdb.set_trace()
                seg = lina(seg)
                if l < self.num_layers_app - 2:
                    seg = self.relu(seg)
            # seg = self.tanh(seg) * 0.5 + 0.5
            out['seg'] = seg
        return out

    def encode(self, images, poses, focal, z_bounds=None, c=None):
        self.pixelnerf.encode(images=images, poses=poses, focal=focal, z_bounds=z_bounds, c=c)

    def forward(self, xyz, coarse=True, viewdirs=None, far=False, only_occupancy=False, return_addocc=False):
        # print(xyz)
        with profiler.record_function("model_unisurf"):
            # print(xyz.shape)
            output = self.pixelnerf(xyz, coarse=True, viewdirs=None, no_activations=True)
            # import pdb; pdb.set_trace()
            # normals = F.normalize(normals)
            # if self.init_tu_culo:
            #     # print("Since it is loading previous weights the sigmoid function is changing FORWARD: UNISURF")
            #     sigma = torch.sigmoid(output[...,:1] * 1.0 - 5)
            # else:
            #     # print("Since it is NOT loading previous weights the sigmoid function is changing FORWARD: UNISURF")
            #     sigma = torch.sigmoid(output[...,:1] * -10.0 )
            sigma = torch.sigmoid((output[...,:1] - 5) * self.sigmoid_scaling)
            
        if only_occupancy:
            return sigma
        normals =  self.gradient(xyz)
        # print(xyz.shape, normals.shape, output.shape)
        rgb = self.infer_app(xyz, normals, viewdirs, output[...,1:])
        if return_addocc:
            return rgb, sigma
        out_list = [sigma]
        if self.use_rgb_head:
            out_list.append(rgb['rgb'])
        if self.use_seg_head:
            out_list.append(rgb['seg'])
        return torch.cat(out_list, dim=-1)

    def load_weights(self, args, strict=True, device=None, opt_init=False, load_best=False):
        """
        Helper for loading weights according to argparse arguments.
        Your can put a checkpoint at checkpoints/<exp>/pixel_nerf_init to use as initialization.
        :param opt_init if true, loads from init checkpoint instead of usual even when resuming
        """
        # print(id(self))
        # TODO: make backups
        # self.opt_initt = True
        # import pdb; pdb.set_trace()
        

        self.init_tu_culo = (args.opt_init or opt_init)
        if not self.init_tu_culo and not args.resume and not load_best:
            return

        if self.init_tu_culo:
            if self.init_ckpt is not None:
                model_path = self.init_ckpt
            else:
                model_path = "%s/%s/%s" % (args.checkpoints_path, args.name, "pixel_nerf_init")
            if args.resume:
                resume_path = "%s/%s/%s" % (args.checkpoints_path, args.name, "pixel_nerf_latest")
                if os.path.exists(resume_path):
                    model_path = resume_path
                    self.init_tu_culo = False
            if load_best:
                model_path = "%s/%s/%s" % (args.checkpoints_path, args.name, "best_pixel_nerf")
                self.init_tu_culo = False
        else:
            if load_best:
                model_path = "%s/%s/%s" % (args.checkpoints_path, args.name, "best_pixel_nerf")
            else:
                model_path = "%s/%s/%s" % (args.checkpoints_path, args.name, "pixel_nerf_latest")
        

        if device is None:
            device = self.pixelnerf.poses.device
        # print("Trying to load ", model_path)
        if os.path.exists(model_path):
            print("Load", model_path)

            pretrained_dict = torch.load(model_path, map_location=device)
            
            if self.init_tu_culo:
                model_dict = copy.deepcopy(self.state_dict())
                # 1. filter out unnecessary keys
                for k, v in self.state_dict().items():
                    k_no_pixelnerf = k.replace('pixelnerf.', '')
                    if 'lin_in.weight' in k:
                        model_dict[k] = pretrained_dict[k_no_pixelnerf][:,:-3]
                        continue
                    if 'lin_out.weight' in k:
                        model_dict[k][:1] = pretrained_dict[k_no_pixelnerf][:1]
                        continue
                    if 'lin_out.bias' in k:
                        model_dict[k][:1] = pretrained_dict[k_no_pixelnerf][:1]
                        continue

                    if k in pretrained_dict:
                        model_dict[k] = pretrained_dict[k]
                    elif k_no_pixelnerf in pretrained_dict:
                        model_dict[k] = pretrained_dict[k_no_pixelnerf]
                pretrained_dict = model_dict
            self.load_state_dict(pretrained_dict)
            # import pdb; pdb.set_trace()

        elif not self.init_tu_culo:
            warnings.warn(
                (
                    "WARNING: {} does not exist, not loaded!! Model will be re-initialized.\n"
                    + "If you are trying to load a pretrained model, STOP since it's not in the right place. "
                    + "If training, unless you are starting a new experiment, please remember to pass --resume."
                ).format(model_path)
            )
        return self

    def save_weights(self, args, best = False, opt_init=False):
        """
        Helper for saving weights according to argparse arguments
        :param opt_init if true, saves from init checkpoint instead of usual
        """
        from shutil import copyfile
        if not best:
            ckpt_name = "pixel_nerf_init" if opt_init else "pixel_nerf_latest"
            backup_name = "pixel_nerf_init_backup" if opt_init else "pixel_nerf_backup"
        else:
            ckpt_name = "best_pixel_nerf_init" if opt_init else "best_pixel_nerf"
            backup_name = "best_pixel_nerf_init_backup" if opt_init else "best_pixel_nerf_backup"

        ckpt_path = osp.join(args.checkpoints_path, args.name, ckpt_name)
        ckpt_backup_path = osp.join(args.checkpoints_path, args.name, backup_name)

        if osp.exists(ckpt_path):
            copyfile(ckpt_path, ckpt_backup_path)
        torch.save(self.state_dict(), ckpt_path)
        return self