import torch
import torch.nn as nn
from torch import distributions as dist
import torch.nn.functional as F
import numpy as np

from models.decoder_module import decoder
from models.encoder_module import voxels
from models.renderer_module.single_variance import SingleVarianceNetwork
from vgn.utils.transform import Rotation, Transform
from torchvision.transforms import Resize
from models.renderer_module.rays import *
from models.renderer_module.renderer import *
from utils.misc import EasyDict


class AGATENet(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if use_cuda else "cpu")
        self.cfg = cfg

        
        dataset = None
        c_dim = cfg.encoder.c_dim
        padding = cfg.padding
        plane_resolution = cfg.encoder.plane_resolution

        encoder_type = self.cfg.encoder.type
        encoder_kwargs = {'plane_type': ['xz', 'xy', 'yz'],  
                          'plane_resolution': plane_resolution, 'unet': True,
                          'unet_kwargs':
                          {'depth': 3, 'merge_mode': 'concat', 'start_filts': 32}}
        if encoder_type == 'idx':
            self.encoder = nn.Embedding(len(dataset), c_dim).to(self.device)
        elif encoder_type is not None:  
            self.encoder = voxels.LocalVoxelEncoder(
                c_dim=c_dim, padding=padding,
                **encoder_kwargs
            ).to(self.device)
        else:
            self.encoder = None

        
        self.supervision = cfg.geometry_decoder.type
        self.decoder_padding = cfg.decoder_padding
        self.detach_tsdf = False

        decoder_kwargs = {'sample_mode': self.cfg.grasp_decoder.sample_mode,
                          'hidden_size': self.cfg.grasp_decoder.hidden_size,
                          'concat_feat': self.cfg.grasp_decoder.concat_feat}

        if cfg.grasp_decoder.decoder_input == 'pos':
            grasp_in_dim = 3
        else:
            grasp_in_dim = 0
        if cfg.geometry_decoder.decoder_input == 'pos':
            geo_in_dim = 3
        else:
            geo_in_dim = 0

        self.decoder_qual = decoder.LocalDecoder(
            in_dim=grasp_in_dim, c_dim=c_dim, padding=padding, out_dim=1,  
            **decoder_kwargs).to(self.device)
        self.decoder_rot = decoder.LocalDecoder(
            in_dim=grasp_in_dim, c_dim=c_dim, padding=padding, out_dim=4,
            **decoder_kwargs).to(self.device)
        self.decoder_width = decoder.LocalDecoder(
            in_dim=grasp_in_dim, c_dim=c_dim, padding=padding, out_dim=1,
            **decoder_kwargs).to(self.device)
        self.decoder_geo = decoder.LocalDecoder(
            in_dim=geo_in_dim, c_dim=c_dim, padding=padding, out_dim=1,
            **decoder_kwargs).to(self.device)

        if self.supervision == "rendered_depth":
            self.deviation_network = SingleVarianceNetwork(init_val=0.3).to(self.device)
            self.renderer = NeuSRenderer(self.decoder_geo, cfg.geometry_decoder.decoder_input,
                                         self.device,
                                         self.deviation_network,
                                         cfg.geometry_decoder.n_rays,
                                         cfg.geometry_decoder.n_pts_per_ray,
                                         cfg.geometry_decoder.n_importance,
                                         cfg.geometry_decoder.up_sample_steps,
                                         cfg.geometry_decoder.sample_near_gt_range)  

    def to(self, device):
        ''' Puts the model to the device.
        Args:
            device (device): pytorch device
        '''
        model = super().to(device)
        model._device = device
        return model

    def forward(self, batch_dict, iteration_count=None,
                grasp_branch=True, geometry_branch=True,
                **kwargs):
        ''' Performs a forward pass through the network.
        Returns: (dict)
            grasp:
                grasp_label: B*1, grasp quality
                grasp_rotation: B*1*4, grasp rotation
                grasp_width: B*1, grasp width
            geometry supervision: (↓ or)
                rendered_depth: B*H*W
                occupancy: B*2048*1, occupancy values (density+variance)
        '''
        pred_dict = EasyDict()
        
        assert len(batch_dict.tsdf.shape) == 4, "tsdf should be 4D (B, R**3)"
        feature = self.encoder(batch_dict.tsdf)
        if batch_dict.tsdf.shape[0] == 1:  
            for key in feature.keys():
                feature[key] = feature[key].repeat(batch_dict.point_grasp.shape[0], 1, 1, 1)

        if grasp_branch:
            pred_dict.grasp_label, pred_dict.grasp_rotation, pred_dict.grasp_width = \
                self.decode_grasp(batch_dict, feature)

        
        if geometry_branch:
            if self.supervision == "rendered_depth":
                original_size = 0.3
                bounding_box = self.get_bounding_box(self.decoder_padding, original_size)

                camera_extrinsic_matrix = torch.tensor([Transform.from_list(list(batch_dict.camera_extrinsic[:, 0, :][n].cpu())).as_matrix()
                                                        for n in range(batch_dict.camera_extrinsic.shape[0])],
                                                       device=self.device)

                
                render_out = self.renderer.render(feature,
                                                  batch_dict.camera_intrinsic,
                                                  camera_extrinsic_matrix,
                                                  batch_dict.depth_img,
                                                  bounding_box, original_size, self.decoder_padding,
                                                  iteration_count,
                                                  **kwargs)
                for k, v in render_out.items():
                    pred_dict[k] = v

                
                img_gt = batch_dict.depth_img[:, 0, :]
                img_gt = img_gt.reshape(img_gt.shape[0], -1)
                pixel_gt = torch.stack([img_gt[b, pred_dict.ray_index[b, :]] for b in range(img_gt.shape[0])], dim=0)
                pred_dict.approximate_sdf = pixel_gt[:, :, None, None] - pred_dict.mid_z_vals

            elif self.supervision == "occupancy":
                points = batch_dict.point_occ
                pred_dict.occupancy = torch.sigmoid(self.decoder_geo(points, feature,
                                                                     input=points if self.cfg.geometry_decoder.decoder_input == 'pos' else None
                                                                     )[:, :, 0])
            elif self.supervision == "none":
                pass

            else:
                raise ValueError(f"Unknown supervision type: {self.supervision}")

        return pred_dict

    def get_bounding_box(self, padding, original_size):
        '''generate sample space after padding (real world coordinates)
        '''
        original_size = 0.3
        bounding_box = np.array([0.0, 0.0, 0.0, original_size, original_size, original_size])  
        bounding_box -= original_size * 0.5
        bounding_box *= (1 + padding)
        bounding_box += original_size * 0.5
        return bounding_box

    def decode_grasp(self, input_batch, feature, **kwargs):
        ''' Returns grasp prediction for the sampled points.
        Args:
            pos (tensor): points (-0.5 ~ 0.5), [batch_size, 1, 3]
            feature (tensor): latent conditioned code, dict['xz', 'xy', 'yz' or 'grid']
        '''
        pos = input_batch.point_grasp
        pos = pos / (1 + self.decoder_padding)

        if self.cfg.grasp_decoder.decoder_input == 'pos':
            input = pos
        elif self.cfg.grasp_decoder.decoder_input == 'None':
            input = None
        else:
            raise NotImplementedError

        qual = self.decoder_qual(pos, feature, input, **kwargs)
        qual = torch.where(torch.isnan(qual), torch.zeros_like(qual), qual)  
        qual = torch.sigmoid(qual)
        qual = qual.squeeze(-1)
        rot = self.decoder_rot(pos, feature, input, **kwargs)
        rot = nn.functional.normalize(rot, dim=2)
        rot = rot.squeeze(1)
        width = self.decoder_width(pos, feature, input, **kwargs)
        width = width.squeeze(-1)
        
        return qual, rot, width

    def compute_loss(self, batch_dict, prediction):
        """loss function, combines grasp loss and occupancy loss
        Args:
            batch_dict: gt data
            prediction: net output
        Returns:
            loss (EasyDict): loss values
        """
        loss = EasyDict()

        loss_qual = self.qual_loss_fn(prediction.grasp_label, batch_dict.grasp_label)
        loss_rot = self.rot_loss_fn(prediction.grasp_rotation, batch_dict.grasp_rotation)
        loss_width = self.width_loss_fn(prediction.grasp_width, batch_dict.grasp_width)
        loss.loss_qual = loss_qual.mean()
        loss.loss_rot = loss_rot.mean()
        loss.loss_width = loss_width.mean()
        loss.loss_all = (loss_qual + batch_dict.grasp_label * (1.0 * loss_rot + 16 * loss_width)).mean()

        if self.supervision == "rendered_depth":
            pixel_pred = prediction.rendered_depth
            img_gt = batch_dict.depth_img[:, 0, :]
            img_gt = img_gt.reshape(img_gt.shape[0], -1)
            
            pixel_gt = torch.gather(img_gt, 1, prediction.ray_index)  
            loss.loss_geo = self.depth_loss_fn(torch.squeeze(pixel_pred),
                                               pixel_gt,
                                               mask=torch.squeeze(prediction.valid_ray_mask)).mean()
            loss.loss_all += loss.loss_geo * self.cfg.geometry_decoder.loss_depth_img_weight

            
            if self.training:
                loss.loss_eik = prediction.gradient_error  
                loss.loss_all += loss.loss_eik * self.cfg.geometry_decoder.loss_eik_weight

            
            n_pts_per_ray = int(prediction.sdf.shape.numel() / prediction.valid_ray_mask.shape.numel())
            valid_pts_mask = prediction.valid_ray_mask.repeat(1, 1, n_pts_per_ray).reshape(-1)
            sdf = prediction.sdf.reshape(-1, 1)[valid_pts_mask]
            approximate_sdf = prediction.approximate_sdf.reshape(-1, 1)[valid_pts_mask]
            is_near_surface = approximate_sdf <= 0.05

            
            if is_near_surface.sum() > 0:
                loss.loss_near = torch.abs(sdf[is_near_surface] - approximate_sdf[is_near_surface]).sum() / is_near_surface.sum()
                loss.loss_all += loss.loss_near * self.cfg.geometry_decoder.loss_near_weight

            
            if (~is_near_surface).sum() > 0:
                free_space_loss = torch.stack((torch.exp(-5 * sdf[~is_near_surface]) - 1.0,
                                               sdf[~is_near_surface] - approximate_sdf[~is_near_surface]), dim=1).max(1)[0]
                loss.loss_free = free_space_loss.clamp(min=0.0).sum() / (~is_near_surface).sum()
                loss.loss_all += loss.loss_free * self.cfg.geometry_decoder.loss_free_weight

        elif self.supervision == "occupancy":
            loss.loss_geo = self.occ_loss_fn(prediction.occupancy, batch_dict.occupancy).mean()
            loss.loss_all += loss.loss_geo

        elif self.supervision == "none":
            loss.loss_geo = torch.zeros_like(loss.loss_all)
            loss.loss_all += loss.loss_geo
        else:
            raise ValueError("Unknown supervision type: {}".format(self.supervision))

        return loss

    def render_full_image(self, batch_dict, iteration_count, **kwargs):
        
        feature = self.encoder(batch_dict.tsdf)

        original_size = 0.3
        bounding_box = self.get_bounding_box(self.decoder_padding, original_size)

        camera_extrinsic_matrix = torch.tensor([Transform.from_list(list(batch_dict.camera_extrinsic[:, 0, :][n].cpu())).as_matrix()
                                                for n in range(batch_dict.camera_extrinsic.shape[0])],
                                               device=self.device)

        
        render_out = self.renderer.render_full_image(feature,
                                                     batch_dict.camera_intrinsic,
                                                     camera_extrinsic_matrix,
                                                     batch_dict.depth_img,
                                                     bounding_box, original_size, self.decoder_padding,
                                                     iteration_count,
                                                     **kwargs)
        return render_out

    def extract_geometry(self, batch_dict, iteration_count, resolution, **kwargs):
        
        feature = self.encoder(batch_dict.tsdf)

        original_size = 0.3
        bounding_box = self.get_bounding_box(self.decoder_padding, original_size)

        
        render_out = self.renderer.extract_geometry(feature,
                                                    resolution,
                                                    bounding_box, original_size, self.decoder_padding,
                                                    iteration_count,
                                                    **kwargs)
        return render_out

    def predict_grasp(self, batch):
        prediction = EasyDict()
        with torch.no_grad():
            prediction = self.forward(batch, geometry_branch=False)
        return prediction

    def qual_loss_fn(self, pred, target):
        return F.binary_cross_entropy(pred, target, reduction="none")

    def quat_loss_fn(self, pred, target):
        
        return 1.0 - torch.abs(torch.sum(pred * target, dim=1))

    def rot_loss_fn(self, pred, target):
        loss0 = self.quat_loss_fn(pred, target[:, 0])
        loss1 = self.quat_loss_fn(pred, target[:, 1])
        return torch.min(loss0, loss1)

    def width_loss_fn(self, pred, target):
        return F.mse_loss(pred, target, reduction="none")

    def occ_loss_fn(self, pred, target):
        return F.binary_cross_entropy(pred, target, reduction="none").mean(-1)

    def depth_loss_fn(self, pred, target, mask=None):
        pixel_loss = F.l1_loss(pred, target, reduction="none") * mask
        
        return pixel_loss.view(pixel_loss.shape[0], -1).mean(-1)
    
    def infer_geo(self, inputs, p_tsdf, **kwargs):
        c = self.encoder(inputs)
        tsdf = self.decoder_geo(p_tsdf, c, **kwargs)
        return tsdf

    def query_feature(self, pos, c):
        return self.decoder_qual.query_feature(pos, c)

    def decode_feature(self, pos, feature):
        '''【Not used】same as decode_grasp, but use .compute_out()
        '''
        qual = self.decoder_qual.compute_out(pos, feature)
        qual = torch.sigmoid(qual)
        rot = self.decoder_rot.compute_out(pos, feature)
        rot = nn.functional.normalize(rot, dim=2)
        width = self.decoder_width.compute_out(pos, feature)
        return qual, rot, width

    def decode_occ(self, pos, c, **kwargs):
        ''' Returns occupancy probabilities for the sampled points.
        Args:
            pos (tensor): points
            c (tensor): latent conditioned code c
        '''

        logits = self.decoder_geo(pos, c, **kwargs)
        p_r = dist.Bernoulli(logits=logits)
        return p_r

    def grad_refine(self, x, pos, bound_value=0.0125, lr=1e-6, num_step=1):
        pos_tmp = pos.clone()
        l_bound = pos - bound_value
        u_bound = pos + bound_value
        pos_tmp.requires_grad = True
        optimizer = torch.optim.SGD([pos_tmp], lr=lr)
        self.eval()
        for p in self.parameters():
            p.requres_grad = False
        for _ in range(num_step):
            optimizer.zero_grad()
            qual_out, _, _ = self.forward(x, pos_tmp)
            
            loss = - qual_out.sum()
            loss.backward()
            optimizer.step()
            
        with torch.no_grad():
            
            pos_tmp = torch.maximum(torch.minimum(pos_tmp, u_bound), l_bound)
            qual_out, rot_out, width_out = self.forward(x, pos_tmp)
        
        for p in self.parameters():
            p.requres_grad = True
        
        return qual_out, pos_tmp, rot_out, width_out
