from math import e
import torch
import torch.nn.functional as F
from utils import model_utils
import GMN.graphembeddingnetwork as gmngen
from loguru import logger
from torch_scatter import scatter_add

class FloodFillBase(torch.nn.Module):
    def __init__(self, conf, tensor_dim):
        super(FloodFillBase, self).__init__()
        self.conf = conf
        self.tensor_dim = tensor_dim
        self.device = self.conf.training.device
        self.max_set_size = self.conf.dataset.max_node_set_size
        self.edge_list = self.get_edge_list_for_flood_fill()
        self.list_of_elists, self.list_of_target_nodes = self.get_edge_list_for_opt_flood_fill()
        assert (self.edge_list.sort()[0] == torch.cat(self.list_of_elists,dim=-1).sort()[0]).all()
        self.flood_fill_iters = self.max_set_size   
        logger.info(f'Doing {self.flood_fill_iters} number of iterations')

        self.threshold_network = torch.nn.Identity()
        self.message_net = torch.nn.Sequential(
            torch.nn.Linear(2 * self.tensor_dim, 2 * self.tensor_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(2 * self.tensor_dim, self.tensor_dim)
            )
        self.scoring_net = torch.nn.Sequential(
            torch.nn.Linear(2 * self.tensor_dim, self.tensor_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.tensor_dim, self.tensor_dim)
            )
        self.mcis_score_network = torch.nn.Sequential(
                torch.nn.Linear(self.tensor_dim, self.tensor_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(self.tensor_dim, 1),
            )

    def get_edge_list_for_flood_fill(self):
        node_idx = torch.arange(self.max_set_size * self.max_set_size, device=self.device\
                                ).reshape(self.max_set_size, self.max_set_size)
        src_nodes_horz = node_idx[1:, :-1].flatten()
        dest_nodes_horz = node_idx[1:, 1:].flatten()
        src_nodes_vert = node_idx[:-1, 1:].flatten()
        dest_nodes_vert = node_idx[1:, 1:].flatten()
        src_nodes_diag = node_idx[:-1, :-1].flatten()
        dest_nodes_diag = node_idx[1:, 1:].flatten()
        all_src_nodes = torch.cat((src_nodes_horz, src_nodes_vert, src_nodes_diag))
        all_dest_nodes = torch.cat((dest_nodes_horz, dest_nodes_vert, dest_nodes_diag))
        all_edge_list = torch.stack((all_src_nodes, all_dest_nodes))
        return all_edge_list


    def get_edge_list_for_opt_flood_fill(self):
        node_idx = torch.arange(self.max_set_size * self.max_set_size, device=self.device\
                                ).reshape(self.max_set_size, self.max_set_size)
        
        num_wavefronts = 2*self.max_set_size - 1

        list_of_elists = []
        list_of_target_nodes = []
        for wave_idx in range(2,num_wavefronts):
            elist_src  = []
            elist_dest = []
            target_cell_list = []
            for src in range(1,self.max_set_size):
                if wave_idx-src < self.max_set_size and wave_idx-src > 0:
                    target_cell = node_idx[src,wave_idx-src]
                    elist_src.extend([ node_idx[src-1,wave_idx-src],\
                                    node_idx[src,wave_idx-src-1],\
                                    node_idx[src-1,wave_idx-src-1]])
                    elist_dest.extend([target_cell,target_cell,target_cell])
                    target_cell_list.append(target_cell)
            list_of_elists.append(
                    torch.stack((torch.tensor(elist_src,device=self.device),
                            torch.tensor(elist_dest,device=self.device)))
            )
            target_cell_list = torch.tensor(target_cell_list,device=self.device)
            assert (target_cell_list.sort()[0] == target_cell_list.sort()[0]).all()
            list_of_target_nodes.append(target_cell_list)

        return list_of_elists, list_of_target_nodes

    def batched_masked_max(self,final_node_states, gen_cgraph_mask):
        op = [] 
        for i in range(final_node_states.shape[0]):
            op.append(torch.max(final_node_states[i][~gen_cgraph_mask.flatten(-2,-1)[i]]))
        return torch.stack(op)

    def flood_fill(self, node_states, gen_qgraph_mask):
        original_node_states = node_states.clone()
        from_idx, to_idx = self.edge_list
        for _ in range(self.flood_fill_iters):
            from_states = node_states[:,from_idx,:]
            to_states = node_states[:,to_idx,:]
            edge_inputs = torch.cat([from_states, to_states], dim=-1)
            messages = self.message_net(edge_inputs)
            aggregated_messages = scatter_add(messages, to_idx, dim=-2, dim_size=original_node_states.shape[-2])
            scoring_input = torch.cat([aggregated_messages, original_node_states], dim=-1)
            node_states = self.scoring_net(scoring_input)
        final_node_states = self.mcis_score_network(node_states).squeeze(-1)

        return self.batched_masked_max(final_node_states, gen_qgraph_mask)
    
    
    @staticmethod
    def four_d_torch_matmul(t1, t2):
        return (t1.permute(0,3,1,2) @ t2.permute(0,3,1,2)).permute(0,2,3,1)
    
    
    def forward(self, corpus_batch_adj, gen_cgraph_mask, transport_plan, diagnostic_mode=False):

        if len(corpus_batch_adj.shape) == 3:
            permuted_corpus_batch_adj = (transport_plan.permute(0,2,1)@corpus_batch_adj@transport_plan).unsqueeze(-1)
        else:    
            permuted_corpus_batch_adj = torch.einsum("buvd, bvwd->buwd", torch.einsum("bijd, bikd->bjkd", transport_plan.unsqueeze(-1), corpus_batch_adj), transport_plan.unsqueeze(-1)) #option2 
       
       
       
        flattened_indicator = permuted_corpus_batch_adj.flatten(start_dim=-3, end_dim=-2)

        mcis_scores = self.flood_fill(flattened_indicator, gen_cgraph_mask)
        
        return mcis_scores




class NANL_FF_clique(torch.nn.Module):
    def __init__(self,conf, gmn_config):
        super(NANL_FF_clique, self).__init__()
        self.conf = conf
        self.gmn_config = gmn_config
        self.max_node_set_size = conf.dataset.max_node_set_size
        self.max_edge_set_size = conf.dataset.max_edge_set_size 
        self.device = conf.training.device
        self.delta = conf.model.delta
        self.mask_sinkhorn = conf.model.mask_sinkhorn
        self.sinkhorn_temp = conf.model.sinkhorn_temp
        self.sinkhorn_num_iters = conf.model.sinkhorn_num_iters

        self.graph_size_to_mask_map = model_utils.graph_size_to_mask_map(
            max_set_size=self.max_node_set_size, lateral_dim=self.max_node_set_size, device=self.device
        )
        
        self.set_size_to_mask_map = model_utils.set_size_to_mask_map(
            max_set_size=self.max_node_set_size, device=self.device
        )

        self.encoder = gmngen.GraphEncoder(**self.gmn_config["encoder"])
        prop_config = self.gmn_config["graph_embedding_net"].copy()
        prop_config.pop("n_prop_layers", None)
        prop_config.pop("share_prop_params", None)
        self.prop_layer = gmngen.GraphPropLayer(**prop_config)
        self.propagation_steps = self.gmn_config["graph_embedding_net"]["n_prop_layers"]

        self.sinkhorn_feature_layers = torch.nn.Sequential(
            torch.nn.Linear(prop_config["node_state_dim"], self.max_node_set_size),
            torch.nn.ReLU(),
            torch.nn.Linear(self.max_node_set_size, self.max_node_set_size)
        )
        if self.conf.model.ff_variant == "SepLRL":
            self.tensor_dim = prop_config["edge_hidden_sizes"][-1]
        else:
            self.tensor_dim = 1

        self.ff_network = FloodFillBase(conf, self.tensor_dim)
        if self.conf.model.ff_variant == "SepLRL" or self.conf.model.ff_variant == "SepLRL_dense":
            self.lrl_network  = torch.nn.Sequential(
            torch.nn.Linear(
                prop_config["edge_hidden_sizes"][-1] + 1, self.tensor_dim
            ),
            torch.nn.ReLU(),
            torch.nn.Linear(self.tensor_dim, self.tensor_dim),
            )
            

    def get_graph_encodings(self, graph_batch, graph_sizes):
        node_features, edge_features, from_idx, to_idx, _ = model_utils.get_graph_features(graph_batch)

        node_features_enc, edge_features_enc = self.encoder(node_features, edge_features)
        for _ in range(self.propagation_steps) :
            node_features_enc = self.prop_layer(node_features_enc, from_idx, to_idx, edge_features_enc)
        
        stacked_features_query = model_utils.split_and_stack_singles(
            node_features_enc, graph_sizes, self.max_node_set_size
        )
        return stacked_features_query

    def forward(self,  
                query_batch_data, 
                query_batch_data_node_sizes, 
                query_batch_data_edge_sizes, 
                query_batch_adj, 
                corpus_batch_data, 
                corpus_batch_data_node_sizes, 
                corpus_batch_data_edge_sizes, 
                corpus_batch_adj,
                diagnostic_mode=False):

        query_sizes = query_batch_data_node_sizes 
        corpus_sizes = corpus_batch_data_node_sizes 

        stacked_features_corpus = self.get_graph_encodings(corpus_batch_data, corpus_sizes)
        stacked_features_query = self.get_graph_encodings(query_batch_data, query_sizes)
        
        transformed_features_corpus = self.sinkhorn_feature_layers(stacked_features_corpus)

        def mask_graphs(features, graph_sizes):
            mask = torch.stack([self.graph_size_to_mask_map[i] for i in graph_sizes])
            return mask * features

        masked_features_corpus = mask_graphs(transformed_features_corpus, corpus_sizes)

        gen_cgraph_mask = torch.stack([self.set_size_to_mask_map[x] == 0 for x in corpus_sizes])
        
        sinkhorn_input = masked_features_corpus
        if self.mask_sinkhorn:
            transport_plan = model_utils.pytorch_sinkhorn_iters_mask(log_alpha=sinkhorn_input, mask=gen_cgraph_mask, device=self.device, temperature=self.sinkhorn_temp, noise_factor = 0, num_iters=self.sinkhorn_num_iters)
        else:
            transport_plan = model_utils.pytorch_sinkhorn_iters(log_alpha=sinkhorn_input, device=self.device, temperature=self.sinkhorn_temp, noise_factor = 0, num_iters=self.sinkhorn_num_iters)
           
           
        stacked_corpus_batch_adj = torch.stack(corpus_batch_adj)
        if self.conf.model.ff_variant == "HcHcT":
            stacked_corpus_batch_adj_for_ff = (stacked_corpus_batch_adj * (stacked_features_corpus @ stacked_features_corpus.permute(0,2,1)))#[...,None]
        else:
            stacked_corpus_batch_adj_for_ff = stacked_corpus_batch_adj
        
        ff_output = self.ff_network(stacked_corpus_batch_adj_for_ff, gen_cgraph_mask, transport_plan) 

        if diagnostic_mode:
            return stacked_features_query, stacked_features_corpus, transport_plan, transformed_features_corpus, masked_features_corpus, corpus_batch_adj, corpus_batch_data_node_sizes, gen_cgraph_mask, ff_output
        
        return -torch.nn.ReLU()(stacked_features_query[:, None,:, :] - transport_plan@stacked_features_corpus).sum(dim=(-1,-2)).T, ff_output