import torch
import itertools
import numpy as np
from scipy.optimize import linear_sum_assignment
import torch.nn.functional as F
import torch.nn as nn


from models.OTGM.sconv_archs import SiameseSConvOnNodes, SiameseNodeFeaturesToEdgeFeatures
from src.feature_align import feature_align
from src.factorize_graph_matching import construct_aff_mat
from src.utils.pad_tensor import pad_tensor
from src.lap_solvers.sinkhorn import Sinkhorn
from src.lap_solvers.hungarian import hungarian

from src.utils.config import cfg

from src.backbone import *

from src.loss_func import *

CNN = eval(cfg.BACKBONE)


def lexico_iter(lex):
    return itertools.combinations(lex, 2)


def normalize_over_channels(x):
    channel_norms = torch.norm(x, dim=1, keepdim=True)
    return x / channel_norms


def concat_features(embeddings, num_vertices):
    res = torch.cat([embedding[:, :num_v] for embedding, num_v in zip(embeddings, num_vertices)], dim=-1)
    return res.transpose(0, 1)


class InnerProduct(nn.Module):
    def __init__(self, output_dim):
        super(InnerProduct, self).__init__()
        self.d = output_dim

    def _forward(self, X, Y):
        assert X.shape[1] == Y.shape[1] == self.d, (X.shape[1], Y.shape[1], self.d)
        X = torch.nn.functional.normalize(X, dim=-1)
        Y = torch.nn.functional.normalize(Y, dim=-1)
        res = torch.matmul(X, Y.transpose(0, 1))
        return res

    def forward(self, Xs, Ys):
        return [self._forward(X, Y) for X, Y in zip(Xs, Ys)]


class Backbone(CNN):
    def __init__(self):
        super(Backbone, self).__init__()
        self.message_pass_node_features = SiameseSConvOnNodes(input_node_dim=cfg.OTGM.FEATURE_CHANNEL * 2)
        self.build_edge_features_from_node_features = SiameseNodeFeaturesToEdgeFeatures(
            total_num_nodes=self.message_pass_node_features.num_node_features
        )
        self.vertex_affinity = InnerProduct(256)
        self.rescale = cfg.PROBLEM.RESCALE
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / cfg.OTGM.SOFTMAXTEMP))

        self.projection = nn.Sequential(
            nn.Linear(1024, 1024, bias=True),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 256, bias=True),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )

    def forward(self, data_dict, online=True):
        with torch.no_grad():
            self.logit_scale.clamp_(0, 4.6052)  

        images = data_dict['images']
        points = data_dict['Ps']
        n_points = data_dict['ns']
        graphs = data_dict['pyg_graphs']
        batch_size = data_dict['batch_size']
        num_graphs = len(images)
        orig_graph_list = []

        for image, p, n_p, graph in zip(images, points, n_points, graphs):
            nodes = self.node_layers(image)
            edges = self.edge_layers(nodes)

            nodes = normalize_over_channels(nodes)
            edges = normalize_over_channels(edges)

            U = feature_align(nodes, p, n_p, self.rescale)
            F = feature_align(edges, p, n_p, self.rescale)
            U = concat_features(U, n_p)
            F = concat_features(F, n_p)
            node_features = torch.cat((U, F), dim=1)

            graph.x = node_features
            graph = self.message_pass_node_features(graph)
            orig_graph = self.build_edge_features_from_node_features(graph)
            orig_graph_list.append(orig_graph)

        unary_affs_list = [
            self.vertex_affinity([self.projection(item.x) for item in g_1], [self.projection(item.x) for item in g_2])
            for (g_1, g_2) in lexico_iter(orig_graph_list)
        ]

        keypoint_number_list = []  
        node_feature_list = [] 

        node_feature_graph1 = torch.zeros([batch_size, data_dict['gt_perm_mat'].shape[1], node_features.shape[1]],
                                         device=node_features.device)
        node_feature_graph2 = torch.zeros([batch_size, data_dict['gt_perm_mat'].shape[2], node_features.shape[1]],
                                         device=node_features.device)
        for index in range(batch_size):
            node_feature_graph1[index, :orig_graph_list[0][index].x.shape[0]] = orig_graph_list[0][index].x
            node_feature_graph2[index, :orig_graph_list[1][index].x.shape[0]] = orig_graph_list[1][index].x
            keypoint_number_list.append(torch.sum(data_dict['gt_perm_mat'][index]))
        number = int(sum(keypoint_number_list)) 

        node_feature_graph2 = torch.bmm(data_dict['gt_perm_mat'], node_feature_graph2)
        final_node_feature_graph1 = torch.zeros([number, node_features.shape[1]], device=node_features.device)
        final_node_feature_graph2 = torch.zeros([number, node_features.shape[1]], device=node_features.device)
        count = 0
        for index in range(batch_size):
            final_node_feature_graph1[count: count + int(keypoint_number_list[index])] \
                = node_feature_graph1[index, :int(keypoint_number_list[index])]
            final_node_feature_graph2[count: count + int(keypoint_number_list[index])] \
                = node_feature_graph2[index, :int(keypoint_number_list[index])]
            count += int(keypoint_number_list[index])
        node_feature_list.append(self.projection(final_node_feature_graph1))
        node_feature_list.append(self.projection(final_node_feature_graph2))

        if online == False:
            return node_feature_list
        elif online == True:
            x_list = []
            for unary_affs, (idx1, idx2) in zip(unary_affs_list, lexico_iter(range(num_graphs))):
                Kp = torch.stack(pad_tensor(unary_affs), dim=0)
                x = hungarian(Kp, n_points[idx1], n_points[idx2])
                x_list.append(x)
            return node_feature_list, x_list


