import torch
import torch.nn as nn
import numbers
import numpy as np
import sys
# from ops import get_knn_feats
sys.path.append('../mamba_origin/mamba_ssm/modules')
from mamba_simple import Mamba
sys.path.append('../mamba_origin/mamba_ssm/utils')
from generation import GenerationMixin
from hf import load_config_hf, load_state_dict_hf

from rope import *
import random

def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x) 
    xx = torch.sum(x**2, dim=1, keepdim=True) 
    pairwise_distance = -xx - inner - xx.transpose(2, 1) 

    idx = pairwise_distance.topk(k=k, dim=-1)[1]   

    return idx[:, :, :]

def get_graph_feature(x, k=20, idx=None):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points) 
    if idx is None:
        idx_out = knn(x, k=k) 
    else:
        idx_out = idx
    device = x.device

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx_out + idx_base 

    idx = idx.view(-1) 

    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((x, x - feature), dim=3).permute(0, 3, 1, 2).contiguous() 
    return feature

class trans(nn.Module): ## 
    def __init__(self, dim1, dim2):
        nn.Module.__init__(self)
        self.dim1 = dim1
        self.dim2 = dim2

    def forward(self, x):
        return x.transpose(self.dim1, self.dim2)


class PointCN(nn.Module):
    def __init__(self, channels, out_channels=None, use_bn=True, use_short_cut=True):
        nn.Module.__init__(self)
        if not out_channels:
           out_channels = channels

        self.use_short_cut=use_short_cut
        if use_short_cut:
            self.shot_cut = None
            if out_channels != channels:
                self.shot_cut = nn.Conv2d(channels, out_channels, kernel_size=1)
        if use_bn:
            self.conv = nn.Sequential(
                    nn.InstanceNorm1d(channels, eps=1e-3),
                    nn.BatchNorm1d(channels),
                    nn.ReLU(True),
                    nn.Conv1d(channels, out_channels, kernel_size=1),
                    nn.InstanceNorm1d(out_channels, eps=1e-3),
                    nn.BatchNorm1d(out_channels),
                    nn.ReLU(True),
                    nn.Conv1d(out_channels, out_channels, kernel_size=1)
                    )
        else:
            self.conv = nn.Sequential(
                    nn.InstanceNorm1d(channels, eps=1e-3),
                    nn.ReLU(),
                    nn.Conv1d(channels, out_channels, kernel_size=1),
                    nn.InstanceNorm1d(out_channels, eps=1e-3),
                    nn.ReLU(),
                    nn.Conv1d(out_channels, out_channels, kernel_size=1)
                    )

    def forward(self, x):
        # print('out1', x.shape)
        out = self.conv(x)
        if self.use_short_cut:
            if self.shot_cut:
                out = out + self.shot_cut(x)
            else:
                out = out + x
        # print('out2', out.shape)
        return out


class PositionEncoder(nn.Module):
    def __init__(self, channels):
        nn.Module.__init__(self)
        self.position_encoder = nn.Sequential(
            nn.Conv1d(2, 32, kernel_size=1), PointCN(32),##### 
            nn.Conv1d(32, 64, kernel_size=1), PointCN(64),
            nn.Conv1d(64, 128, kernel_size=1), PointCN(128)
        )
        
    def forward(self, x):
        return self.position_encoder(x)

