import logging

import torch
import torch.nn as nn

from lib.gaussian.gs_utils import *

from gsplat.project_gaussians_2d import project_gaussians_2d_batch, project_gaussians_2d
from gsplat.rasterize_sum import rasterize_gaussians_sum_batch, rasterize_gaussians_sum
from omegaconf import DictConfig

logger = logging.getLogger("lib.gaussian.gaussianimage_cholesky_batch")

class GaussianImage_Cholesky_Batch(nn.Module):
    def __init__(self, cfg: DictConfig, device):
        super().__init__()
        self.device = device
        self.cfg = cfg
        self.init_num_points = self.cfg.num_points
        self.batch_size = cfg.get("batch_size", 1)
        self.use_opacity = cfg.get("use_opacity", False)
        if self.use_opacity:
            logger.info("GaussianImage_Cholesky: using opacity")
        logger.info(f"GaussianImage_Cholesky: batch size set to {self.batch_size}")
        logger.info(f"GaussianImage_Cholesky: init {self.init_num_points} points per image")

        self.H, self.W = self.cfg.H, self.cfg.W
        self.BLOCK_W, self.BLOCK_H = self.cfg.BLOCK_W, self.cfg.BLOCK_H
        self.tile_bounds = (
            (self.W + self.BLOCK_W - 1) // self.BLOCK_W,
            (self.H + self.BLOCK_H - 1) // self.BLOCK_H,
            1,
        )
        
        self.precision = str(cfg.get("precision", "fp32")).lower()  
        self.kernel_dtype_pref = str(cfg.get("kernel_dtype", "fp32")).lower()  

        logger.info(f"GaussianImage_Cholesky: using precision = {self.precision}, kernel_dtype = {self.kernel_dtype_pref}")

        def _resolve_param_dtype(p):
            if p == "fp32": return torch.float32
            if p == "fp16": return torch.float32
            if p == "bf16": return torch.float32
            if p == "fp8": return torch.float32
            raise ValueError(f"Unsupported precision: {p}")

        def _resolve_compute_dtype(p):
            if p == "fp32": return torch.float32
            if p == "fp16": return torch.float16
            if p == "bf16": return torch.bfloat16
            if p == "fp8":  return torch.float8_e4m3fn
            raise ValueError(f"Unsupported precision: {p}")

        self.param_dtype   = _resolve_param_dtype(self.precision)
        self.compute_dtype = _resolve_compute_dtype(self.precision)
        self._use_kernel_native_dtype = (self.kernel_dtype_pref == "native")

        
        base_eps = float(cfg.get("eps", 1e-6))
        if self.precision in ("fp16", "bf16", "fp8"):
            base_eps = max(base_eps, 1e-4)
        self._eps_for_clamp = base_eps

        total_points = self.init_num_points * self.batch_size
        self.params = nn.ParameterDict({
            "xy":       nn.Parameter(2 * torch.rand(total_points, 2, device=device, dtype=self.param_dtype) - 1),
            "cholesky": nn.Parameter(torch.rand(total_points, 3, device=device, dtype=self.param_dtype)),
            "features": nn.Parameter(torch.rand(total_points, 3, device=device, dtype=self.param_dtype)),
        })
        
        opacity_all = torch.ones((total_points, 1), device=device, dtype=self.param_dtype)

        if self.use_opacity:
            self.params["opacity"] = nn.Parameter(opacity_all)
            self._opacity = self.params["opacity"]
        else:
            self.register_buffer("_opacity", opacity_all)

        
        self.register_buffer('num_points_per_image',
                             torch.full((self.batch_size,), self.init_num_points,
                                        dtype=torch.long, device=device))
        background_single = torch.ones(3, device=device, dtype=self.param_dtype)
        self.register_buffer('background', background_single.unsqueeze(0).repeat(self.batch_size, 1))
        self.register_buffer('cholesky_bound',
                             torch.tensor([0.5, 0, 0.5], device=device,
                                          dtype=self.param_dtype).view(1, 3))
    
    def clamp(self):
        
        lo = -1 + self._eps_for_clamp * 10
        hi =  1 - self._eps_for_clamp * 10
        self.params['xy'].data.clamp_(lo, hi)
    
    def _to_kernel_dtype(self, *tensors):
        if self._use_kernel_native_dtype:
            return [t.contiguous() for t in tensors]
        else:
            return [t.to(torch.float32).contiguous() for t in tensors]
    
    def _autocast_context(self):
        use_amp = self.compute_dtype in (torch.float16, torch.bfloat16, torch.float8_e4m3fn)
        return torch.amp.autocast(device_type="cuda", dtype=self.compute_dtype, enabled=use_amp)

    @property
    def get_total_points(self):
        
        return self.num_points_per_image.sum().item()

    @property
    def offsets(self):
        return torch.cumsum(torch.cat([torch.tensor([0], device=self.device), self.num_points_per_image[:-1]]), dim=0)

    @property
    def get_xyz(self):
        return torch.clamp(min=-1, max=1.0, input=self.params['xy'])
    
    @property
    def get_features(self):
        return self.params['features']
    
    @property
    def get_opacity(self):
        return self.params["opacity"] if self.use_opacity else self._opacity

    @property
    def get_cholesky_elements(self):
        return self.params['cholesky'] + self.cholesky_bound

    
    @property
    def get_total_storage(self):
        total_params = sum(p.numel() for p in self.params.values())
        total_bytes = total_params * torch.tensor([], dtype=self.param_dtype).element_size()
        return total_bytes
    
    def quantize_self(self, dtype=None):
        if dtype is None:
            dtype = self.compute_dtype
        else:
            if dtype not in (torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn):
                raise ValueError("Only fp32, fp16, bf16, fp8 are supported for quantization.")
        
        dtype_list_before = [param.dtype for param in self.params.values()]
        assert all(d == dtype_list_before[0] for d in dtype_list_before), "All parameters must have the same dtype before quantization."
        storage_before = self.get_total_storage
        logger.info(f"Before quantization: total params = {self.get_total_points}, param_dtype = {dtype_list_before[0]}, total storage = {storage_before} bytes.")
        for name, param in self.params.items():
            if param.dtype != dtype:
                param.data = param.data.to(dtype)
                logger.info(f"Quantized parameter '{name}' to {dtype}.")
        
        self.param_dtype = dtype
        self.compute_dtype = dtype
        dtype_list_after = [param.dtype for param in self.params.values()]
        assert all(d == dtype for d in dtype_list_after), "All parameters must have the same dtype after quantization."
        storage_after = self.get_total_storage
        logger.info(f"After quantization: total params = {self.get_total_points}, param_dtype = {dtype}, total storage = {storage_after} bytes.")

    def forward(self):
        num_points_for_cuda = self.num_points_per_image[0].item()

        with self._autocast_context():
            
            xy      = self.params['xy']
            chol       = self.params['cholesky']
            feats   = self.get_features
            opac    = self.get_opacity
            bg      = self.background

            xy = xy.to(self.compute_dtype)
            chol  = chol .to(self.compute_dtype)
            feats = feats.to(self.compute_dtype)
            opac  = opac.to(self.compute_dtype)
            bg    = bg.to(self.compute_dtype)

            
            xy_k, chol_k, feats_k, opac_k, bg_k = self._to_kernel_dtype(xy, chol, feats, opac, bg)
            L_k = chol_k + self.cholesky_bound

            xys, depths, radii, conics, num_tiles_hit = project_gaussians_2d_batch(
                self.batch_size, num_points_for_cuda, xy_k, L_k,
                self.H, self.W, self.tile_bounds
            )

            out_img = rasterize_gaussians_sum_batch(
                self.batch_size, num_points_for_cuda,
                xys, depths, radii, conics, num_tiles_hit,
                feats_k, opac_k,
                self.H, self.W, self.BLOCK_H, self.BLOCK_W,
                background=bg_k
            )

        return {"render": out_img.permute(0, 3, 1, 2).contiguous()}
    
    def forward_subset(self, indices):
        if isinstance(indices, (list, tuple)):
            idx = torch.tensor(indices, dtype=torch.long, device=self.device)
        elif isinstance(indices, torch.Tensor):
            if indices.dtype == torch.bool:
                if indices.numel() != self.batch_size:
                    raise ValueError("Boolean mask size must equal batch_size.")
                idx = torch.nonzero(indices.to(device=self.device), as_tuple=False).flatten()
            else:
                idx = indices.to(self.device).long()
        else:
            raise TypeError("indices must be List[int], Tuple[int], or torch.Tensor.")

        
        if idx.numel() == 0:
            return {"render": self.background[:0].view(0, 3, self.H, self.W).contiguous()}

        
        if (idx.min() < 0) or (idx.max() >= self.batch_size):
            raise ValueError(f"indices must be within [0, {self.batch_size-1}].")

        
        N_all = self.num_points_per_image[idx]
        if not torch.all(N_all == N_all[0]):
            raise NotImplementedError(
                "forward_subset；"
                 crop_forward_padding。"
            )
        N = int(N_all[0].item())
        B_sub = int(idx.numel())

        
        
        offsets = self.offsets  
        starts = offsets[idx]   

        arangeN = torch.arange(N, device=self.device)
        gather_idx = (starts[:, None] + arangeN[None, :]).reshape(-1)  

        xy_flat    = self.get_xyz[gather_idx]
        L_flat     = self.get_cholesky_elements[gather_idx]
        feats_flat = self.get_features[gather_idx]
        opac_flat  = self.get_opacity[gather_idx]

        
        with self._autocast_context():
            xy_k, L_k, feats_k, opac_k = self._to_kernel_dtype(xy_flat, L_flat, feats_flat, opac_flat)
            bg_sub_k, = self._to_kernel_dtype(bg_sub)
            xys, depths, radii, conics, num_tiles_hit = project_gaussians_2d_batch(
                B_sub, N, xy_k, L_k, self.H, self.W, self.tile_bounds
            )

            bg_sub = self.background[idx]  

            out_img = rasterize_gaussians_sum_batch(
                B_sub, N, xys, depths, radii, conics, num_tiles_hit,
                feats_k, opac_k,
                self.H, self.W, self.BLOCK_H, self.BLOCK_W,
                background=bg_sub_k
            )

        return {"render": out_img.permute(0, 3, 1, 2).contiguous()}

    def crop_forward_loop(self, crop_ndc, out_size):
        """
       。
        """
        H_crop, W_crop = out_size
        if crop_ndc.shape[0] != self.batch_size:
            raise ValueError(f"crop_ndc batch size mismatch")

        all_xy = self.get_xyz
        all_L = self.get_cholesky_elements
        all_feats = self.get_features
        all_opac = self.get_opacity
        offsets = self.offsets
        num_points_per_image = self.num_points_per_image[0].item()

        output_images = []
        tile_bounds_crop = ((W_crop + self.BLOCK_W - 1) // self.BLOCK_W, (H_crop + self.BLOCK_H - 1) // self.BLOCK_H, 1)

        for i in range(self.batch_size):
            start, end = offsets[i], offsets[i] + num_points_per_image
            xy_i, L_i, feats_i, opac_i = all_xy[start:end], all_L[start:end], all_feats[start:end], all_opac[start:end]
            crop_ndc_i = crop_ndc[i].tolist()
            bg_i = self.background[i]

            xy_c, L_c, feat_c, opa_c = self._crop_gaussians_single(xy_i, L_i, feats_i, opac_i, crop_ndc_i)

            xys, depths, radii, conics, tiles_hit = project_gaussians_2d(xy_c, L_c, H_crop, W_crop, tile_bounds_crop)
            img_chw = rasterize_gaussians_sum(xys, depths, radii, conics, tiles_hit, feat_c, opa_c, H_crop, W_crop, self.BLOCK_H, self.BLOCK_W, background=bg_i)
            img = img_chw.view(1, H_crop, W_crop, 3).permute(0, 3, 1, 2)
            
            output_images.append(img)

        return {"render": torch.cat(output_images, dim=0).contiguous()}
    
    def crop_forward_padding(self, crop_ndc, out_size):
        """
       。
        """
        H_crop, W_crop = out_size
        B = self.batch_size
        N = self.num_points_per_image[0].item()

        
        xy_bnd = self.get_xyz.view(B, N, 2)
        L_bnd = self.get_cholesky_elements.view(B, N, 3)
        feats_bnd = self.get_features.view(B, N, 3)
        opac_bnd = self.get_opacity.view(B, N, 1)

        
        x0, y0, x1, y1 = crop_ndc.T
        cx = ((x0 + x1) / 2).view(B, 1, 1)
        cy = ((y0 + y1) / 2).view(B, 1, 1)
        sx = ((x1 - x0) / 2).view(B, 1, 1)
        sy = ((y1 - y0) / 2).view(B, 1, 1)

        xy_c_bnd = torch.stack([(xy_bnd[..., 0] - cx[..., 0]) / sx[..., 0], (xy_bnd[..., 1] - cy[..., 0]) / sy[..., 0]], dim=-1)
        L_c_bnd = L_bnd.clone()
        L_c_bnd[..., 0] /= sx[..., 0]
        L_c_bnd[..., 1] /= sx[..., 0]
        L_c_bnd[..., 2] /= sy[..., 0]

        
        
        cull_mask_bn = (xy_c_bnd.abs() <= 1).all(dim=-1)          
        num_points_after_cull = cull_mask_bn.sum(dim=1)            
        max_points = int(num_points_after_cull.max())

        if max_points == 0:
            return {"render": self.background.view(B, 3, 1, 1)
                                     .expand(B, 3, H_crop, W_crop).contiguous()}

        
        
        dest_idx = (cull_mask_bn.cumsum(dim=1) - 1).clamp(min=0)   

        
        row_idx  = torch.arange(B, device=self.device)[:, None].expand_as(dest_idx)  
        valid    = cull_mask_bn                                                     

        row_idx_flat  = row_idx [valid]        
        col_idx_flat  = dest_idx[valid]        

        
        xy_padded   = xy_c_bnd.new_zeros (B, max_points, 2)
        L_padded    = L_c_bnd .new_zeros (B, max_points, 3)
        feats_padded = feats_bnd.new_zeros (B, max_points, 3)
        opac_padded = opac_bnd .new_zeros (B, max_points, 1)

        
        xy_padded  [row_idx_flat, col_idx_flat] = xy_c_bnd [valid]
        L_padded   [row_idx_flat, col_idx_flat] = L_c_bnd  [valid]
        feats_padded[row_idx_flat, col_idx_flat] = feats_bnd[valid]
        opac_padded[row_idx_flat, col_idx_flat] = opac_bnd [valid]

        
        xy_flat, L_flat = xy_padded.view(-1, 2), L_padded.view(-1, 3)
        feats_flat, opac_flat = feats_padded.view(-1, 3), opac_padded.view(-1, 1)

        
        tile_bounds_crop = ((W_crop + self.BLOCK_W - 1) // self.BLOCK_W, (H_crop + self.BLOCK_H - 1) // self.BLOCK_H, 1)
        xys, depths, radii, conics, tiles_hit = project_gaussians_2d_batch(B, max_points, xy_flat, L_flat, H_crop, W_crop, tile_bounds_crop)
        out = rasterize_gaussians_sum_batch(B, max_points, xys, depths, radii, conics, tiles_hit, feats_flat, opac_flat, H_crop, W_crop, self.BLOCK_H, self.BLOCK_W, background=self.background)

        return {"render": out.permute(0, 3, 1, 2).contiguous()}
    
    @staticmethod
    def _crop_gaussians_single(xy, L, feats, opac, crop_ndc):
        x0, y0, x1, y1 = crop_ndc
        cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
        sx, sy = (x1 - x0) / 2, (y1 - y0) / 2
        
        xy_c = torch.stack([(xy[:, 0] - cx) / sx, (xy[:, 1] - cy) / sy], dim=-1)
        mask = (xy_c.abs() <= 1).all(dim=-1)
        xy_c, feats_c, opac_c, L_c = xy_c[mask], feats[mask], opac[mask], L[mask]
        
        if xy_c.numel() > 0:
            L_c = L_c.clone()
            L_c[:, 0] /= sx
            L_c[:, 1] /= sx
            L_c[:, 2] /= sy
        return xy_c, L_c, feats_c, opac_c
    
    def forward_aug_def(self, strength: float = 0.1, num_control_points: int = 5, influence_radius: float = 0.5):
        """
       。
       。

        Args:
            strength (float):。
            num_control_points (int):。
            influence_radius (float):  sigma)。
        """
        
        if strength <= 0:
            return self.forward()

        
        
        xy_aug = self.params['xy'].clone().detach()
        cholesky_params_aug = self.params['cholesky'].clone().detach()

        total_points = xy_aug.shape[0]
        if num_control_points <= 0 or total_points == 0:
            return self.forward()
        
        
        
        num_control_points = min(num_control_points, total_points)
        
        
        indices = torch.randperm(total_points, device=self.device)[:num_control_points]
        control_points = xy_aug[indices]  
        
        
        
        control_strengths = (torch.rand(num_control_points, 1, device=self.device) - 0.5) * 2 * strength
        
        
        
        
        p = xy_aug
        e = control_points
        
        
        
        
        p_exp = p.unsqueeze(1)
        e_exp = e.unsqueeze(0)
        diff = p_exp - e_exp
        
        
        dist_sq = torch.sum(diff**2, dim=-1)
        
        
        
        sigma_sq = influence_radius**2
        weights = torch.exp(-dist_sq / (2 * sigma_sq))
        
        
        
        
        eps = 1e-6
        
        
        
        
        term_numerator = diff * weights.unsqueeze(-1) * control_strengths.T.unsqueeze(-1)
        term_denominator = (dist_sq + eps).unsqueeze(-1)
        
        
        displacement = torch.sum(term_numerator / term_denominator, dim=1)  
        
        
        xy_aug += displacement
        
        
        
        
        
        
        
        dist_sq_exp = dist_sq.unsqueeze(-1)
        weights_exp = weights.unsqueeze(-1)
        
        
        
        
        
        I_term = torch.eye(2, device=self.device).reshape(1, 1, 2, 2) / (dist_sq_exp + eps).unsqueeze(-1)
        
        outer_prod = torch.einsum('nci,ncj->ncij', diff, diff)
        term1_J = I_term - 2 * outer_prod / (dist_sq_exp + eps).unsqueeze(-1)**2

        
        
        
        term2_J = -diff / sigma_sq * weights_exp
        
        
        
        
        J_D_contribs = (
            torch.einsum('ncij,nc->ncij', term1_J, weights) * control_strengths.T.unsqueeze(-1).unsqueeze(-1) +
            torch.einsum('nci,ncj->ncij', diff / (dist_sq_exp + eps), term2_J) * control_strengths.T.unsqueeze(-1).unsqueeze(-1)
        )
        
        
        J_D = torch.sum(J_D_contribs, dim=1) 
        
        
        identity = torch.eye(2, device=self.device).expand(total_points, 2, 2)
        J_T = identity + J_D
        
        
        
        chol_elements = cholesky_params_aug + self.cholesky_bound
        L = torch.zeros(total_points, 2, 2, device=self.device)
        L[:, 0, 0] = chol_elements[:, 0]
        L[:, 1, 0] = chol_elements[:, 1]
        L[:, 1, 1] = chol_elements[:, 2]
        
        
        L_new = torch.bmm(J_T, L)
        
        
        cholesky_params_aug[:, 0] = L_new[:, 0, 0]
        cholesky_params_aug[:, 1] = L_new[:, 1, 0]
        cholesky_params_aug[:, 2] = L_new[:, 1, 1]
        
        
        cholesky_params_aug -= self.cholesky_bound.expand_as(cholesky_params_aug)
        
        
        num_points_for_cuda = self.num_points_per_image[0].item()
        
        with self._autocast_context():
            
            xy = xy_aug
            chol = cholesky_params_aug
            
            
            feats = self.get_features
            opac = self.get_opacity
            bg = self.background
            
            
            xy = xy.to(self.compute_dtype)
            chol = chol.to(self.compute_dtype)
            feats = feats.to(self.compute_dtype)
            opac = opac.to(self.compute_dtype)
            bg = bg.to(self.compute_dtype)

            
            lo = -1 + self._eps_for_clamp * 10
            hi = 1 - self._eps_for_clamp * 10
            xy.clamp_(lo, hi)

            
            xy_k, chol_k, feats_k, opac_k, bg_k = self._to_kernel_dtype(xy, chol, feats, opac, bg)
            
            L_k = chol_k + self.cholesky_bound
            
            xys, depths, radii, conics, num_tiles_hit = project_gaussians_2d_batch(
                self.batch_size, num_points_for_cuda, xy_k, L_k,
                self.H, self.W, self.tile_bounds
            )
            
            out_img = rasterize_gaussians_sum_batch(
                self.batch_size, num_points_for_cuda,
                xys, depths, radii, conics, num_tiles_hit,
                feats_k, opac_k,
                self.H, self.W, self.BLOCK_H, self.BLOCK_W,
                background=bg_k
            )

        return {"render": out_img.permute(0, 3, 1, 2).contiguous()}