class CombinedGraphModel(nn.Module):
    def __init__(self, lambda_param=0.5, epsilon=0.01):
        super(CombinedGraphModel, self).__init__()
        self.lambda_param = lambda_param
        self.epsilon = epsilon
        self.softmax = nn.Softmax(dim=1)
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, x, z, x_prime, z_prime, P_A, P_B):
        cost_matrix = self.lambda_param * torch.cdist(x, z, p=2) + \
                      (1 - self.lambda_param) * self.compute_l(x, z, x_prime, z_prime)
        T = self.sinkhorn_iterations(cost_matrix)
        InfoNCE_loss = self.compute_InfoNCE_loss(P_A, P_B, T)
        ot_loss = (T * cost_matrix).sum()
        total_loss = ot_loss + InfoNCE_loss
        return total_loss

    def compute_l(self, x, z, x_prime, z_prime):
        c1 = torch.cdist(x, x_prime, p=2)
        c2 = torch.cdist(z, z_prime, p=2)
        return torch.abs(c1 - c2)

    def sinkhorn_iterations(self, cost_matrix, max_iter=100, tau=1e-3):
        K = torch.exp(-cost_matrix / self.epsilon)
        r = torch.ones(cost_matrix.size(0), device=cost_matrix.device) / cost_matrix.size(0)
        c = torch.ones(cost_matrix.size(1), device=cost_matrix.device) / cost_matrix.size(1)
        u = torch.ones_like(r)
        v = torch.ones_like(c)

        for _ in range(max_iter):
            u_prev = u.clone()
            v_prev = v.clone()
            u = r / (K @ v)
            v = c / (K.t() @ u)
            if torch.max(torch.abs(u - u_prev)) < tau and torch.max(torch.abs(v - v_prev)) < tau:
                break

        T = torch.diag(u) @ K @ torch.diag(v)
        return T

    def compute_InfoNCE_loss(self, P_A, P_B, T):
        In = torch.eye(P_A.size(0)).to(P_A.device)
        S_A = self.softmax(P_A @ T @ P_B.T)
        S_B = self.softmax(P_B @ T.T @ P_A.T)
        return self.cross_entropy(In, S_A) + self.cross_entropy(In, S_B)


