import torch
import torch.nn as nn
from torch import distributions as dist
import torch.nn.functional as F
import numpy as np

from models.decoder_module import decoder
from models.encoder_module import voxels
from models.renderer_module.single_variance import SingleVarianceNetwork
from vgn.utils.transform import Rotation, Transform
from torchvision.transforms import Resize
from models.renderer_module.rays import *
from models.renderer_module.renderer import *
from utils.misc import EasyDict
from models.AGATE_architecture import AGATENet
from typing import Any, List, Dict, Set, Tuple, Union


class BERYLNet(AGATENet):
    def __init__(self, cfg):
        super().__init__(cfg)  
        use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if use_cuda else "cpu")
        self.cfg = cfg
        
        dataset = None
        c_dim = cfg.encoder.c_dim
        padding = cfg.padding
        plane_resolution = cfg.encoder.plane_resolution

        encoder_type = 'voxel_simple_local'
        encoder_kwargs = {'plane_type': ['xz', 'xy', 'yz'],
                          'plane_resolution': plane_resolution, 'unet': True,
                          'unet_kwargs':
                          {'depth': 3, 'merge_mode': 'concat', 'start_filts': 32}}
        if encoder_type == 'idx':
            self.encoder = nn.Embedding(len(dataset), c_dim).to(self.device)
        elif encoder_type is not None:  
            self.encoder = voxels.LocalVoxelEncoder(
                c_dim=c_dim, padding=padding,
                **encoder_kwargs
            ).to(self.device)
        else:
            self.encoder = None
        
        self.supervision = cfg.geometry_decoder.type
        self.decoder_padding = cfg.decoder_padding
        self.detach_tsdf = False

        decoder_kwargs = {'sample_mode': cfg.grasp_decoder.sample_mode,
                          'hidden_size': cfg.grasp_decoder.hidden_size,
                          'concat_feat': cfg.grasp_decoder.concat_feat}

        grasp_expand_sample_pts = np.prod(cfg.grasp_decoder.expand_sample_pts[:3])
        grasp_sample_along_ray = cfg.grasp_decoder.expand_sample_pts[4] > 0

        if self.cfg.grasp_decoder.decoder_input in ['pos', 'dir']:
            grasp_in_dim = 3
        elif self.cfg.grasp_decoder.decoder_input == 'pos_dir':
            grasp_in_dim = 6
        else:
            grasp_in_dim = 0
        if cfg.geometry_decoder.decoder_input == 'pos':
            geo_in_dim = 3
        else:
            geo_in_dim = 0

        self.decoder_qual = decoder.LocalDecoder(
            in_dim=grasp_in_dim, c_dim=c_dim, padding=padding, out_dim=1,  
            expand_sample_pts=grasp_expand_sample_pts, with_ray_feature=grasp_sample_along_ray, for_grasp=True,
            **decoder_kwargs).to(self.device)
        self.decoder_rot = decoder.LocalDecoder(
            in_dim=grasp_in_dim, c_dim=c_dim, padding=padding, out_dim=1,  
            expand_sample_pts=grasp_expand_sample_pts, with_ray_feature=grasp_sample_along_ray, for_grasp=True,
            **decoder_kwargs).to(self.device)
        self.decoder_width = decoder.LocalDecoder(
            in_dim=grasp_in_dim, c_dim=c_dim, padding=padding, out_dim=1,
            expand_sample_pts=grasp_expand_sample_pts, with_ray_feature=grasp_sample_along_ray, for_grasp=True,
            **decoder_kwargs).to(self.device)
        self.decoder_geo = decoder.LocalDecoder(
            in_dim=geo_in_dim, c_dim=c_dim, padding=padding, out_dim=1,
            **decoder_kwargs).to(self.device)

        if self.supervision == "rendered_depth":
            self.deviation_network = SingleVarianceNetwork(init_val=0.3).to(self.device)
            self.renderer = NeuSRenderer(self.decoder_geo, cfg.geometry_decoder.decoder_input,
                                         self.device,
                                         self.deviation_network,
                                         cfg.geometry_decoder.n_rays,
                                         cfg.geometry_decoder.n_pts_per_ray,
                                         cfg.geometry_decoder.n_importance,
                                         cfg.geometry_decoder.up_sample_steps,
                                         cfg.geometry_decoder.sample_near_gt_range)  

    def decode_grasp(self, input_batch, feature, **kwargs):
        ''' Returns grasp prediction for the sampled points.
        Args:
            pos (tensor): points (-0.5 ~ 0.5), [batch_size, 1, 3]
            feature (tensor): latent conditioned code, dict['xz', 'xy', 'yz' or 'grid']
        '''
        batch_size = len(input_batch.point_grasp)
        pos = input_batch.point_grasp
        pos = pos / (1 + self.decoder_padding)

        direction_vector = torch.zeros_like(pos, device=self.device)
        for i in range(batch_size):
            camera_rotation = Rotation.from_quat(list(input_batch.camera_extrinsic[:, 0, :][i].cpu())[:4])
            camera_M = Transform(camera_rotation, np.array([0, 0, 1])).as_matrix()  
            vector = np.linalg.inv(camera_M)[:3, 3]  
            direction_vector[i, 0] = torch.from_numpy(vector)

        if self.cfg.grasp_decoder.decoder_input == 'pos':
            input = pos.clone()
        elif self.cfg.grasp_decoder.decoder_input == 'dir':
            input = direction_vector
        elif self.cfg.grasp_decoder.decoder_input == 'pos_dir':
            input = torch.cat([pos.clone(), direction_vector], dim=-1)  
        elif self.cfg.grasp_decoder.decoder_input == 'None':
            input = None
        else:
            raise NotImplementedError

        pos = self.expand_sample_pts(pos, direction_vector)  

        qual = self.decoder_qual(pos, feature, input, **kwargs)
        qual = torch.where(torch.isnan(qual), torch.zeros_like(qual), qual)  
        qual = torch.sigmoid(qual)
        qual = qual.squeeze(-1)
        rot = self.decoder_rot(pos, feature, input, **kwargs)
        rot = rot.squeeze(1) * np.pi  
        width = self.decoder_width(pos, feature, input, **kwargs)
        width = width.squeeze(-1)
        
        return qual, rot, width  

    def expand_sample_pts(self, pts: torch.Tensor, dir: torch.Tensor) -> torch.Tensor:
        '''
        pts: [B, 1, 3]
        dir: [B, 1, 3]
        output: [B, 1 + expanded_num, 3]
        '''
        x_num, y_num, z_num, step, along_ray_step = self.cfg.grasp_decoder.expand_sample_pts
        result = []
        for i, p in enumerate(pts):
            
            local_x = torch.cross(dir[i][0], torch.tensor([0., 0., 1.], device=self.device, dtype=dir[i][0].dtype))
            local_x = F.normalize(local_x, dim=-1)
            local_y = torch.cross(dir[i][0], local_x)
            local_y = F.normalize(local_y, dim=-1)
            local_z = dir[i][0]

            expanded_pts = [p[0]]
            if np.prod(self.cfg.grasp_decoder.expand_sample_pts[:3]) > 1:
                
                for z_index in range(z_num):
                    for y_index in range(y_num):
                        for x_index in range(x_num):
                            expanded_pts.append(
                                p[0] +
                                step * (x_index - (x_num-1) / 2) * local_x +
                                step * (y_index - (y_num-1) / 2) * local_y +
                                step * (z_index - (z_num-1) / 2) * local_z
                            )
            if self.cfg.grasp_decoder.expand_sample_pts[4] > 0:
                
                pt = p[0].clone()
                while True:
                    pt_in_box = (pt[0] > -0.5 and pt[0] < 0.5 and
                                 pt[1] > -0.5 and pt[1] < 0.5 and
                                 pt[2] > -0.5 and pt[2] < 0.5)
                    if pt_in_box:
                        expanded_pts.append(pt.clone())
                    else:
                        break
                    pt += along_ray_step * local_z  

            expanded_pts = torch.stack(expanded_pts, dim=0).unsqueeze(0)
            result.append(expanded_pts)

        
        max_len = max([x.shape[1] for x in result])
        for i, r in enumerate(result):
            if r.shape[1] < max_len:
                result[i] = torch.cat([r, r[:, -1:, :].repeat(1, max_len - r.shape[1], 1)], dim=1)
        return torch.cat(result, dim=0)

    def rot_loss_fn(self, pred, target):
        return 1 - torch.cos(2 * (pred - target))  

    def predict_grasp(self, batch):
        prediction = EasyDict()
        with torch.no_grad():
            prediction = self.forward(batch, geometry_branch=False)

        batch_size = batch.point_grasp.shape[0]
        
        predicted_rot = prediction.grasp_rotation
        prediction.grasp_rotation = torch.zeros(batch_size, 4)
        for index in range(batch_size):
            
            camera_M = Transform(Rotation.from_quat(list(batch.camera_extrinsic[index, 0])[:4]),
                                 np.array([0, 0, 1])).as_matrix()
            normal = np.linalg.inv(camera_M)[:3, 3]
            z_axis = -normal  
            x_axis = np.r_[1.0, 0.0, 0.0]
            if np.isclose(np.abs(np.dot(x_axis, z_axis)), 1.0, 1e-4):  
                x_axis = np.r_[0.0, 1.0, 0.0]  
            y_axis = np.cross(z_axis, x_axis)  
            x_axis = np.cross(y_axis, z_axis)
            R = Rotation.from_matrix(np.vstack((x_axis, y_axis, z_axis)).T)  

            ori = R * Rotation.from_euler("z", predicted_rot[index].cpu())
            
            prediction.grasp_rotation[index] = torch.tensor(ori.as_quat())
        return prediction
