import torch
import torch.nn as nn
import numpy as np
from einops.einops import rearrange

from .backbone import build_backbone
from .utils.position_encoding import PositionEncodingSine
from .utils.position_encoding3d import PositionalEncoding3D
from .lotfr_module import LocalFeatureTransformer, FinePreprocess
from .depth_predictor.depth_predictor import DepthPredictor
from .depth_predictor.depth_transformer import DepthGuidedEncoder, TransformerEncoderGeneral
from .utils.coarse_matching import CoarseMatching
from .utils.fine_matching import FineMatching
from src.utils.metrics import estimate_pose


class PALoFTR(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Misc
        self.config = config
        
        # Modules
        self.backbone = build_backbone(config) # ResNet + FPN here
        self.pos_encoding = PositionalEncoding3D(config['pe'])
        
        # self.depth_predictor = DepthPredictor(config['depth_predictor'])
        # self.loftr_coarse = DepthGuidedEncoder(config['depth_loftr'])
        self.loftr_coarse = TransformerEncoderGeneral(config['depth_loftr'])
        
        # ------------- For pose estimation ----------------
        self.pos_encoding2d = PositionEncodingSine(
            config['pose_coarse']['d_model'], temp_bug_fix=True
        )
        self.pose_coarse = LocalFeatureTransformer(config['pose_coarse'])
        self.pose_preprocess = FinePreprocess(config, pose=True)
        self.pose_fine = LocalFeatureTransformer(config['fine'])
        # ----------------------------------------------------------
        
        self.coarse_matching = CoarseMatching(config['match_coarse'])
        self.fine_preprocess = FinePreprocess(config)
        self.loftr_fine = LocalFeatureTransformer(config['fine'])
        self.fine_matching = FineMatching()
        
        self.proj = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=(1, 1)),
            nn.GroupNorm(32, 128)
        )

        self.matcher_head = nn.ModuleList([
            self.pos_encoding, 
            # self.depth_predictor, 
            self.loftr_coarse, 
            self.coarse_matching, 
            self.fine_preprocess, 
            self.loftr_fine, 
            self.fine_matching, 
            self.proj])
        
    def forward(self, data):
        """
        Update:
            data (dict): {
                'image0': (torch.Tensor): (N, 1, H, W)
                'image1': (torch.Tensor): (N, 1, H, W)
                'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
                'mask1'(optional) : (torch.Tensor): (N, H, W)
            }
        """
        # 1. Local Feature CNN
        data.update({
            'bs': data['image0'].size(0),  # Batch size
            'hw0_i': data['image0'].shape[2:],  # (H, W)
            'hw1_i': data['image1'].shape[2:]
        })
        
        if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
            # get coarse and fine level features from backbone
            feats_c, feats_m, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
            # remember to split into two sets
            (feat_c0, feat_c1), (feat_m0, feat_m1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_m.split(data['bs']), feats_f.split(data['bs'])
        else: # handle different input shapes
            (feat_c0, feat_m0, feat_f0), (feat_c1, feat_m1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
        
        # record height and width for both coarse and fine level feature maps
        data.update({
            'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
            'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
        })

        # 1.6. Get pose estimation
        data['sub_data'] = {} # store estimation for pose, essential for loss calculation
        if self.training:
            data['sub_data'].update({
                'bs': data['bs'],
                'hw0_i': data['hw0_i'],
                'hw1_i': data['hw1_i'],
                'hw0_c': data['hw0_c'],
                'hw1_c': data['hw1_c'],
                'hw0_f': data['hw0_f'],
                'hw1_f': data['hw1_f'],
                'conf_matrix_gt': data['conf_matrix_gt'],
                'spv_b_ids': data['spv_b_ids'],
                'spv_i_ids': data['spv_i_ids'],
                'spv_j_ids': data['spv_j_ids'],
                'spv_w_pt0_i': data['spv_w_pt0_i'],
                'spv_pt1_i': data['spv_pt1_i'],
                'K0': data['K0'],
                'K1': data['K1'],
                'T_0to1': data['T_0to1']
            })
        else:
            data['sub_data'].update({
                'bs': data['bs'],
                'hw0_i': data['hw0_i'],
                'hw1_i': data['hw1_i'],
                'hw0_c': data['hw0_c'],
                'hw1_c': data['hw1_c'],
                'hw0_f': data['hw0_f'],
                'hw1_f': data['hw1_f'],
                'K0': data['K0'],
                'K1': data['K1'],
            })

        T_0to1 = self.get_pose(data['sub_data'], feat_c0, feat_c1, feat_f0, feat_f1)

        # debug: for dimension of model set as 128
        feat_c0 = self.proj(feat_c0)
        feat_c1 = self.proj(feat_c1)

        # 1.5. Depth predictor
        # depth_logits0, depth_embed0, weighted_depth0 = self.depth_predictor([feat_m0, feat_c0])
        # depth_logits1, depth_embed1, weighted_depth1 = self.depth_predictor([feat_m1, feat_c1])
        depth_logits0, weighted_depth0 = None, None
        depth_logits1, weighted_depth1 = None, None

        # 2. Get 3D position encoding
        pos_embed0 = self.pos_encoding(data['hw0_i'], data['hw0_c'], data['K0'], T_0to1)
        pos_embed1 = self.pos_encoding(data['hw1_i'], data['hw1_c'], data['K1'])

        mask_c0 = mask_c1 = None # useful in training
        if 'mask0' in data:
            mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
        # 3. Coarse-level loftr module
        feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
        feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
        # flatten PE
        pos_embed0 = rearrange(pos_embed0, 'n c h w -> n (h w) c')
        pos_embed1 = rearrange(pos_embed1, 'n c h w -> n (h w) c')
        
        mask_c0 = mask_c1 = None # useful in training
        if 'mask0' in data:
            mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
        # feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, pos_embed0, pos_embed1, depth_embed0, depth_embed1, mask_c0, mask_c1)
        feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, pos_embed0, pos_embed1, mask_c0, mask_c1)
        
        # 4. match coarse-level
        self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1, dmap_c0=weighted_depth0, dmap_c1=weighted_depth1)
        
        # 5. fine-level refinement
        feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
        if feat_f0_unfold.size(0) != 0:
            feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
            
        # 6. match fine-level
        self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
        
        return depth_logits0, depth_logits1, weighted_depth0, weighted_depth1

    @torch.no_grad()
    def get_pose(self, data, feat_c0, feat_c1, feat_f0, feat_f1):
        """get pose estimation"""
        # 1. get coarse match pts
        feat_c0 = rearrange(self.pos_encoding2d(feat_c0), 'n c h w -> n (h w) c')
        feat_c1 = rearrange(self.pos_encoding2d(feat_c1), 'n c h w -> n (h w) c')

        mask_c0 = mask_c1 = None # useful in training
        if 'mask0' in data:
            mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)

        # coarse transformer
        feat_c0, feat_c1 = self.pose_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
        self.coarse_matching(feat_c0, feat_c1, data, None, None, pick_sample=False)

        # only pick top 300 pts
        mconf = data['mconf']
        indices = torch.argsort(mconf)[-1000:]
        # update data
        for key in ['m_bids', 'mkpts0_c', 'mkpts1_c', 'mconf']:
            data[key] = data[key][indices]
        
        # 2. fine-level refinement
        feat_f0_unfold, feat_f1_unfold = self.pose_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
        if feat_f0_unfold.size(0) != 0:
            feat_f0_unfold, feat_f1_unfold = self.pose_fine(feat_f0_unfold, feat_f1_unfold)

        # 3. get fine-level match
        self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)

        # 4. estimate relative pose
        pixel_thr = 0.5
        conf = 0.99999
        m_bids = data['m_bids'].cpu().numpy()
        pts0 = data['mkpts0_f'].cpu().numpy()
        pts1 = data['mkpts1_f'].cpu().numpy()
        K0 = data['K0'].cpu().numpy()
        K1 = data['K1'].cpu().numpy()
        T_0to1 = torch.zeros((K0.shape[0], 4, 4), device=data['K0'].device)
        for bs in range(K0.shape[0]):
            mask = m_bids == bs
            ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf)
            if ret is None:
                if self.training: # help training
                    T_0to1[bs] = data['T_0to1'][bs]
                else:
                    T_0to1[bs] = torch.eye(4)
            else:
                R, t, inliers = ret
                R = torch.from_numpy(R)
                t = torch.from_numpy(t)
                # check sanity
                if torch.any(torch.isnan(R)) or torch.any(torch.isinf(R)) or torch.any(torch.isnan(t)) or torch.any(torch.isinf(t)):
                    if self.training: # help training
                        T_0to1[bs] = data['T_0to1'][bs]
                    else:
                        T_0to1[bs] = torch.eye(4)
                else:
                    T_0to1[bs][:3, :3] = R
                    T_0to1[bs][:3, 3] = t
            # # For training, 50 percent using ground truth Transformation
            # if self.training:
            #     if np.random.rand() < 0.5: # using Ground Truth
            #         T_0to1[bs] = data['T_0to1'][bs]
            # # before output T, normalize t
            # T_0to1[bs][0:3, 3] = (T_0to1[bs][0:3, 3] / torch.linalg.norm(T_0to1[bs][0:3, 3]))
        
        return T_0to1
