import cv2,torch
import numpy as np
from PIL import Image
import torchvision.transforms as T
import torch.nn.functional as F
import scipy.signal

mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))


def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
    """
    depth: (H, W)
    """

    x = np.nan_to_num(depth) # change nan to 0
    if minmax is None:
        mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
        ma = np.max(x)
    else:
        mi,ma = minmax

    x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
    x = (255*x).astype(np.uint8)
    x_ = cv2.applyColorMap(x, cmap)
    return x_, [mi,ma]

def init_log(log, keys):
    for key in keys:
        log[key] = torch.tensor([0.0], dtype=float)
    return log

def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
    """
    depth: (H, W)
    """
    if type(depth) is not np.ndarray:
        depth = depth.cpu().numpy()

    x = np.nan_to_num(depth) # change nan to 0
    if minmax is None:
        mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
        ma = np.max(x)
    else:
        mi,ma = minmax

    x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
    x = (255*x).astype(np.uint8)
    x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
    x_ = T.ToTensor()(x_)  # (3, H, W)
    return x_, [mi,ma]

def N_to_reso(n_voxels, bbox):
    xyz_min, xyz_max = bbox
    dim = len(xyz_min)
    voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim)
    return ((xyz_max - xyz_min) / voxel_size).long().tolist()

def cal_n_samples(reso, step_ratio=0.5):
    return int(np.linalg.norm(reso)/step_ratio)




__LPIPS__ = {}
def init_lpips(net_name, device):
    assert net_name in ['alex', 'vgg']
    import lpips
    print(f'init_lpips: lpips_{net_name}')
    return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)

def rgb_lpips(np_gt, np_im, net_name, device):
    if net_name not in __LPIPS__:
        __LPIPS__[net_name] = init_lpips(net_name, device)
    gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
    im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
    return __LPIPS__[net_name](gt, im, normalize=True).item()


def findItem(items, target):
    for one in items:
        if one[:len(target)]==target:
            return one
    return None


''' Evaluation metrics (ssim, lpips)
'''
def rgb_ssim(img0, img1, max_val,
             filter_size=11,
             filter_sigma=1.5,
             k1=0.01,
             k2=0.03,
             return_map=False):
    # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
    assert len(img0.shape) == 3
    assert img0.shape[-1] == 3
    assert img0.shape == img1.shape

    # Construct a 1D Gaussian blur filter.
    hw = filter_size // 2
    shift = (2 * hw - filter_size + 1) / 2
    f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
    filt = np.exp(-0.5 * f_i)
    filt /= np.sum(filt)

    # Blur in x and y (faster than the 2D convolution).
    def convolve2d(z, f):
        return scipy.signal.convolve2d(z, f, mode='valid')

    filt_fn = lambda z: np.stack([
        convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
        for i in range(z.shape[-1])], -1)
    mu0 = filt_fn(img0)
    mu1 = filt_fn(img1)
    mu00 = mu0 * mu0
    mu11 = mu1 * mu1
    mu01 = mu0 * mu1
    sigma00 = filt_fn(img0**2) - mu00
    sigma11 = filt_fn(img1**2) - mu11
    sigma01 = filt_fn(img0 * img1) - mu01

    # Clip the variances and covariances to valid values.
    # Variance must be non-negative:
    sigma00 = np.maximum(0., sigma00)
    sigma11 = np.maximum(0., sigma11)
    sigma01 = np.sign(sigma01) * np.minimum(
        np.sqrt(sigma00 * sigma11), np.abs(sigma01))
    c1 = (k1 * max_val)**2
    c2 = (k2 * max_val)**2
    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
    ssim_map = numer / denom
    ssim = np.mean(ssim_map)
    return ssim_map if return_map else ssim


import torch.nn as nn
class TVLoss(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(TVLoss,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self._tensor_size(x[:,:,1:,:])
        count_w = self._tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]

class colorTVLoss(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(colorTVLoss,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x,y):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        h_tv_x = torch.mean(torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2),dim=1,keepdim=True)
        w_tv_x = torch.mean(torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2),dim=1,keepdim=True)
        weight_h_x = torch.exp(-h_tv_x)
        weight_w_x = torch.exp(-w_tv_x)
        
        h_y = y.size()[2]
        w_y = y.size()[3]
        # inteplote weight to y size
        if h_x != h_y:
            weight_h_x = F.interpolate(weight_h_x, size=(h_y-1, w_y), mode='bilinear', align_corners=False)
            weight_w_x = F.interpolate(weight_w_x, size=(h_y, w_y-1), mode='bilinear', align_corners=False)
        h_tv_y = torch.mean(torch.pow((y[:,:,1:,:]-y[:,:,:h_y-1,:]) * weight_h_x,2))
        w_tv_y = torch.mean(torch.pow((y[:,:,:,1:]-y[:,:,:,:w_y-1]) * weight_w_x,2))

        return self.TVLoss_weight*2*(h_tv_y+w_tv_y)/batch_size

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]


