import torch
import torch.nn as nn
import torch.nn.functional as F

from dinoreg.modules.ops import point_to_node_partition, index_select
from dinoreg.modules.registration import get_node_correspondences
from dinoreg.modules.sinkhorn import LearnableLogOptimalTransport
from dinoreg.modules.vgtransformer import (
    VGTransformer,
    SuperPointMatching,
    SuperPointTargetGenerator,
    LocalGlobalRegistration,
)
from dinoreg.modules.visual import XYZMapHW, PatchMapping, normalize_hw, ConvFusionBlock

from backbone import KPConvFPN


class DINOReg(nn.Module):
    def __init__(self, cfg):
        super(DINOReg, self).__init__()
        self.num_points_in_patch = cfg.model.num_points_in_patch
        self.matching_radius = cfg.model.ground_truth_matching_radius

        self.kpconv = KPConvFPN(
            cfg.geometric_branch.input_dim,
            cfg.geometric_branch.output_dim,
            cfg.geometric_branch.init_dim,
            cfg.geometric_branch.kernel_size,
            cfg.geometric_branch.init_radius,
            cfg.geometric_branch.init_sigma,
            cfg.geometric_branch.group_norm,
        )

        self.window_radius = cfg.visual_branch.window_radius
        pc_hidden_size = cfg.geometric_branch.init_dim * (2 ** cfg.geometric_branch.num_stages)
        self.fusion = ConvFusionBlock(pc_hidden_size, cfg.visual_branch.backbone_output_size, cfg.visual_branch.fusion_output_size, self.window_radius)

        self.transformer = VGTransformer(
            cfg.visual_branch.fusion_output_size,
            cfg.vgtransformer.output_dim,
            cfg.vgtransformer.hidden_dim,
            cfg.vgtransformer.num_heads,
            cfg.vgtransformer.blocks,
            cfg.vgtransformer.sigma_d,
            cfg.vgtransformer.sigma_a,
            cfg.vgtransformer.angle_k,
        )

        self.coarse_target = SuperPointTargetGenerator(
            cfg.coarse_matching.num_targets, cfg.coarse_matching.overlap_threshold
        )

        self.coarse_matching = SuperPointMatching(
            cfg.coarse_matching.num_correspondences, cfg.coarse_matching.dual_normalization
        )

        self.fine_matching = LocalGlobalRegistration(
            cfg.fine_matching.topk,
            cfg.fine_matching.acceptance_radius,
            mutual=cfg.fine_matching.mutual,
            confidence_threshold=cfg.fine_matching.confidence_threshold,
            use_dustbin=cfg.fine_matching.use_dustbin,
            use_global_score=cfg.fine_matching.use_global_score,
            correspondence_threshold=cfg.fine_matching.correspondence_threshold,
            correspondence_limit=cfg.fine_matching.correspondence_limit,
            num_refinement_steps=cfg.fine_matching.num_refinement_steps,
        )

        self.optimal_transport = LearnableLogOptimalTransport(cfg.model.num_sinkhorn_iterations)

    def forward(self, data_dict):
        output_dict = {}

        feats = data_dict['features'].detach()
        transform = data_dict['transform'].detach()

        ref_length_c = data_dict['lengths'][-1][0].item()
        ref_length_f = data_dict['lengths'][1][0].item()
        ref_length = data_dict['lengths'][0][0].item()
        points_c = data_dict['points'][-1].detach()
        points_f = data_dict['points'][1].detach()
        points = data_dict['points'][0].detach()

        ref_points_c = points_c[:ref_length_c]
        src_points_c = points_c[ref_length_c:]
        ref_points_f = points_f[:ref_length_f]
        src_points_f = points_f[ref_length_f:]
        ref_points = points[:ref_length]
        src_points = points[ref_length:]

        # 1. KPConv-FPN
        feats_list = self.kpconv(feats, data_dict)
        feats_c = feats_list[-1]
        feats_f = feats_list[0]
        ref_pc_feats_c = feats_c[:ref_length_c]
        src_pc_feats_c = feats_c[ref_length_c:]

        # 2. DINOv2
        ref_img = data_dict['ref_img']
        src_img = data_dict['src_img']
        ref_dino_feats = data_dict['ref_dino'][-1]
        src_dino_feats = data_dict['src_dino'][-1]
        ref_dino_feats = ref_dino_feats[:, 1:]
        src_dino_feats = src_dino_feats[:, 1:]
        batch_size, _, num_channels = ref_dino_feats.shape
        ref_dino_feats = ref_dino_feats.reshape(batch_size, data_dict['ref_patch_size'][0], data_dict['ref_patch_size'][1], num_channels)
        ref_dino_feats = ref_dino_feats.permute(0, 3, 1, 2)[0]  # [C, H, W]
        src_dino_feats = src_dino_feats.reshape(batch_size, data_dict['src_patch_size'][0], data_dict['src_patch_size'][1], num_channels)
        src_dino_feats = src_dino_feats.permute(0, 3, 1, 2)[0]  # [C, H, W]

        # 3.1. Feature Assignment
        aug_ref = data_dict['aug_ref'].detach()
        aug_src = data_dict['aug_src'].detach()
        intrinsics = data_dict['intrinsics'].detach()

        ref_hw_c, ref_img_mask = XYZMapHW(ref_points_c, aug_ref, intrinsics, ref_img.height, ref_img.width)
        src_hw_c, src_img_mask = XYZMapHW(src_points_c, aug_src, intrinsics, src_img.height, src_img.width)
        
        ref_hw_c = ref_hw_c[ref_img_mask, :]
        ref_points_c = ref_points_c[ref_img_mask, :]
        ref_pc_feats_c = ref_pc_feats_c[ref_img_mask, :]
        src_hw_c = src_hw_c[src_img_mask, :]
        src_points_c = src_points_c[src_img_mask, :]
        src_pc_feats_c = src_pc_feats_c[src_img_mask, :]

        ref_vis_feats_c = PatchMapping(ref_hw_c, ref_img.height, ref_img.width, ref_dino_feats, self.window_radius)
        src_vis_feats_c = PatchMapping(src_hw_c, src_img.height, src_img.width, src_dino_feats, self.window_radius)

        # 3.2. Feature Fusion
        ref_feats_len = ref_pc_feats_c.shape[0]
        pc_feats_c = torch.cat([ref_pc_feats_c, src_pc_feats_c], dim = 0)
        vis_feats_c = torch.cat([ref_vis_feats_c, src_vis_feats_c], dim = 0)
        fused_feats_c = self.fusion(pc_feats_c, vis_feats_c)
        ref_feats_c = fused_feats_c[:ref_feats_len]
        src_feats_c = fused_feats_c[ref_feats_len:]
        
        output_dict['ref_points_c'] = ref_points_c
        output_dict['src_points_c'] = src_points_c
        output_dict['ref_points_f'] = ref_points_f
        output_dict['src_points_f'] = src_points_f
        output_dict['ref_points'] = ref_points
        output_dict['src_points'] = src_points

        # 4. Generate Ground Truth Node Correspondences
        _, ref_node_masks, ref_node_knn_indices, ref_node_knn_masks = point_to_node_partition(
            ref_points_f, ref_points_c, self.num_points_in_patch
        )
        _, src_node_masks, src_node_knn_indices, src_node_knn_masks = point_to_node_partition(
            src_points_f, src_points_c, self.num_points_in_patch
        )

        ref_padded_points_f = torch.cat([ref_points_f, torch.zeros_like(ref_points_f[:1])], dim=0)
        src_padded_points_f = torch.cat([src_points_f, torch.zeros_like(src_points_f[:1])], dim=0)
        ref_node_knn_points = index_select(ref_padded_points_f, ref_node_knn_indices, dim=0)
        src_node_knn_points = index_select(src_padded_points_f, src_node_knn_indices, dim=0)

        gt_node_corr_indices, gt_node_corr_overlaps = get_node_correspondences(
            ref_points_c,
            src_points_c,
            ref_node_knn_points,
            src_node_knn_points,
            transform,
            self.matching_radius,
            ref_masks=ref_node_masks,
            src_masks=src_node_masks,
            ref_knn_masks=ref_node_knn_masks,
            src_knn_masks=src_node_knn_masks,
        )

        output_dict['gt_node_corr_indices'] = gt_node_corr_indices
        output_dict['gt_node_corr_overlaps'] = gt_node_corr_overlaps

        # 5. Conditional Transformer
        ref_hw_c = normalize_hw(ref_hw_c, [ref_img.height, ref_img.width])
        src_hw_c = normalize_hw(src_hw_c, [src_img.height, src_img.width])
        ref_feats_c, src_feats_c = self.transformer(
            ref_points_c.unsqueeze(0),
            src_points_c.unsqueeze(0),
            ref_hw_c.unsqueeze(0),
            src_hw_c.unsqueeze(0),
            ref_feats_c.unsqueeze(0),
            src_feats_c.unsqueeze(0),
        )
        ref_feats_c_norm = F.normalize(ref_feats_c.squeeze(0), p=2, dim=1)
        src_feats_c_norm = F.normalize(src_feats_c.squeeze(0), p=2, dim=1)

        output_dict['ref_feats_c'] = ref_feats_c_norm
        output_dict['src_feats_c'] = src_feats_c_norm

        # 6.1. Head for Fine Level Matching
        ref_feats_f = feats_f[:ref_length_f]
        src_feats_f = feats_f[ref_length_f:]
        output_dict['ref_feats_f'] = ref_feats_f
        output_dict['src_feats_f'] = src_feats_f

        # 6.2. Select Topk Nearest Node Correspondences
        with torch.no_grad():
            ref_node_corr_indices, src_node_corr_indices, node_corr_scores = self.coarse_matching(
                ref_feats_c_norm, src_feats_c_norm, ref_node_masks, src_node_masks
            )

            output_dict['ref_node_corr_indices'] = ref_node_corr_indices
            output_dict['src_node_corr_indices'] = src_node_corr_indices

            # 6.3. Random Select Ground Truth Node Correspondences During Training
            if self.training:
                ref_node_corr_indices, src_node_corr_indices, node_corr_scores = self.coarse_target(
                    gt_node_corr_indices, gt_node_corr_overlaps
                )

        # 7.1. Generate Batched Node Points & Feats
        ref_node_corr_knn_indices = ref_node_knn_indices[ref_node_corr_indices]  # (P, K)
        src_node_corr_knn_indices = src_node_knn_indices[src_node_corr_indices]  # (P, K)
        ref_node_corr_knn_masks = ref_node_knn_masks[ref_node_corr_indices]  # (P, K)
        src_node_corr_knn_masks = src_node_knn_masks[src_node_corr_indices]  # (P, K)
        ref_node_corr_knn_points = ref_node_knn_points[ref_node_corr_indices]  # (P, K, 3)
        src_node_corr_knn_points = src_node_knn_points[src_node_corr_indices]  # (P, K, 3)

        ref_padded_feats_f = torch.cat([ref_feats_f, torch.zeros_like(ref_feats_f[:1])], dim=0)
        src_padded_feats_f = torch.cat([src_feats_f, torch.zeros_like(src_feats_f[:1])], dim=0)
        ref_node_corr_knn_feats = index_select(ref_padded_feats_f, ref_node_corr_knn_indices, dim=0)  # (P, K, C)
        src_node_corr_knn_feats = index_select(src_padded_feats_f, src_node_corr_knn_indices, dim=0)  # (P, K, C)

        output_dict['ref_node_corr_knn_points'] = ref_node_corr_knn_points
        output_dict['src_node_corr_knn_points'] = src_node_corr_knn_points
        output_dict['ref_node_corr_knn_masks'] = ref_node_corr_knn_masks
        output_dict['src_node_corr_knn_masks'] = src_node_corr_knn_masks

        # 7.2. Optimal Transport
        matching_scores = torch.einsum('bnd,bmd->bnm', ref_node_corr_knn_feats, src_node_corr_knn_feats)  # (P, K, K)
        matching_scores = matching_scores / feats_f.shape[1] ** 0.5
        matching_scores = self.optimal_transport(matching_scores, ref_node_corr_knn_masks, src_node_corr_knn_masks)

        output_dict['matching_scores'] = matching_scores

        # 8. Generate Final Correspondences During Testing
        with torch.no_grad():
            if not self.fine_matching.use_dustbin:
                matching_scores = matching_scores[:, :-1, :-1]

            ref_corr_points, src_corr_points, corr_scores, estimated_transform = self.fine_matching(
                ref_node_corr_knn_points,
                src_node_corr_knn_points,
                ref_node_corr_knn_masks,
                src_node_corr_knn_masks,
                matching_scores,
                node_corr_scores,
            )

            output_dict['ref_corr_points'] = ref_corr_points
            output_dict['src_corr_points'] = src_corr_points
            output_dict['corr_scores'] = corr_scores
            output_dict['estimated_transform'] = estimated_transform

        return output_dict


def create_model(config):
    model = DINOReg(config)
    return model


def main():
    from config import make_cfg

    cfg = make_cfg()
    model = create_model(cfg)
    print(model.state_dict().keys())
    print(model)


if __name__ == '__main__':
    main()
