import cv2,torch
import numpy as np
from PIL import Image
import torchvision.transforms as T
import torch.nn.functional as F
import scipy.signal
from skimage.metrics import structural_similarity

mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))


def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
    """
    depth: (H, W)
    """

    x = np.nan_to_num(depth) # change nan to 0
    if minmax is None:
        mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
        ma = np.max(x)
    else:
        mi,ma = minmax

    x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
    x = (255*x).astype(np.uint8)
    x_ = cv2.applyColorMap(x, cmap)
    return x_, [mi,ma]

def init_log(log, keys):
    for key in keys:
        log[key] = torch.tensor([0.0], dtype=float)
    return log

def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
    """
    depth: (H, W)
    """
    if type(depth) is not np.ndarray:
        depth = depth.cpu().numpy()

    x = np.nan_to_num(depth) # change nan to 0
    if minmax is None:
        mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
        ma = np.max(x)
    else:
        mi,ma = minmax

    x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
    x = (255*x).astype(np.uint8)
    x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
    x_ = T.ToTensor()(x_)  # (3, H, W)
    return x_, [mi,ma]

def N_to_reso(n_voxels, bbox):
    xyz_min, xyz_max = bbox
    dim = len(xyz_min)
    voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim)
    return ((xyz_max - xyz_min) / voxel_size).long().tolist()

def cal_n_samples(reso, step_ratio=0.5):
    return int(np.linalg.norm(reso)/step_ratio)




__LPIPS__ = {}
def init_lpips(net_name, device):
    assert net_name in ['alex', 'vgg']
    import lpips
    print(f'init_lpips: lpips_{net_name}')
    return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)

def rgb_lpips(np_gt, np_im, net_name, device):
    if net_name not in __LPIPS__:
        __LPIPS__[net_name] = init_lpips(net_name, device)
    gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
    im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
    return __LPIPS__[net_name](gt, im, normalize=True).item()


def findItem(items, target):
    for one in items:
        if one[:len(target)]==target:
            return one
    return None


''' Evaluation metrics (ssim, lpips)
'''
# def rgb_ssim(img0, img1, max_val,
#              filter_size=11,
#              filter_sigma=1.5,
#              k1=0.01,
#              k2=0.03,
#              return_map=False):
#     # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
#     assert len(img0.shape) == 3
#     assert img0.shape[-1] == 3
#     assert img0.shape == img1.shape

#     # Construct a 1D Gaussian blur filter.
#     hw = filter_size // 2
#     shift = (2 * hw - filter_size + 1) / 2
#     f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
#     filt = np.exp(-0.5 * f_i)
#     filt /= np.sum(filt)

#     # Blur in x and y (faster than the 2D convolution).
#     def convolve2d(z, f):
#         return scipy.signal.convolve2d(z, f, mode='valid')

#     filt_fn = lambda z: np.stack([
#         convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
#         for i in range(z.shape[-1])], -1)
#     mu0 = filt_fn(img0)
#     mu1 = filt_fn(img1)
#     mu00 = mu0 * mu0
#     mu11 = mu1 * mu1
#     mu01 = mu0 * mu1
#     sigma00 = filt_fn(img0**2) - mu00
#     sigma11 = filt_fn(img1**2) - mu11
#     sigma01 = filt_fn(img0 * img1) - mu01

#     # Clip the variances and covariances to valid values.
#     # Variance must be non-negative:
#     sigma00 = np.maximum(0., sigma00)
#     sigma11 = np.maximum(0., sigma11)
#     sigma01 = np.sign(sigma01) * np.minimum(
#         np.sqrt(sigma00 * sigma11), np.abs(sigma01))
#     c1 = (k1 * max_val)**2
#     c2 = (k2 * max_val)**2
#     numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
#     denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
#     ssim_map = numer / denom
#     ssim = np.mean(ssim_map)
#     return ssim_map if return_map else ssim

def rgb_ssim(gt_frame: np.ndarray, eval_frame: np.ndarray):
    """
    gt_frame: (H, W, 3)
    eval_frame: (H, W, 3)
    """
    assert gt_frame.shape == eval_frame.shape
    assert gt_frame.dtype == eval_frame.dtype

    return structural_similarity(gt_frame, eval_frame, channel_axis=-1, data_range=1.0, gaussian_weights=True, sigma=1.5,
                                            use_sample_covariance=False)