def DSLoss(depth_map_patch):
    d_x = depth_map_patch[:,1:,:] - depth_map_patch[:,:-1,:]
    d_y = depth_map_patch[1:,:,:] - depth_map_patch[:-1,:,:]
    d_x = torch.mean(torch.pow(d_x,2))
    d_y = torch.mean(torch.pow(d_y,2))
    return d_x + d_y

def ColorDSLoss(color_map_patch,depth_map_patch):
    c_x = color_map_patch[:,1:,:] - color_map_patch[:,:-1,:]
    c_y = color_map_patch[1:,:,:] - color_map_patch[:-1,:,:]
    c_x = torch.mean(torch.pow(c_x,2))
    c_y = torch.mean(torch.pow(c_y,2))
    weight_c_x = torch.exp(-c_x)
    weight_c_y = torch.exp(-c_y)
    d_x = depth_map_patch[:,1:,:] - depth_map_patch[:,:-1,:]
    d_y = depth_map_patch[1:,:,:] - depth_map_patch[:-1,:,:]
    d_x = torch.mean(torch.pow(d_x * weight_c_x,2))
    d_y = torch.mean(torch.pow(d_y * weight_c_y,2))

    return d_x + d_y

def cal_disparity_loss(pre_depth, gt_depth, weight=None):
    if weight is None: 
        return torch.mean((1.0/(pre_depth+1e-6) - 1.0/(gt_depth+1e-6))**2)
    else:
        return torch.mean(weight * (1.0/(pre_depth+1e-6) - 1.0/(gt_depth+1e-6))**2)
    
def cal_depth_loss(pre_depth, gy_depth, weight=None):
    if weight is None: 
        return torch.mean((pre_depth - gy_depth)**2)
    else:
        return torch.mean(weight * (pre_depth - gy_depth)**2)
    
def normalize_depth(depth):
    t_d = torch.median(depth)
    s_d = torch.mean(torch.abs(depth - t_d))
    depth_norm = (depth - t_d.detach()) / (s_d.detach() + 1e-10)
    return depth_norm

def cal_occ_loss(sigma, rgb_ray, reg_range, wb_range, wb_prior=False):
    rgb_mean = rgb_ray.mean(-1)
    # Compute a mask for the white/black background region if using a prior
    if wb_prior:
        white_mask = rgb_mean > 0.99 # A naive way to locate white background
        black_mask = rgb_mean < 0.01  # A naive way to locate black background
        rgb_mask = (white_mask | black_mask) # White or black background
        rgb_mask[:, wb_range:] = 0 # White or black background range
    else:
        rgb_mask = torch.zeros_like(rgb_mean)
    
    # Create a mask for the general regularization region
    # It can be implemented as a one-line-code.
    if reg_range > 0:
        rgb_mask[:, :reg_range] = 1# Penalize the points in reg_range close to the camera

    # Compute the density-weighted loss within the regularization and white/black background mask
    return torch.mean(sigma * rgb_mask)

import plyfile
import skimage.measure
def convert_sdf_samples_to_ply(
    pytorch_3d_sdf_tensor,
    ply_filename_out,
    bbox,
    level=0.5,
    offset=None,
    scale=None,
):
    """
    Convert sdf samples to .ply

    :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
    :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
    :voxel_size: float, the size of the voxels
    :ply_filename_out: string, path of the filename to save to

    This function adapted from: https://github.com/RobotLocomotion/spartan
    """

    numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()
    voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape))

    verts, faces, normals, values = skimage.measure.marching_cubes(
        numpy_3d_sdf_tensor, level=level, spacing=voxel_size
    )
    faces = faces[...,::-1] # inverse face orientation

    # transform from voxel coordinates to camera coordinates
    # note x and y are flipped in the output of marching_cubes
    mesh_points = np.zeros_like(verts)
    mesh_points[:, 0] = bbox[0,0] + verts[:, 0]
    mesh_points[:, 1] = bbox[0,1] + verts[:, 1]
    mesh_points[:, 2] = bbox[0,2] + verts[:, 2]

    # apply additional offset and scale
    if scale is not None:
        mesh_points = mesh_points / scale
    if offset is not None:
        mesh_points = mesh_points - offset

    # try writing to the ply file

    num_verts = verts.shape[0]
    num_faces = faces.shape[0]

    verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])

    for i in range(0, num_verts):
        verts_tuple[i] = tuple(mesh_points[i, :])

    faces_building = []
    for i in range(0, num_faces):
        faces_building.append(((faces[i, :].tolist(),)))
    faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])

    el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
    el_faces = plyfile.PlyElement.describe(faces_tuple, "face")

    ply_data = plyfile.PlyData([el_verts, el_faces])
    print("saving mesh to %s" % (ply_filename_out))
    ply_data.write(ply_filename_out)
