import math
import os
from dataclasses import dataclass

import numpy as np
import threestudio
import torch
import torch.nn.functional as F

from threestudio.models.background.base import BaseBackground
from threestudio.models.geometry.base import BaseGeometry
from threestudio.models.materials.base import BaseMaterial
from threestudio.models.renderers.base import Rasterizer
from threestudio.utils.typing import *

import nvdiffrast.torch as dr
from .facemesh_batch_render import FaceMeshBatchRender


def dot(x, y):
    return torch.sum(x * y, -1, keepdim=True)


def length(x, eps=1e-20):
    return torch.sqrt(torch.clamp(dot(x, x), min=eps))


def safe_normalize(x, eps=1e-20):
    return x / length(x, eps)


def make_divisible(x, m=8):
    return int(math.ceil(x / m) * m)


def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'):
    assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
    y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
    if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger
        y = torch.nn.functional.interpolate(y, size, mode=min)
    else: # Magnification
        if mag == 'bilinear' or mag == 'bicubic':
            y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)
        else:
            y = torch.nn.functional.interpolate(y, size, mode=mag)
    return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC

def scale_img_hwc(x, size, mag='bilinear', min='bilinear'):
    return scale_img_nhwc(x[None, ...], size, mag, min)[0]

@threestudio.register("facemesh-rasterizer")
class FaceMesh(Rasterizer, FaceMeshBatchRender):
    @dataclass
    class Config(Rasterizer.Config):
        debug: bool = False
        invert_bg_prob: float = 1.0
        back_ground_color: Tuple[float, float, float] = (1, 1, 1)
        force_cuda_rast: bool = False

    cfg: Config

    def configure(
        self,
        geometry: BaseGeometry,
        material: BaseMaterial,
        background: BaseBackground,
    ) -> None:
        threestudio.info(
            "[Note] Gaussian Splatting doesn't support material and background now."
        )
        super().configure(geometry, material, background)
        self.background_tensor = torch.tensor(
            self.cfg.back_ground_color, dtype=torch.float32, device="cuda"
        )
        
        if not self.cfg.force_cuda_rast:
            print('RasterizeGLContext')
            self.glctx = dr.RasterizeGLContext()
        else:
            print('RasterizeCudaContext')
            self.glctx = dr.RasterizeCudaContext()


    def forward(
        self,
        viewpoint_camera,
        bg_color: torch.Tensor,        
        **kwargs
    ) -> Dict[str, Any]:
        """
        Render the scene.
        Background tensor (bg_color) must be on GPU!
        """

        if self.training:
            invert_bg_color = np.random.rand() > self.cfg.invert_bg_prob
        else:
            invert_bg_color = False

        bg_color = bg_color if not invert_bg_color else (1.0 - bg_color)
        geometry = self.geometry
        # do super-sampling
        # ssaa = min(2.0, max(0.125, 2 * np.random.random()))
        h0 = viewpoint_camera.image_height
        w0 = viewpoint_camera.image_width
        # if ssaa != 1:
        #     h = make_divisible(h0 * ssaa, 8)
        #     w = make_divisible(w0 * ssaa, 8)
        # else:
        ssaa = 1
        h, w = h0, w0
        
        # get v
        if geometry.train_geo:
            v = geometry.mesh.v + geometry.v_offsets # [N, 3]
        else:
            v = geometry.mesh.v

        # print(v[:2,:])

        # print(geometry.mesh.vt[:2,:])

        # print(geometry.mesh.f[:2,:])

        # print(viewpoint_camera.world_view_transform)

        # print(viewpoint_camera.full_proj_transform)

        v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), viewpoint_camera.world_view_transform).float().unsqueeze(0)   
        v_clip = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), viewpoint_camera.full_proj_transform).float().unsqueeze(0)
        rast, rast_db = dr.rasterize(self.glctx, v_clip, geometry.mesh.f, (h, w))

        alpha = (rast[0, ..., 3:] > 0).float()
        # depth, _ = dr.interpolate(-v_cam[..., [2]], rast, geometry.mesh.f) # [1, H, W, 1]
        depth, _ = dr.interpolate(v_cam[..., [2]], rast, geometry.mesh.f) # [1, H, W, 1]
        depth = depth.squeeze(0) # [H, W, 1]
        
        # assert not ((geometry.mesh.f-geometry.mesh.ft).any())

        texc, texc_db = dr.interpolate(geometry.mesh.vt.unsqueeze(0).contiguous(), rast, geometry.mesh.ft, rast_db=rast_db, diff_attrs='all')
        albedo = dr.texture(geometry.raw_albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode="linear-mipmap-linear") # [1, H, W, 3]
        albedo = torch.sigmoid(albedo)
        
        # import cv2
        # temp = torch.sigmoid(geometry.raw_albedo).cpu().data.numpy()[:,:,::-1]
        # temp = temp/np.max(temp)*255.0
        # cv2.imwrite('raw_albedo.png', temp.astype(np.uint8))

        # temp = albedo[0].cpu().data.numpy()[:,:,::-1]
        # temp = temp/np.max(temp)*255.0
        # cv2.imwrite('debug.png', temp.astype(np.uint8))
        # exit()

        # get vn and render normal
        if geometry.train_geo:
            i0, i1, i2 = geometry.mesh.f[:, 0].long(), geometry.mesh.f[:, 1].long(), geometry.mesh.f[:, 2].long()
            v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]

            face_normals = torch.cross(v1 - v0, v2 - v0)
            face_normals = safe_normalize(face_normals)
            
            vn = torch.zeros_like(v)
            vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
            vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
            vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)

            vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
        else:
            vn = geometry.mesh.vn
        
        normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, geometry.mesh.fn)
        normal = safe_normalize(normal[0])

        # rotated normal (where [0, 0, 1] always faces camera)
        pose = viewpoint_camera.c2w
        rot_normal = normal @ (pose[:3, :3].T)
        viewcos = rot_normal[..., [2]]

        # antialias
        albedo = dr.antialias(albedo, rast, v_clip, geometry.mesh.f).squeeze(0) # [H, W, 3]
        albedo = alpha * albedo + (1 - alpha) * bg_color

        # ssaa
        if ssaa != 1:
            albedo = scale_img_hwc(albedo, (h0, w0))
            alpha = scale_img_hwc(alpha, (h0, w0))
            depth = scale_img_hwc(depth, (h0, w0))
            normal = scale_img_hwc(normal, (h0, w0))
            viewcos = scale_img_hwc(viewcos, (h0, w0))


       
        # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
        # They will be excluded from value updates used in the splitting criteria.
        albedo = albedo.permute(2,0,1)
        depth = depth.permute(2,0,1)
        alpha = alpha.permute(2,0,1)
        normal = normal.permute(2,0,1)

        return {
            "render": albedo.clamp(0, 1),
            "render_kd": None,
            "depth": depth,
            "mask": alpha,
            "normal": (normal + 1) / 2,
            "viewcos" : viewcos,
            "full_proj_transform":viewpoint_camera.full_proj_transform,

        }

