import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.einops import rearrange, reduce, repeat
from torchvision.transforms.functional import resize

INF = 1e9

def mask_border(m, b: int, v):
    """ Mask borders with value
    Args:
        m (torch.Tensor): [N, H0, W0, H1, W1]
        b (int)
        v (m.dtype)
    """
    if b <= 0:
        return

    m[:, :b] = v
    m[:, :, :b] = v
    m[:, :, :, :b] = v
    m[:, :, :, :, :b] = v
    m[:, -b:] = v
    m[:, :, -b:] = v
    m[:, :, :, -b:] = v
    m[:, :, :, :, -b:] = v


def mask_border_with_padding(m, bd, v, p_m0, p_m1):
    if bd <= 0:
        return

    m[:, :bd] = v
    m[:, :, :bd] = v
    m[:, :, :, :bd] = v
    m[:, :, :, :, :bd] = v

    h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
    h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
    for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
        m[b_idx, h0 - bd:] = v
        m[b_idx, :, w0 - bd:] = v
        m[b_idx, :, :, h1 - bd:] = v
        m[b_idx, :, :, :, w1 - bd:] = v


def compute_max_candidates(p_m0, p_m1):
    """Compute the max candidates of all pairs within a batch
    
    Args:
        p_m0, p_m1 (torch.Tensor): padded masks
    """
    h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
    h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
    max_cand = torch.sum(
        torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
    return max_cand



class InstMatching(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        # provide dual softmax differentiable matching
        self.match_type = config['match_type']
        self.conf_thr = config['conf_thr']
        self.mask_thr = config['mask_thr']
        self.min_edge = config['min_edge']
        self.max_edge = config['max_edge']
        if self.match_type == 'dual_softmax':
            self.temperature = config['dsmax_temperature']
        else:
            raise NotImplementedError()
        
    def forward(self, inst0, inst1, data):
        """
        Args:
            inst0 (torch.Tensor): [N, Li, C]
            inst1 (torch.Tensor): [N, Si, C]
            data (dict)

        Raises:
            ImportError: _description_
            NotImplementedError: _description_

        Updates:
            data (dict): {
                'inst_conf_matrix' (torch.Tensor): [N, Li, Si] 
            }
        """
        N, Li, Si, _ = inst0.size(0), inst0.size(1), inst1.size(1), inst0.size(2)
        # normalize the instance features
        inst0, inst1 = map(lambda inst: inst / inst.shape[-1]**.5, [inst0, inst1])
        # construct the instance confidence matrix -- dual softmax
        inst_sim_matrix = torch.einsum("nlc,nsc->nls", inst0, inst1) / self.temperature
        inst_conf_matrix = F.softmax(inst_sim_matrix, 1) * F.softmax(inst_sim_matrix, 2)
        data.update({'inst_conf_matrix': inst_conf_matrix})
        
        # predict instance matches from inst_conf_matrix
        data.update(**self.get_inst_match(inst_conf_matrix, data))
    
    @torch.no_grad() 
    def get_inst_match(self, inst_conf_matrix, data):
        """
        Args:
            inst_conf_matrix (torch.Tensor): [N, Li, Si]
            data (dict): with keys ['hw0_inst', 'hw1_inst', 'hw0_in', 'hw1_in']
        Returns:
            inst_matches (dict): {
                'b_ids' (torch.Tensor): [M`],
                'i_ids' (torch.Tensor): [M`],
                'j_ids' (torch.Tensor): [M`],
                'm_bids' (torch.Tensor): [M],
                'mkpts0_inst' (torch.Tensor): [M, 2],
                'mkpts1_inst' (torch.Tensor): [M, 2],
                'iconf' (torch.Tensor): [M],
                'mboxes_0' (torch.Tensor): [M, 4],
                'mboxes_1' (torch.Tensor): [M, 4]
            }
        """
        inst_lengths = {
            'h0c': data['hw0_inst'][0],
            'w0c': data['hw0_inst'][1],
            'h1c': data['hw1_inst'][0],
            'w1c': data['hw1_inst'][1]
        }
        _device = inst_conf_matrix.device
        
        # 1. confidence thresholding
        mask = inst_conf_matrix > self.conf_thr # TODO: make it in config
        
        # 2. mutual nearest
        mask = mask \
            * (inst_conf_matrix == inst_conf_matrix.max(dim=2, keepdim=True)[0]) \
            * (inst_conf_matrix == inst_conf_matrix.max(dim=1, keepdim=True)[0])
        
        # 3. find all valid instance matches
        # only works when at most one 'True' in each row
        mask_v, all_j_ids = mask.max(dim=2)
        b_ids, i_ids = torch.where(mask_v)
        j_ids = all_j_ids[b_ids, i_ids]
        iconf = inst_conf_matrix[b_ids, i_ids, j_ids]
        
        # matches selected
        inst_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
        
        # 3.5 Get the bounding box for the matching pairs
        # mboxes_0 = data['boxes_0'][b_ids, i_ids]
        # mboxes_1 = data['boxes_1'][b_ids, j_ids]
        
        # Debug: test upsampling the box bounding
        if b_ids.shape[0] > 0:
            mweight_0 = resize(data['weight0'][b_ids, i_ids], data['hw0_in'])
            mweight_1 = resize(data['weight1'][b_ids, j_ids], data['hw1_in']) # [M, h, w]
        else:
            mweight_0 = data['weight0'][b_ids, i_ids]
            mweight_1 = data['weight1'][b_ids, j_ids]
        pos_info_0 = data['pos_info'][b_ids, i_ids] * (data['hw0_in'][0] / data['hw0_inst'][0])
        pos_info_1 = data['pos_info'][b_ids, j_ids] * (data['hw1_in'][0] / data['hw1_inst'][0])
        
        mboxes_0 = self.get_bounding_boxes(mweight_0, pos_info_0)
        mboxes_1 = self.get_bounding_boxes(mweight_1, pos_info_1)
        
        # scale up to original size -- Debug: should turn this off
        # mboxes_0[:, 0:2] = mboxes_0[:, 0:2] * (data['hw0_in'][0] / data['hw0_inst'][0])
        # mboxes_0[:, 2:] = mboxes_0[:, 2:] * (data['hw0_in'][1] / data['hw0_inst'][1])
        # mboxes_1[:, 0:2] = mboxes_1[:, 0:2] * (data['hw1_in'][0] / data['hw1_inst'][0])
        # mboxes_1[:, 2:] = mboxes_1[:, 2:] * (data['hw1_in'][1] / data['hw1_inst'][1])
        
        # 4. update matches in original resolution -- note: Size of two images should be the same
        scale = data['hw0_in'][0] / data['hw0_inst'][0]
        mkpts0_inst = torch.stack(
            [i_ids % data['hw0_inst'][1], i_ids // data['hw0_inst'][1]],
            dim=1
        ) * scale
        mkpts1_inst = torch.stack(
            [j_ids % data['hw1_inst'][1], j_ids // data['hw1_inst'][1]],
            dim=1
        ) * scale
        
        inst_matches.update({
            'm_bids': b_ids[iconf != 0],
            'mkpts0_inst': mkpts0_inst[iconf != 0],
            'mkpts1_inst': mkpts1_inst[iconf != 0],
            'iconf': iconf[iconf != 0],
            'mboxes_0': mboxes_0[iconf != 0],
            'mboxes_1': mboxes_1[iconf != 0]
        })
        
        return inst_matches
    
    @torch.no_grad()
    def get_bounding_boxes(self, weight_mask, pos_info):
        """
        Args:
            data (dictionary): contains all information generated by the network
            weight_mask (torch.tensor): [M, h, w]
            pos_info (torch.tensor): [M, 2]
        Returns:
            mboxes (torch.tensor): [M, 4] -- up bottom left right (ymin ymax xmin xmax)
        """
        # misc
        min_edge_h = self.min_edge * weight_mask.shape[1]
        min_edge_w = self.min_edge * weight_mask.shape[2]
        max_edge_h = self.max_edge * weight_mask.shape[1]
        max_edge_w = self.max_edge * weight_mask.shape[2]
        
        mboxes = torch.zeros([weight_mask.shape[0], 4], device=weight_mask.device)
        masks = weight_mask > self.mask_thr
        for m in range(masks.shape[0]):
            seg = masks[m, :, :]
            # Bounding Box
            horizontal_indices = torch.where(torch.any(seg, dim=0))[0]
            vertical_indices = torch.where(torch.any(seg, dim=1))[0]
            if horizontal_indices.shape[0]:
                x1, x2 = horizontal_indices[[0, -1]]
                y1, y2 = vertical_indices[[0, -1]]
                x2 += 1
                y2 += 1
            else:
                x1, x2, y1, y2 = 0, 0, 0, 0
            # adjust box between minimal and maximal range according to config
            d_up, d_bo, d_l, d_r = pos_info[m, 1]-y1, y2-pos_info[m, 1], pos_info[m, 0]-x1, x2-pos_info[m, 0]
            d_up = torch.clamp(d_up, min_edge_h, max_edge_h)
            d_bo = torch.clamp(d_bo, min_edge_h, max_edge_h)
            d_l = torch.clamp(d_l, min_edge_w, max_edge_w)
            d_r = torch.clamp(d_r, min_edge_w, max_edge_w)
            y1, y2, x1, x2 = pos_info[m, 1]-d_up, pos_info[m, 1]+d_bo, pos_info[m, 0]-d_l, pos_info[m, 0]+d_r
            
            mboxes[m] = torch.tensor([y1, y2, x1, x2])
        mboxes[:, 0:2] = mboxes[:, 0:2].clamp(0.0, weight_mask.shape[1])
        mboxes[:, 2:] = mboxes[:, 2:].clamp(0.0, weight_mask.shape[2])
        return mboxes


class InstCoarseMatching(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # general config
        self.thr = config['thr']
        self.border_rm = config['border_rm']
        # -- # for trainig fine-level LoFTR
        self.train_coarse_percent = config['train_coarse_percent']
        self.train_pad_num_gt_min = config['train_pad_num_gt_min']

        # we provide 2 options for differentiable matching
        self.match_type = config['match_type']
        if self.match_type == 'dual_softmax':
            self.temperature = config['dsmax_temperature']
        elif self.match_type == 'sinkhorn':
            try:
                from .superglue import log_optimal_transport
            except ImportError:
                raise ImportError("download superglue.py first!")
            self.log_optimal_transport = log_optimal_transport
            self.bin_score = nn.Parameter(
                torch.tensor(config['skh_init_bin_score'], requires_grad=True))
            self.skh_iters = config['skh_iters']
            self.skh_prefilter = config['skh_prefilter']
        else:
            raise NotImplementedError()

    def forward(self, feat_c0, feat_c1, inst0, inst1, data, mask_c0=None, mask_c1=None):
        """
        Args:
            feat0 (torch.Tensor): [N, L, C]
            feat1 (torch.Tensor): [N, S, C]
            inst0 (torch.Tensor): [N, Li, C]
            inst1 (torch.Tensor): [N, Si, C]
            data (dict)
            mask_c0 (torch.Tensor): [N, L] (optional)
            mask_c1 (torch.Tensor): [N, S] (optional)
        Update:
            data (dict): {
                'b_ids' (torch.Tensor): [M'],
                'i_ids' (torch.Tensor): [M'],
                'j_ids' (torch.Tensor): [M'],
                'gt_mask' (torch.Tensor): [M'],
                'mkpts0_c' (torch.Tensor): [M, 2],
                'mkpts1_c' (torch.Tensor): [M, 2],
                'mconf' (torch.Tensor): [M]},
            NOTE: M' != M during training.
        """
        N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
        _, Li, Si, _ = inst0.size(0), inst0.size(1), inst1.size(1), inst0.size(2)
        # normalize the instance features
        inst0, inst1 = map(lambda inst: inst / inst.shape[-1]**.5, [inst0, inst1])
        # construct the instance confidence matrix -- dual softmax
        inst_sim_matrix = torch.einsum("nlc,nsc->nls", inst0, inst1) / self.temperature
        inst_conf_matrix = F.softmax(inst_sim_matrix, 1) * F.softmax(inst_sim_matrix, 2)
        data.update({'inst_conf_matrix': inst_conf_matrix})
        
        # normalize
        feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
                               [feat_c0, feat_c1])

        if self.match_type == 'dual_softmax':
            sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
                                      feat_c1) / self.temperature
            if mask_c0 is not None:
                sim_matrix.masked_fill_(
                    ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
                    -INF)
            conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)

        elif self.match_type == 'sinkhorn':
            # sinkhorn, dustbin included
            sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
            if mask_c0 is not None:
                sim_matrix[:, :L, :S].masked_fill_(
                    ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
                    -INF)

            # build uniform prior & use sinkhorn
            log_assign_matrix = self.log_optimal_transport(
                sim_matrix, self.bin_score, self.skh_iters)
            assign_matrix = log_assign_matrix.exp()
            conf_matrix = assign_matrix[:, :-1, :-1]

            # filter prediction with dustbin score (only in evaluation mode)
            if not self.training and self.skh_prefilter:
                filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1]  # [N, L]
                filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1]  # [N, S]
                conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
                conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0

            if self.config['sparse_spvs']:
                data.update({'conf_matrix_with_bin': assign_matrix.clone()})

        data.update({'conf_matrix': conf_matrix})

        # predict coarse matches from conf_matrix
        data.update(**self.get_coarse_match(conf_matrix, inst_conf_matrix, data))

    @torch.no_grad()
    def get_coarse_match(self, conf_matrix, inst_conf_matrix, data):
        """
        Args:
            conf_matrix (torch.Tensor): [N, L, S]
            inst_conf_matrix (torch.Tensor): [N, Li, Si]
            data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
        Returns:
            coarse_matches (dict): {
                'b_ids' (torch.Tensor): [M'],
                'i_ids' (torch.Tensor): [M'],
                'j_ids' (torch.Tensor): [M'],
                'gt_mask' (torch.Tensor): [M'],
                'm_bids' (torch.Tensor): [M],
                'mkpts0_c' (torch.Tensor): [M, 2],
                'mkpts1_c' (torch.Tensor): [M, 2],
                'mconf' (torch.Tensor): [M]}
        """
        axes_lengths = {
            'h0c': data['hw0_c'][0],
            'w0c': data['hw0_c'][1],
            'h1c': data['hw1_c'][0],
            'w1c': data['hw1_c'][1]
        }
        inst_lengths = {
            'h0c': data['hw0_inst'][0],
            'w0c': data['hw0_inst'][1],
            'h1c': data['hw1_inst'][0],
            'w1c': data['hw1_inst'][1]
        }
        _device = conf_matrix.device
        # 0. TODO: handle the instance confidence matrix
        inst_mask = inst_conf_matrix > 0.2
        # inst_mask = inst_mask \
        #     * (inst_conf_matrix == inst_conf_matrix.max(dim=2, keepdim=True)[0]) \
        #     * (inst_conf_matrix == inst_conf_matrix.max(dim=1, keepdim=True)[0])
        scale_inter = 4
        inst_mask = rearrange(
            inst_mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
            **inst_lengths
        )
        # expand to coarse resolution
        inst_mask = repeat(
            inst_mask, 'b h0 w0 h1 w1 -> b (h0 s0) (w0 s1) (h1 s2) (w1 s3)',
            s0=scale_inter, s1=scale_inter, s2=scale_inter, s3=scale_inter
        )
        inst_mask = rearrange(inst_mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', **axes_lengths)
        # TODO: take advantage of inst mask
        
        # 1. confidence thresholding
        mask = conf_matrix > self.thr
        # mask = mask * inst_mask
        mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
                         **axes_lengths)
        if 'mask0' not in data:
            mask_border(mask, self.border_rm, False)
        else:
            mask_border_with_padding(mask, self.border_rm, False,
                                     data['mask0'], data['mask1'])
        mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
                         **axes_lengths)

        # 2. mutual nearest
        mask = mask \
            * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
            * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])

        # 3. find all valid coarse matches
        # this only works when at most one `True` in each row
        mask_v, all_j_ids = mask.max(dim=2)
        b_ids, i_ids = torch.where(mask_v)
        j_ids = all_j_ids[b_ids, i_ids]
        mconf = conf_matrix[b_ids, i_ids, j_ids]

        # 4. Random sampling of training samples for fine-level LoFTR
        # (optional) pad samples with gt coarse-level matches
        if self.training:
            # NOTE:
            # The sampling is performed across all pairs in a batch without manually balancing
            # #samples for fine-level increases w.r.t. batch_size
            if 'mask0' not in data:
                num_candidates_max = mask.size(0) * max(
                    mask.size(1), mask.size(2))
            else:
                num_candidates_max = compute_max_candidates(
                    data['mask0'], data['mask1'])
            num_matches_train = int(num_candidates_max *
                                    self.train_coarse_percent)
            num_matches_pred = len(b_ids)
            assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"

            # pred_indices is to select from prediction
            if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
                pred_indices = torch.arange(num_matches_pred, device=_device)
            else:
                pred_indices = torch.randint(
                    num_matches_pred,
                    (num_matches_train - self.train_pad_num_gt_min, ),
                    device=_device)

            # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
            gt_pad_indices = torch.randint(
                    len(data['spv_b_ids']),
                    (max(num_matches_train - num_matches_pred,
                        self.train_pad_num_gt_min), ),
                    device=_device)
            mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device)  # set conf of gt paddings to all zero

            b_ids, i_ids, j_ids, mconf = map(
                lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
                                       dim=0),
                *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
                     [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))

        # These matches select patches that feed into fine-level network
        coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}

        # 4. Update with matches in original image resolution
        scale = data['hw0_i'][0] / data['hw0_c'][0]
        scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
        scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
        mkpts0_c = torch.stack(
            [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
            dim=1) * scale0
        mkpts1_c = torch.stack(
            [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
            dim=1) * scale1

        # These matches is the current prediction (for visualization)
        coarse_matches.update({
            'gt_mask': mconf == 0,
            'm_bids': b_ids[mconf != 0],  # mconf == 0 => gt matches
            'mkpts0_c': mkpts0_c[mconf != 0],
            'mkpts1_c': mkpts1_c[mconf != 0],
            'mconf': mconf[mconf != 0]
        })

        return coarse_matches
