from math import log
from loguru import logger

import torch
import torch.nn.functional as F
from einops import repeat, rearrange, reduce
from kornia.utils import create_meshgrid

from .geometry import warp_kpts

##############  ↓  Coarse-Level supervision  ↓  ##############


@torch.no_grad()
def mask_pts_at_padded_regions(grid_pt, mask):
    """For megadepth dataset, zero-padding exists in images"""
    mask = repeat(mask, 'n h w -> n (h w) c', c=2)
    grid_pt[~mask.bool()] = 0
    return grid_pt


@torch.no_grad()
def spvs_coarse(data, config): # 处理有监督学习所需要的标签
    """
    Update:
        data (dict): {
            "conf_matrix_gt": [N, hw0, hw1],
            'spv_b_ids': [M]
            'spv_i_ids': [M]
            'spv_j_ids': [M]
            'spv_w_pt0_i': [N, hw0, 2], in original image resolution
            'spv_pt1_i': [N, hw1, 2], in original image resolution
        }
        
    NOTE:
        - for scannet dataset, there're 3 kinds of resolution {i, c, f}
        - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
    """
    # 1. misc
    device = data['image0'].device
    N, _, H0, W0 = data['image0'].shape
    _, _, H1, W1 = data['image1'].shape
    scale = config['LOFTR']['RESOLUTION'][0] # eg: 8 for (FPN 8 2) 
    scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
    scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
    h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) # [H0, W0, H1, W1] // 8

    # 2. warp grids
    # create kpts in meshgrid and resize them to image resolution
    grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1)    # [N, hw, 2]
    grid_pt0_i = scale0 * grid_pt0_c
    grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
    grid_pt1_i = scale1 * grid_pt1_c

    # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
    if 'mask0' in data:
        grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
        grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])

    # warp kpts bi-directionally and resize them to coarse-level resolution
    # (no depth consistency check, since it leads to worse results experimentally)
    # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
    _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
    _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
    w_pt0_c = w_pt0_i / scale1 # 第一张图的点在第二个画面中的位置（coarse level）
    w_pt1_c = w_pt1_i / scale0

    # 3. check if mutual nearest neighbor
    w_pt0_c_round = w_pt0_c[:, :, :].round().long()
    nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
    w_pt1_c_round = w_pt1_c[:, :, :].round().long()
    nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0

    # corner case: out of boundary
    def out_bound_mask(pt, w, h):
        return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) # 只要有一个true就是越界了
    nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 # 这个变量表示图0中的关键点在图1中最近的index是哪个（coarse level）(N, L)
    nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0

    loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
    correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1) # (N, L) 表示图0哪个index是一个关键点
    correct_0to1[:, 0] = False  # ignore the top-left corner

    # 4. construct a gt conf_matrix
    conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
    b_ids, i_ids = torch.where(correct_0to1 != 0) # b_id -- Batch编号；i_id -- 图0中的编号
    j_ids = nearest_index1[b_ids, i_ids] # j_id -- 图1中的编号

    conf_matrix_gt[b_ids, i_ids, j_ids] = 1 # 将对应位置的置信度设置为1
    data.update({'conf_matrix_gt': conf_matrix_gt})

    # 5. save coarse matches(gt) for training fine level
    if len(b_ids) == 0:
        logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
        # this won't affect fine-level loss calculation
        b_ids = torch.tensor([0], device=device)
        i_ids = torch.tensor([0], device=device)
        j_ids = torch.tensor([0], device=device)

    data.update({ # 这个就是coarse-level的featmap上 gt 的关键点对
        'spv_b_ids': b_ids,
        'spv_i_ids': i_ids,
        'spv_j_ids': j_ids
    })

    # 6. save intermediate results (for fast fine-level computation)
    data.update({
        'spv_w_pt0_i': w_pt0_i, # 图0的点在图1中的位置（original resolution）(N, L, 2)
        'spv_pt1_i': grid_pt1_i # 图1的坐标点阵 (N, HW, 2) Note: L = HW
    })


def compute_supervision_coarse(data, config):
    assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
    data_source = data['dataset_name'][0]
    if data_source.lower() in ['scannet', 'megadepth']:
        spvs_coarse(data, config)
    else:
        raise ValueError(f'Unknown data source: {data_source}')


##############  ↓  Fine-Level supervision  ↓  ##############

@torch.no_grad()
def spvs_fine(data, config):
    """
    Update:
        data (dict):{
            "expec_f_gt": [M, 2]}
    """
    # 1. misc
    # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
    w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
    scale = config['LOFTR']['RESOLUTION'][1]
    radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2

    # 2. get coarse prediction
    b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']

    # 3. compute gt
    scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
    # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
    expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius  # [M, 2]
    data.update({"expec_f_gt": expec_f_gt})


def compute_supervision_fine(data, config):
    data_source = data['dataset_name'][0]
    if data_source.lower() in ['scannet', 'megadepth']:
        spvs_fine(data, config)
    else:
        raise NotImplementedError