import torch.nn as nn
class TVLoss(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(TVLoss,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self._tensor_size(x[:,:,1:,:])
        count_w = self._tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]

class colorTVLoss(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(colorTVLoss,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x,y):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        h_tv_x = torch.mean(torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2),dim=1,keepdim=True)
        w_tv_x = torch.mean(torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2),dim=1,keepdim=True)
        weight_h_x = torch.exp(-h_tv_x)
        weight_w_x = torch.exp(-w_tv_x)
        
        h_y = y.size()[2]
        w_y = y.size()[3]
        # inteplote weight to y size
        if h_x != h_y:
            weight_h_x = F.interpolate(weight_h_x, size=(h_y-1, w_y), mode='bilinear', align_corners=False)
            weight_w_x = F.interpolate(weight_w_x, size=(h_y, w_y-1), mode='bilinear', align_corners=False)
        h_tv_y = torch.mean(torch.pow((y[:,:,1:,:]-y[:,:,:h_y-1,:]) * weight_h_x,2))
        w_tv_y = torch.mean(torch.pow((y[:,:,:,1:]-y[:,:,:,:w_y-1]) * weight_w_x,2))

        return self.TVLoss_weight*2*(h_tv_y+w_tv_y)/batch_size

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]


def DSLoss(depth_map_patch):
    d_x = depth_map_patch[:,1:,:] - depth_map_patch[:,:-1,:]
    d_y = depth_map_patch[1:,:,:] - depth_map_patch[:-1,:,:]
    d_x = torch.mean(torch.pow(d_x,2))
    d_y = torch.mean(torch.pow(d_y,2))
    return d_x + d_y

def ColorDSLoss(color_map_patch,depth_map_patch):
    c_x = color_map_patch[:,1:,:] - color_map_patch[:,:-1,:]
    c_y = color_map_patch[1:,:,:] - color_map_patch[:-1,:,:]
    c_x = torch.mean(torch.pow(c_x,2))
    c_y = torch.mean(torch.pow(c_y,2))
    weight_c_x = torch.exp(-c_x)
    weight_c_y = torch.exp(-c_y)
    d_x = depth_map_patch[:,1:,:] - depth_map_patch[:,:-1,:]
    d_y = depth_map_patch[1:,:,:] - depth_map_patch[:-1,:,:]
    d_x = torch.mean(torch.pow(d_x * weight_c_x,2))
    d_y = torch.mean(torch.pow(d_y * weight_c_y,2))

    return d_x + d_y

def cal_disparity_loss(pre_depth, gt_depth, weight=None):
    if weight is None: 
        return torch.mean((1.0/(pre_depth+1e-6) - 1.0/(gt_depth+1e-6))**2)
    else:
        return torch.mean(weight * (1.0/(pre_depth+1e-6) - 1.0/(gt_depth+1e-6))**2)
    
def cal_depth_loss(pre_depth, gy_depth, weight=None):
    if weight is None: 
        return torch.mean((pre_depth - gy_depth)**2)
    else:
        return torch.mean(weight * (pre_depth - gy_depth)**2)
    
def normalize_depth(depth):
    t_d = torch.median(depth)
    s_d = torch.mean(torch.abs(depth - t_d))
    depth_norm = (depth - t_d.detach()) / (s_d.detach() + 1e-10)
    return depth_norm


import plyfile
import skimage.measure
def convert_sdf_samples_to_ply(
    pytorch_3d_sdf_tensor,
    ply_filename_out,
    bbox,
    level=0.5,
    offset=None,
    scale=None,
):
    """
    Convert sdf samples to .ply

    :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
    :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
    :voxel_size: float, the size of the voxels
    :ply_filename_out: string, path of the filename to save to

    This function adapted from: https://github.com/RobotLocomotion/spartan
    """

    numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()
    voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape))

    verts, faces, normals, values = skimage.measure.marching_cubes(
        numpy_3d_sdf_tensor, level=level, spacing=voxel_size
    )
    faces = faces[...,::-1] # inverse face orientation

    # transform from voxel coordinates to camera coordinates
    # note x and y are flipped in the output of marching_cubes
    mesh_points = np.zeros_like(verts)
    mesh_points[:, 0] = bbox[0,0] + verts[:, 0]
    mesh_points[:, 1] = bbox[0,1] + verts[:, 1]
    mesh_points[:, 2] = bbox[0,2] + verts[:, 2]

    # apply additional offset and scale
    if scale is not None:
        mesh_points = mesh_points / scale
    if offset is not None:
        mesh_points = mesh_points - offset

    # try writing to the ply file

    num_verts = verts.shape[0]
    num_faces = faces.shape[0]

    verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])

    for i in range(0, num_verts):
        verts_tuple[i] = tuple(mesh_points[i, :])

    faces_building = []
    for i in range(0, num_faces):
        faces_building.append(((faces[i, :].tolist(),)))
    faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])

    el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
    el_faces = plyfile.PlyElement.describe(faces_tuple, "face")

    ply_data = plyfile.PlyData([el_verts, el_faces])
    print("saving mesh to %s" % (ply_filename_out))
    ply_data.write(ply_filename_out)