class SCE_Layer(nn.Module):
    def __init__(self, knn_num=9, in_channel=128):
        super(SCE_Layer, self).__init__()
        self.knn_num = knn_num
        self.in_channel = in_channel

        assert self.knn_num == 9 or self.knn_num == 6
        if self.knn_num == 9:
            self.conv = nn.Sequential(
                nn.Conv2d(self.in_channel*2, self.in_channel, (1, 3), stride=(1, 3)), 
                nn.BatchNorm2d(self.in_channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(self.in_channel, self.in_channel, (1, 3)),
                nn.BatchNorm2d(self.in_channel),
                nn.ReLU(inplace=True),
            )
        if self.knn_num == 6:
            self.conv = nn.Sequential(
                nn.Conv2d(self.in_channel*2, self.in_channel, (1, 3), stride=(1, 3)),
                nn.BatchNorm2d(self.in_channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(self.in_channel, self.in_channel, (1, 2)), 
                nn.BatchNorm2d(self.in_channel),
                nn.ReLU(inplace=True),
            )

    def forward(self, x):
        out = self.conv(x) 
        return out


class InitProject(nn.Module):
    def __init__(self, channels):
        nn.Module.__init__(self)
        self.init_project = nn.Sequential(
            nn.Conv1d(4, 32, kernel_size=1), PointCN(32), 
            nn.Conv1d(32, 64, kernel_size=1), PointCN(64),
            nn.Conv1d(64, 128, kernel_size=1), PointCN(128)
        )
        
    def forward(self, x):
        return self.init_project(x)


class InlinerPredictor(nn.Module):
    def __init__(self, channels):
        nn.Module.__init__(self)
        self.inlier_pre = nn.Sequential(
            nn.Conv1d(channels, 64, kernel_size=1), nn.InstanceNorm1d(64, eps=1e-3), nn.BatchNorm1d(64), nn.ReLU(),
            nn.Conv1d(64, 16, kernel_size=1), nn.InstanceNorm1d(16, eps=1e-3), nn.BatchNorm1d(16), nn.ReLU(),
            nn.Conv1d(16, 4, kernel_size=1), nn.InstanceNorm1d(4, eps=1e-3), nn.BatchNorm1d(4), nn.ReLU(),
            nn.Conv1d(4, 1, kernel_size=1)
        )

    def forward(self, d):
        # BCN -> B1N
        return self.inlier_pre(d)


class SCE_Layer(nn.Module):
    def __init__(self, knn_num=9, in_channel=128):
        super(SCE_Layer, self).__init__()
        self.knn_num = knn_num
        self.in_channel = in_channel
        assert self.knn_num == 9 or self.knn_num == 6
        if self.knn_num == 9:
            self.conv = nn.Sequential(
                nn.Conv2d(self.in_channel*2, self.in_channel, (1, 3), stride=(1, 3)), 
                nn.BatchNorm2d(self.in_channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(self.in_channel, self.in_channel, (1, 3)),
                nn.BatchNorm2d(self.in_channel),
                nn.ReLU(inplace=True),
            )
        if self.knn_num == 6:
            self.conv = nn.Sequential(
                nn.Conv2d(self.in_channel*2, self.in_channel, (1, 3), stride=(1, 3)),
                nn.BatchNorm2d(self.in_channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(self.in_channel, self.in_channel, (1, 2)), 
                nn.BatchNorm2d(self.in_channel),
                nn.ReLU(inplace=True),
            )

    def forward(self, x):
        out = self.conv(x) 
        return out


class SE(nn.Module):
    def __init__(self, num_channels, reduction_ratio=2):        
        super(SE, self).__init__()
        num_channels_reduced = num_channels // reduction_ratio
        self.conv0 = nn.Conv1d(num_channels, num_channels, kernel_size=1, stride=1,bias=True)
        self.in0 = nn.InstanceNorm1d(num_channels)
        self.bn0 = nn.BatchNorm1d(num_channels)
        self.reduction_ratio = reduction_ratio
        self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
        self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    def forward(self, input_tensor): 
        # print(input_tensor.shape)       
        batch_size, num_channels, _ = input_tensor.size()       
        x = self.conv0(input_tensor) 
        x = self.in0(x) 
        x = self.bn0(x)
        input_tensor = self.relu(x)
        squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2)
        # channel excitation
        fc_out_1 = self.relu(self.fc1(squeeze_tensor))
        fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
        a, b = squeeze_tensor.size()
        # print('in', input_tensor.shape)
        # print('fc_out_2', fc_out_2.shape)
        output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1))
        return output_tensor


class MambaBlock(nn.Module):
    def __init__(self, dim):
        super(MambaBlock, self).__init__()
        
        self.cn0 = nn.Sequential(
            nn.InstanceNorm1d(dim, eps=1e-3),
            nn.BatchNorm1d(dim), 
            nn.ReLU()
        )
        self.mamba0 = nn.Sequential(
            trans(1, 2),
            Mamba(dim, bimamba_type=None),
            trans(1,2)
        )
        self.cn1 = nn.Sequential(
            nn.InstanceNorm1d(dim, eps=1e-3),
            nn.BatchNorm1d(dim), 
            nn.ReLU()
        )
        self.ca = SE(dim)

    def forward(self,x):
        x = self.cn0(x) 
        x = x + self.mamba0(x)
        x = self.cn1(x)
        x = x + self.ca(x)
        return x


class ConvBlock(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.convblock = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.convblock(x) 

class DiffPool(nn.Module):
    def __init__(self, in_channel, output_points, use_bn=True):
        nn.Module.__init__(self)
        self.output_points = output_points
        if use_bn:
            self.conv = nn.Sequential(
                nn.InstanceNorm2d(in_channel, eps=1e-3),
                nn.BatchNorm2d(in_channel),
                nn.ReLU(),
                nn.Conv2d(in_channel, output_points, kernel_size=1))
        else:
            self.conv = nn.Sequential(
                nn.InstanceNorm2d(in_channel, eps=1e-3),
                nn.ReLU(),
                nn.Conv2d(in_channel, output_points, kernel_size=1))

    def forward(self, x):
        # x: b,f,n,1
        embed = self.conv(x)  # b*k*n*1
        S = torch.softmax(embed, dim=2).squeeze(3)  # b,k,n
        # b,f,n @ b,n,k
        out = torch.matmul(x.squeeze(3), S.transpose(1, 2)).unsqueeze(3)
        return out

class DiffUnpool(nn.Module):
    def __init__(self, in_channel, output_points, use_bn=True):
        nn.Module.__init__(self)
        self.output_points = output_points
        if use_bn:
            self.conv = nn.Sequential(
                nn.InstanceNorm2d(in_channel, eps=1e-3),
                nn.BatchNorm2d(in_channel),
                nn.ReLU(),
                nn.Conv2d(in_channel, output_points, kernel_size=1))
        else:
            self.conv = nn.Sequential(
                nn.InstanceNorm2d(in_channel, eps=1e-3),
                nn.ReLU(),
                nn.Conv2d(in_channel, output_points, kernel_size=1))

    def forward(self, x_up, x_down):
        # x_up: b*c*n*1
        # x_down: b*c*k*1
        embed = self.conv(x_up)  # b*k*n*1
        S = torch.softmax(embed, dim=1).squeeze(3)  # b*k*n
        out = torch.matmul(x_down.squeeze(3), S).unsqueeze(3)
        return out

class OAFilter(nn.Module):
    def __init__(self, channels, points, out_channels=None, use_bn=True):
        nn.Module.__init__(self)
        if not out_channels:
            out_channels = channels
        self.shot_cut = None
        if out_channels != channels:
            self.shot_cut = nn.Conv2d(channels, out_channels, kernel_size=1)
        if use_bn:
            self.conv1 = nn.Sequential(
                nn.InstanceNorm2d(channels, eps=1e-3),
                nn.BatchNorm2d(channels),
                nn.ReLU(),
                nn.Conv2d(channels, out_channels, kernel_size=1),  # b*c*n*1
                trans(1, 2))
            # Spatial Correlation Layer
            self.conv2 = nn.Sequential(
                nn.BatchNorm2d(points),
                nn.ReLU(),
                nn.Conv2d(points, points, kernel_size=1)
            )
            self.conv3 = nn.Sequential(
                trans(1, 2),
                nn.InstanceNorm2d(out_channels, eps=1e-3),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size=1)
            )
        else:
            self.conv1 = nn.Sequential(
                nn.InstanceNorm2d(channels, eps=1e-3),
                nn.ReLU(),
                nn.Conv2d(channels, out_channels, kernel_size=1),  # b*c*n*1
                trans(1, 2))
            # Spatial Correlation Layer
            self.conv2 = nn.Sequential(
                nn.ReLU(),
                nn.Conv2d(points, points, kernel_size=1)
            )
            self.conv3 = nn.Sequential(
                trans(1, 2),
                nn.InstanceNorm2d(out_channels, eps=1e-3),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size=1)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = out + self.conv2(out)
        out = self.conv3(out)
        if self.shot_cut:
            out = out + self.shot_cut(x)
        else:
            out = out + x
        return out

class GlobalClusterLayer(nn.Module):
    def __init__(self, in_dim, cluster_num):
        super().__init__()
        self.down = DiffPool(in_dim, cluster_num)
        self.mlp_cluster = OAFilter(in_dim, cluster_num)
        self.up = DiffUnpool(in_dim, cluster_num)

    def forward(self, feats):
        feats_down=self.down(feats)
        feats_down=self.mlp_cluster(feats_down)
        feats_up=self.up(feats,feats_down)
        return feats_up


class LayerBlock(nn.Module):
    def __init__(self, channels, knn_dim, cluster):
        nn.Module.__init__(self)
        self.lc0 = SCE_Layer()
        self.cn0 = PointCN(channels)
        self.mambaBlock0 = MambaBlock(channels)
        self.clusterInter0 = GlobalClusterLayer(channels, cluster) 
        self.lc1 = SCE_Layer()
        self.cn1 = PointCN(channels)
        self.mambaBlock1 = MambaBlock(channels) 
        self.fe = FeatureExtractor(channels)
        self.k_num = knn_dim
        self.inlier_pre = InlinerPredictor(channels)
        
    def forward(self, xs, d):
        # xs: B1N4 d: bcn
        d = d.unsqueeze(3)
        grid_d = d + self.fe(d)
        idx_fn1 = knn(grid_d.squeeze(-1), k=self.k_num)
        grid_d = grid_d + self.lc0(get_graph_feature(grid_d, k=self.k_num, idx=idx_fn1))
        grid_d = self.cn0(grid_d.squeeze(3))
        grid_d = (grid_d + self.mambaBlock0(grid_d)).unsqueeze(3) # bcn->bcn1

        grid_d = grid_d + self.clusterInter0(grid_d)

        idx_fn2 = knn(grid_d.squeeze(-1), k=self.k_num)
        grid_d = grid_d + self.lc1(get_graph_feature(grid_d, k=self.k_num, idx=idx_fn2))
        grid_d = self.cn1(grid_d.squeeze(3))
        grid_d = grid_d + self.mambaBlock1(grid_d)

        
        # BCN -> B1N -> BN
        logits = torch.squeeze(self.inlier_pre(grid_d), 1)
        # print('logits', logits)
        e_hat = weighted_8points(xs, logits)
        return grid_d, logits, e_hat


class ConvMatch(nn.Module):
    def __init__(self, config, use_gpu=True, knn_num=9, cluster=500):
        nn.Module.__init__(self)
        self.layer_num = config.layer_num
        self.init_project = InitProject(config.net_channels)
        # self.pos_embed = PositionEncoder(config.net_channels)
        self.layer_blocks = nn.Sequential(
            *[LayerBlock(config.net_channels, knn_num, cluster) for _ in range(self.layer_num)]
        )

    def forward(self, data):
        assert data['xs'].dim() == 4 and data['xs'].shape[1] == 1
        # batch_size, num_pts = data['xs'].shape[0], data['xs'].shape[2]
        # B1NC -> BCN
        input = data['xs'].transpose(1,3).squeeze(3)
        x1 = input[:,:2,:]
        # pos = x1 # B2N
        # pos_embed = self.pos_embed(pos) # BCN
        d = self.init_project(input) # + pos_embed # BCN

        res_logits, res_e_hat = [], []
        for i in range(self.layer_num):
            d, logits, e_hat = self.layer_blocks[i](data['xs'], d) # BCN
            res_logits.append(logits), res_e_hat.append(e_hat)
        return res_logits, res_e_hat 


def batch_symeig(X):
    # it is much faster to run symeig on CPU
    X = X.cpu()
    b, d, _ = X.size()
    bv = X.new(b,d,d)
    for batch_idx in range(X.shape[0]):
        e, v = torch.linalg.eigh(X[batch_idx,:,:].squeeze(), UPLO='U')
        bv[batch_idx,:,:] = v
    bv = bv.cuda()
    return bv


def weighted_8points(x_in, logits):
    # x_in: batch * 1 * N * 4
    x_shp = x_in.shape
    # Turn into weights for each sample
    weights = torch.relu(torch.tanh(logits))
    x_in = x_in.squeeze(1)
    
    # Make input data (num_img_pair x num_corr x 4)
    xx = torch.reshape(x_in, (x_shp[0], x_shp[2], 4)).permute(0, 2, 1)

    # Create the matrix to be used for the eight-point algorithm
    X = torch.stack([
        xx[:, 2] * xx[:, 0], xx[:, 2] * xx[:, 1], xx[:, 2],
        xx[:, 3] * xx[:, 0], xx[:, 3] * xx[:, 1], xx[:, 3],
        xx[:, 0], xx[:, 1], torch.ones_like(xx[:, 0])
    ], dim=1).permute(0, 2, 1)
    wX = torch.reshape(weights, (x_shp[0], x_shp[2], 1)) * X
    XwX = torch.matmul(X.permute(0, 2, 1), wX)
    

    # Recover essential matrix from self-adjoing eigen
    v = batch_symeig(XwX)
    e_hat = torch.reshape(v[:, :, 0], (x_shp[0], 9))

    # Make unit norm just in case
    e_hat = e_hat / torch.norm(e_hat, dim=1, keepdim=True)
    return e_hat