class GraphDenoising(nn.Module):
    def __init__(self, num_nodes, hidden_dim=128):
        super(GraphDenoising, self).__init__()
        self.num_nodes = num_nodes
        self.mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),  
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, features_A, features_B):
        logits_A = self.compute_logits(features_A)
        logits_B = self.compute_logits(features_B)
        
        probs_A = torch.sigmoid(logits_A)
        probs_B = torch.sigmoid(logits_B)
        
        M_A = torch.bernoulli(probs_A)
        M_B = torch.bernoulli(probs_B)

        F_prime_A = features_A * M_A
        F_prime_B = features_B * M_B

        return F_prime_A, F_prime_B, probs_A, probs_B

    def compute_logits(self, features):
        num_nodes = features.size(0)
        logits = torch.zeros(num_nodes, num_nodes, device=features.device)
        
        for i in range(num_nodes):
            for j in range(num_nodes):
                if i != j:  
                    combined_features = torch.cat([features[i], features[j]], dim=0)
                    logits[i, j] = self.mlp(combined_features)
        
        return logits

    def loss(self, probs_A, probs_B):
        loss_A = F.binary_cross_entropy_with_logits(probs_A, torch.zeros_like(probs_A))
        loss_B = F.binary_cross_entropy_with_logits(probs_B, torch.zeros_like(probs_B))
        return loss_A + loss_B

class Net(nn.Module):
    def __init__(self, beta=0.5):
        super(Net, self).__init__()
        self.onlineNet = Backbone()
        self.momentumNet = Backbone()
        self.momentum = cfg.OTGM.MOMENTUM
        self.beta = beta  

        self.backbone_params = list(self.onlineNet.backbone_params)
        self.warmup_step = cfg.OTGM.WARMUP_STEP
        self.epoch_iters = cfg.TRAIN.EPOCH_ITERS

        self.model_pairs = [[self.onlineNet, self.momentumNet]]
        self.copy_params()

        assert cfg.PROBLEM.TYPE == '2GM'

        self.ot_module = OptimalTransport(lambda_param=cfg.OTGM.OT_LAMBDA)
        self.denoising_module = GraphDenoising(num_nodes=cfg.NUM_NODES, hidden_dim=128)

    def forward(self, data_dict, training=False, iter_num=0, epoch=0):
        if epoch * self.epoch_iters + iter_num >= self.warmup_step:
            alpha = cfg.OTGM.ALPHA
        else:
            alpha = cfg.OTGM.ALPHA * min(1, (epoch * self.epoch_iters + iter_num) / self.warmup_step)

        node_feature_list, x_list = self.onlineNet(data_dict)

        if training:
            with torch.no_grad():
                self._momentum_update()
            node_feature_m_list = self.momentumNet(data_dict, online=False)

            ot_loss = self.ot_module(node_feature_list, node_feature_m_list, node_feature_list, node_feature_m_list)
            P_A, P_B, V_A, V_B = data_dict['P_A'], data_dict['P_B'], data_dict['V_A'], data_dict['V_B']
            _, _, probs_A, probs_B = self.denoising_module(V_A, V_B)
            gd_loss = self.denoising_module.loss(probs_A, probs_B)

            combined_loss = self.beta * ot_loss + (1 - self.beta) * gd_loss

            data_dict.update({
                'perm_mat': x_list[0],
                'loss': combined_loss,
                'ds_mat': None,
            })

        return data_dict

    @torch.no_grad()
    def copy_params(self):
        for model_pair in self.model_pairs:
            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
                param_m.data.copy_(param.data)
                param_m.requires_grad = False

    @torch.no_grad()
    def _momentum_update(self):
        for model_pair in self.model_pairs:
            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
                param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)