@torch.no_grad()
def warping(allrays, ray_idx, H, W, focal, depth_map, c2w_j, startposition_in_allray, patch_mask, patch_size, device):
    depth_map_patch = torch.cat([depth_map]*(patch_size**2))
    c2w_j_patch = torch.cat([c2w_j] * (patch_size**2))
    startposition_in_allray_patch = torch.cat([startposition_in_allray] * (patch_size**2)).to(device)
    
    
    rays_train = torch.zeros((ray_idx.shape[0], 6), device=device)
    rays_train[patch_mask] = allrays[ray_idx[patch_mask].cpu()].to(device)
    # camera coordiante i to camera coordinate j
    xyz = rays_train[:, 0:3] + depth_map_patch.unsqueeze(-1) * rays_train[:, 3:]
    xyz[:, 0:2] = -xyz[:, 0:2]
    xyz = xyz.unsqueeze(-1) - c2w_j_patch[:, :, 3:]
    xyz = torch.matmul(c2w_j_patch[:, :, 0:3].transpose(1, 2), xyz).squeeze(-1)
    # camera coordinate j to image coordinate
    u = torch.round(xyz[:, 0] / (-xyz[:, 2]) * focal + W*.5).to(torch.int)
    v = torch.round(xyz[:, 1] / (-xyz[:, 2]) * focal + H*.5).to(torch.int)
    within_mask = (u >= 0) & (u < W) & (v >= 0) & (v < H)
    
    position_in_allray = startposition_in_allray_patch.to(device) + v * W + u
    return position_in_allray, within_mask


@ torch.no_grad()
def cal_reprojection_error(rgb, projected_rgb, mask,  patch_size):    
    # repro_error = torch.zeros(mask.shape[0]).to(mask.device)
    # repro_error[mask] = torch.mean((rgb - projected_rgb)**2, -1)
    # repro_error_patch = repro_error.view(int(repro_error.shape[0] / (patch_size**2)), patch_size**2)
    # mask_patch = mask.view(int(mask.shape[0] / (patch_size**2)), patch_size**2)
    # repro_error_patch = repro_error_patch.sum(dim=-1) / mask_patch.sum(dim=-1) 
    
    # return repro_error_patch

    repro_error = torch.ones(mask.shape[0]).to(mask.device)
    repro_error[mask] = torch.mean((rgb - projected_rgb)**2, -1)
    repro_error = torch.mean(repro_error.view(int(repro_error.shape[0] / (patch_size**2)), patch_size**2), dim=-1) # get the mean reprojection error of each patch
    return repro_error


# calculate reprojection error with rgb of frame i and rgb warped to frame j
@torch.no_grad()
def patchify(ray_idx, H, W, patch_size, total_frame_len, device):
    patch_offset = patch_size // 2
    t_ref = (ray_idx // (H * W)).to(device).unsqueeze(-1).repeat(1, patch_size**2)  # frame num
    v_ref = ((ray_idx % (H * W)) // W).to(device).unsqueeze(-1).repeat(1, patch_size**2) + torch.tensor([i - patch_offset for i in range(patch_size)], device=device).repeat(patch_size)
    u_ref = ((ray_idx % (H * W)) % W).to(device).unsqueeze(-1).repeat(1, patch_size**2) + torch.tensor([j - patch_offset for j in range(patch_size)], device=device).repeat_interleave(patch_size)
    patch_ray_idx = t_ref * (H * W) + v_ref * W + u_ref
    patch_ray_idx = patch_ray_idx.view(-1)  # (batch_size * k * k)
    patch_mask = ((u_ref >= 0) & (u_ref < W) & (v_ref >= 0) & (v_ref < H)) & (t_ref >= 0) & (t_ref < total_frame_len)
    patch_mask = patch_mask.view(-1)  # mask for pixels out of (H, W)

    return patch_ray_idx, patch_mask