##############  ↓  Instance-Aware supervision  ↓  ##############
@torch.no_grad()
def spvs_instance_match(data, config):
    """Handle the supervision for instance matching
    Args:
        data (dict): Requires keys {
            'image0', 'image1',
            'conf_matrix_gt'
        }
        config: Network configs

    Returns:
        data (dict): update{
            'scale_inst2coarse': downsampling rate
            'inst_conf_matrix_gt': instance-level confidence matrix
        }
    """
    # 1. misc
    device = data['image0'].device
    N, _, H0, W0 = data['image0'].shape
    _, _, H1, W1 = data['image1'].shape
    scale = config['LOFTR']['RESOLUTION'][0] * 4 # make it in config
    scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
    scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
    h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) # [H0, W0, H1, W1] // 32
    
    # 2. warp grids
    # create kpts in meshgrid and resize them to image resolution
    grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1)
    grid_pt0_i = scale0 * grid_pt0_c
    grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
    grid_pt1_i = scale1 * grid_pt1_c
    
    # warp kpts bi-directionally and resize them to coarse-level resolution
    _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
    _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
    w_pt0_c = w_pt0_i / scale1
    w_pt1_c = w_pt1_i / scale0
    
    # 3. check if mutual nearest neighbor
    w_pt0_c_round = w_pt0_c[:, :, :].round().long()
    nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
    w_pt1_c_round = w_pt1_c[:, :, :].round().long()
    nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0
    
    # corner case: out of boundary
    def out_bound_mask(pt, w, h):
        return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
    nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
    nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
    
    loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
    correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
    correct_0to1[:, 0] = False # ignore top-left corner
    
    # construct a gt conf_matrix
    inst_conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
    b_ids, i_ids = torch.where(correct_0to1 != 0)
    j_ids = nearest_index1[b_ids, i_ids]
    
    inst_conf_matrix_gt[b_ids, i_ids, j_ids] = 1
    data.update({
        'scale_inst2coarse': 4,
        'inst_conf_matrix_gt': inst_conf_matrix_gt
    })

@torch.no_grad()
def gaussian(M, std, center, device='cpu'):
    """
    Args:
        M (int): Length of height and width
        std (float): The standard deviation for gaussian distribution
        center (int/float): The center for distribution.
    """
    if M < 1:
        return torch.tensor([], device=device)
    if M == 1:
        return torch.ones(1, dtype=float, device=device)
    if center < 0 or center >= M:
        print("Gaussian erro: The center out of index")
        raise NotImplementedError
    n = torch.arange(0, M, device=device) - center
    sig2 = 2 * std * std
    w = torch.exp(-n**2 / sig2)
    return w


@torch.no_grad()
def gkern2d(kernlen, center, std, device):
    """Generate a 2D gaussian kernel.
    Args:
        kernlen (tuple[int, int]): Length of height and width
        center (tuple: [x, y]): The center for distribution
        std (int/float): The center for distribution
    """
    xarray = gaussian(kernlen[0], std, center[0], device)
    yarray = gaussian(kernlen[1], std, center[1], device)
    gkern2d = torch.ger(xarray, yarray)
    return gkern2d


@torch.no_grad()
def spvs_instance(data, config): # 处理instance相关的监督标签
    """Handle instance related data, such as the labels
    Args:
        data (dict): should contain ['instance0'] and ['instance1']
        config (dict): contain configuration for calculation
    Update:
        data: {
            'region_labels0': [N, wh, h, w],
            'region_labels1': [N, wh, h, w]
        }
    """
    # 1. misc
    device = data['instance0'].device
    bs, _, H0, W0 = data['instance0'].shape
    _ , _, H1, W1 = data['instance1'].shape
    
    w_region = 20 # TODO: make it in config
    h_region = 15
    
    # 2. handle region map
    region_map0 = F.interpolate(data['instance0'].float(), (h_region, w_region), mode='nearest') # [N, 1, h, w]
    region_map1 = F.interpolate(data['instance1'].float(), (h_region, w_region), mode='nearest')
    
    region_labels0 = region_map0.repeat(1, w_region*h_region, 1, 1).float()
    region_labels1 = region_map1.repeat(1, w_region*h_region, 1, 1).float()

    for b in range(bs):
        for i in range(w_region * h_region):
            kernel = gkern2d((h_region, w_region), (i//w_region, i%w_region), 2.5, device).unsqueeze(dim=0)[None] # can optimize here
            region_labels0[b,i,:,:] = (region_map0[b]==region_map0[b,:,i//w_region,i%w_region]) * kernel
            region_labels1[b,i,:,:] = (region_map1[b]==region_map1[b,:,i//w_region,i%w_region]) * kernel

    # 3. update data
    data.update({
        'region_labels0': region_labels0,
        'region_labels1': region_labels1
    })
    
    
@torch.no_grad()
def spvs_instance_full(data, config): # handle instance-related supervision
    # 1. misc
    device = data['instance0'].device
    bs, _, H0, W0 = data['instance0'].shape
    _ , _, H1, W1 = data['instance1'].shape
    
    w_region = config['INSTMATCH']['AWARE']['OUT_W']
    h_region = config['INSTMATCH']['AWARE']['OUT_H']
    
    # 2. region map
    region_map0 = F.interpolate(data['instance0'].float(), (h_region, w_region), mode='nearest')
    region_map1 = F.interpolate(data['instance1'].float(), (h_region, w_region), mode='nearest')
    
    region_labels0 = region_map0.repeat(1, w_region*h_region, 1, 1).float()
    region_labels1 = region_map1.repeat(1, w_region*h_region, 1, 1).float()
    
    for b in range(bs):
        for i in range(w_region * h_region):
            region_labels0[b,i,:,:] = (region_map0[b]==region_map0[b,:,i//w_region,i%w_region])
            region_labels1[b,i,:,:] = (region_map1[b]==region_map1[b,:,i//w_region,i%w_region])
    data.update({
        'region_labels0': region_labels0,
        'region_labels1': region_labels1
    })