from abc import ABC
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch_geometric.nn as pyg_nn
from inspect import signature
from loguru import logger

from ..utils import stack_hidden

import torch_geometric.nn as pyg_nn
from inspect import signature
from loguru import logger
import torch
import torch.nn as nn
import copy



NONE_DUMY = "none_dummy"


######################

# FE Modifications

class BFSConv(pyg_nn.MessagePassing):
    def __init__(self, aggr = "min"):
        super().__init__(aggr=aggr)

    def forward(self, distances, edge_index):
        msg = self.propagate(edge_index, x=distances)
        return torch.minimum(msg, distances)

    def message(self, x_j):
        return x_j + 1

class BFS(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = BFSConv()
    
    def forward(self, data, starts, starting_nodes_per_graph=1):
        edge_index = data.edge_index
        distances = torch.empty(data.num_nodes, starting_nodes_per_graph).fill_(float('Inf')).to(edge_index.device)
        distances[starts] = 0
        counter = data.num_nodes + 1
        while float('Inf') in distances and counter >= 0:       #Added counter to deal with disconnected components
            distances = self.conv(distances, edge_index)
            counter = counter -1
        return distances

######
    
class FloodModel(pyg_nn.MessagePassing):


    def __init__(self, 
                 in_channels, out_channels=1, 
                 num_rounds=1, num_k = 1,
                 hidden_dim=32, hidden_state_factor=2, mlp_depth = 1, dropout = 0, normalization = torch.nn.LayerNorm, activation = torch.nn.ReLU(),
                 aggregation = "add", conv = "gru", prediction_mode = "node", start_mode = "single", start_selection = "fixed", pool_mode = "sum", train_mode = "single", is_weighted = False):
        super().__init__()

        self.num_rounds = num_rounds
        self.num_k = num_k
        self.dropout = dropout

        self.input_dim = in_channels
        self.output_dim = out_channels
        self.hidden_dim = out_channels

        self.edge_dim = self.hidden_dim *2+1 if is_weighted else self.hidden_dim *2



        self.encoder = self.get_mlp(self.input_dim,  hidden_state_factor*self.hidden_dim, self.hidden_dim, mlp_depth, normalization, activation, last_relu=True)
        self.decoder = self.get_mlp(self.hidden_dim, hidden_state_factor*self.hidden_dim, self.output_dim, mlp_depth, normalization, activation, last_relu=False)


        edge_mlp    = self.get_mlp(self.edge_dim, hidden_state_factor * 2 * self.hidden_dim, self.hidden_dim, mlp_depth, normalization, activation, last_relu=True)
        state_mlp   = self.get_mlp(self.hidden_dim,   hidden_state_factor* self.hidden_dim, self.hidden_dim, mlp_depth, normalization, activation, last_relu=True)

        num_convs = num_rounds

        if conv == "gru": 
            self.DownConv   = nn.ModuleList([GRUMLPConv(in_channels=self.hidden_dim, out_channels=self.hidden_dim, mlp_edge = copy.deepcopy(edge_mlp), aggr=aggregation) for _ in range(num_convs)])
            self.UpConv     = nn.ModuleList([GRUMLPConv(in_channels=self.hidden_dim, out_channels=self.hidden_dim, mlp_edge = copy.deepcopy(edge_mlp), aggr=aggregation) for _ in range(num_convs)])
            self.CrossConv  = nn.ModuleList([GRUMLPConv(in_channels=self.hidden_dim, out_channels=self.hidden_dim, mlp_edge = copy.deepcopy(edge_mlp), aggr=aggregation) for _ in range(num_convs)])
        elif conv == "gin":
            self.DownConv   = nn.ModuleList([GINConv(copy.deepcopy(state_mlp), aggr=aggregation) for _ in range(num_convs)])
            self.UpConv     = nn.ModuleList([GINConv(copy.deepcopy(state_mlp), aggr=aggregation) for _ in range(num_convs)])
            self.CrossConv  = nn.ModuleList([GINConv(copy.deepcopy(state_mlp), aggr=aggregation) for _ in range(num_convs)])
        elif conv == "gin-mlp":
            self.DownConv   = nn.ModuleList([GINMLPConv(mlp=copy.deepcopy(state_mlp),mlp_edge=copy.deepcopy(edge_mlp), aggr=aggregation) for _ in range(num_convs)])
            self.UpConv     = nn.ModuleList([GINMLPConv(mlp=copy.deepcopy(state_mlp),mlp_edge=copy.deepcopy(edge_mlp), aggr=aggregation) for _ in range(num_convs)])

        else:
            print("Unknown convolution " + conv)
            exit(1)

        self.prediction_mode = prediction_mode
        self.start_mode = start_mode
        self.start_selection = start_selection
        self.pool_mode = pool_mode
        self.train_mode = train_mode

    def dist_mask(self, edge_index, dist_edge_index, a, b, edge_weight = NONE_DUMY):
        m1 = (dist_edge_index[0] == a)
        m2 = (dist_edge_index[1] == b)
        m = m1&m2
        ret = torch.masked_select(edge_index,m).view((2,-1))
        masked_edge_weight = NONE_DUMY
        if edge_weight != NONE_DUMY:
            masked_edge_weight = torch.masked_select(edge_weight,m)#.view((2,-1))

        return ret, masked_edge_weight

    def masked_update(self, x, x_new, m):
        mask = torch.zeros(x.shape[0],1).to(x.device).index_fill(0, m, 1).view((-1,1))
        inverse_mask = torch.ones(x.shape[0], 1).to(x.device) - mask
        x = mask.float() * x_new + inverse_mask.float() * x
        return x

    def get_mlp(self, input_dim, hidden_dim, output_dim, mlp_depth, normalization, activation, last_relu = True):
        modules = [torch.nn.Linear(input_dim, int(hidden_dim)), normalization(int(hidden_dim)), activation, torch.nn.Dropout(self.dropout)]
        for i in range(0, int(mlp_depth)):
            modules = modules + [torch.nn.Linear(int(hidden_dim), int(hidden_dim)), normalization(int(hidden_dim)), activation, torch.nn.Dropout(self.dropout)]
        modules = modules + [torch.nn.Linear(int(hidden_dim), output_dim)]
        
        if last_relu:
            modules.append(normalization(output_dim))
            modules.append(activation)

        return torch.nn.Sequential(*modules)

    def step(self, edge_index, dist_e_index, a, b, conv, x, edge_weight=NONE_DUMY):
        masked_edges, masked_edge_weight = self.dist_mask(edge_index, dist_e_index, a, b, edge_weight)
        m = masked_edges[1]                
        x_new = conv(x,masked_edges,x, masked_edge_weight)
        return self.masked_update(x, x_new, m)

    def forward(self, x, edge_index, edge_weight=NONE_DUMY, data=None):
        # Get data and node features
        #x, edge_index = data.x, data.edge_index
        x = x.to(torch.float)

        precomp = BFS()
        x_agg = torch.zeros((data.num_nodes, self.hidden_dim, self.num_k)).to(x.device)

        # Choosing correct starting nodes according to different modes
        flood_start = None
        if self.start_mode == "single": # if self.start_mode is all, then the starts are given as fixed always

            if self.start_selection == "random": 
                flood_start = torch.stack([torch.randint(data.ptr[i], data.ptr[i+1], (self.num_k, )) for i in range(data.num_graphs)])
            elif self.start_selection == "fixed" and hasattr(data, "root") and data.root is not None:
                flood_start = data.root + data.ptr[:-1]
                flood_start = flood_start.view((-1,1))
            else:
                #print("WARNING: underspecified starting nodes for this task, defaulting to start at 0")
                flood_start = torch.stack([torch.randint(data.ptr[i], data.ptr[i]+1, (self.num_k, )) for i in range(data.num_graphs)])
        elif self.start_mode == "all": # a fixed start should be chosen anyway
            if hasattr(data, "root") and data.root is not None:
                flood_start = data.root + data.ptr[:-1]
                flood_start = flood_start.view((-1,1))
            else:
                print("WARNING: start_mode = all, but no starting nodes defined, defaulting to start at 0")
                flood_start = torch.stack([torch.randint(data.ptr[i], data.ptr[i]+1, (self.num_k, )) for i in range(data.num_graphs)])
        else: 
            print("Unknown start_mode " + self.start_mode)
            exit(1) 


        if hasattr(data, "s") and data.s.shape == data.num_nodes:
            flood_start = torch.nonzero(data.s).view(-1,1)
        elif hasattr(data, "s") and data.s.shape == torch.Size([data.num_nodes]):
            flood_start = torch.nonzero(data.s).view(-1,1)
        else: 
            pass


        # Do the flooding phases for each starting node
        for k in range(self.num_k):

            #precompute distances for wave activation pattern
            starts = flood_start[:,k]
            D = precomp(data, starts).view((1,-1))[0]
            dist_e_index = D[edge_index]
            maxD = torch.max(D[D != float('Inf')]).long().item()

            #encode the input
            x_k = self.encoder(x)

            # do the flooding phases
            for phases in range(self.num_rounds):
                # flood down
                for flood in range(maxD):
                    x_k = self.step(edge_index, dist_e_index, flood, flood+1, self.DownConv[phases], x_k,edge_weight)
                    x_k = self.step(edge_index, dist_e_index, flood+1, flood+1, self.CrossConv[phases], x_k,edge_weight)

                # echo back
                for echo in range(maxD,0,-1):
                    x_k = self.step(edge_index, dist_e_index, echo, echo, self.CrossConv[phases], x_k,edge_weight)
                    x_k = self.step(edge_index, dist_e_index, echo, echo-1, self.UpConv[phases], x_k,edge_weight)

            x_agg[:,:,k] = x_k

        # aggregate the results of the k runs to a single hidden representation 
        x_agg = torch.mean(x_agg, 2)        

        # either return all nodes or only start node embeddings (or even pool them to a single graph embedding)
        if self.prediction_mode == "node":
            if self.start_mode == "all":
                x_agg = x_agg[starts]

        if self.prediction_mode == "graph":
            if self.num_k == 1:
                x_agg = x_agg[starts]
            else: 
                print("undefined behaviour for graph prediction with more than a single starting node")
                # should probably either pool all nodes, or pool all starting nodes
                exit(1)

        # decode the hidden representation to the output
        x_agg = self.decoder(x_agg)
        return x_agg

   

######################

class PGN(pyg_nn.MessagePassing):
    """Adapted from https://github.com/google-deepmind/clrs/blob/64e016998f14305f94cf3f6d19ac9d7edc39a185/clrs/_src/processors.py#L330"""
    def __init__(self, in_channels, out_channels, aggr, mid_act=None, activation=nn.ReLU()):
        super(PGN, self).__init__(aggr=aggr)
        logger.info(f"PGN: in_channels: {in_channels}, out_channels: {out_channels}")
        self.in_channels = in_channels
        self.mid_channels = out_channels
        self.mid_act = mid_act
        self.out_channels = out_channels
        self.activation = activation

        # Message MLPs
        self.m_1 = nn.Linear(in_channels, self.mid_channels) # source node
        self.m_2 = nn.Linear(in_channels, self.mid_channels) # target node
        
        self.msg_mlp = nn.Sequential(
            nn.ReLU(),
            nn.Linear(self.mid_channels, self.mid_channels),
            nn.ReLU(),
            nn.Linear(self.mid_channels, self.mid_channels)
        )

        # Edge weight scaler
        self.edge_weight_scaler = nn.Linear(1, self.mid_channels)

        # Output MLP
        self.o1 = nn.Linear(in_channels, out_channels) # skip connection
        self.o2 = nn.Linear(self.mid_channels, out_channels)

        
        # We do not support graph level features for now

    def forward(self, x, edge_index, edge_weight=NONE_DUMY):
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        h_1 = self.o1(x)
        h_2 = self.o2(out)
        out = h_1 + h_2
        if self.activation is not None:
            out = self.activation(out)
        return out
    
    def message(self, x_j, x_i, edge_weight=NONE_DUMY):
        # j is source, i is target
        msg_1 = self.m_1(x_j)
        msg_2 = self.m_2(x_i)
        
        
        msg = msg_1 + msg_2        
        if edge_weight is not NONE_DUMY:
            msg_e = self.edge_weight_scaler(edge_weight.reshape(-1, 1))
            msg = msg + msg_e
        
        msg = self.msg_mlp(msg)


        if self.mid_act is not None:
            msg = self.mid_act(msg)

        return msg

######################
# Modules from https://github.com/floriangroetschla/Recurrent-GNNs-for-algorithm-learning/blob/main/model.py
# Adapted to work with edge weights

class GRUConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels, out_channels, aggr):
        super(GRUConv, self).__init__(aggr=aggr)
        logger.info(f"GRUConv: in_channels: {in_channels}, out_channels: {out_channels}")
        self.rnn = torch.nn.GRUCell(in_channels, out_channels)
        self.edge_weight_scaler = nn.Linear(1, in_channels)

    def forward(self, x, edge_index, edge_weight, last_hidden):
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        out = self.rnn(out, last_hidden)
        return out

    def message(self, x_j, edge_weight):
        return F.relu(x_j + self.edge_weight_scaler(edge_weight.unsqueeze(-1)))


class GRUMLPConvOG(pyg_nn.MessagePassing):
    def __init__(self, in_channels, out_channels, mlp_edge, aggr):
        super(GRUMLPConv, self).__init__(aggr=aggr)
        self.rnn = torch.nn.GRUCell(in_channels, out_channels)
        self.mlp_edge = mlp_edge

    def forward(self, x, edge_index, last_hidden):
        out = self.propagate(edge_index, x=x)
        out = self.rnn(out, last_hidden)
        return out

    def message(self, x_j, x_i):
        concatted = torch.cat((x_j, x_i), dim=-1)
        return self.mlp_edge(concatted) 

class GRUMLPConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels, out_channels, mlp_edge, aggr):
        super(GRUMLPConv, self).__init__(aggr=aggr)
        self.rnn = torch.nn.GRUCell(in_channels, out_channels)
        self.mlp_edge = mlp_edge

    def forward(self, x, edge_index, last_hidden, edge_weight=NONE_DUMY):
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        out = self.rnn(out, last_hidden)
        return out

    def message(self, x_j, x_i, edge_weight=NONE_DUMY):
        concatted = torch.cat((x_j, x_i), dim=-1)
        if edge_weight is not NONE_DUMY:
            concatted = torch.cat((concatted, edge_weight.unsqueeze(-1)), dim=-1)
        return self.mlp_edge(concatted) 

def _gruconv_module(in_channels, out_channels, aggr="add"):
    return GRUConv(in_channels, out_channels, aggr)

def _grumlpconv_module(in_channels, out_channels, aggr="add", layers=2, dropout=0.0, use_bn=False, is_weighted=False):
    input_dim = in_channels*2+1 if is_weighted else in_channels*2
    mlp  = nn.Sequential(
        nn.Linear(input_dim, in_channels)
    )
    if use_bn:
        logger.debug(f"Using batch norm in GIN module")
        mlp.add_module(f"bn_input", nn.BatchNorm1d(in_channels))
    for _ in range(layers-1):
        mlp.add_module(f"relu_{_}", nn.ReLU())
        mlp.add_module(f"linear_{_}", nn.Linear(in_channels, in_channels))
        if use_bn:
            logger.debug(f"Using batch norm in GIN module")
            mlp.add_module(f"bn_{_}", nn.BatchNorm1d(in_channels))
    if dropout > 0:
        mlp.add_module(f"dropout", nn.Dropout(dropout))
    return GRUMLPConv(in_channels, out_channels, mlp, aggr)

######################

def _gin_module(in_channels, out_channels, eps=0, train_eps=False, layers=2, dropout=0.0, use_bn=False, aggr="add"):
    mlp = nn.Sequential(
        nn.Linear(in_channels, out_channels),
    )
    if use_bn:
        logger.debug(f"Using batch norm in GIN module")
        mlp.add_module(f"bn_input", nn.BatchNorm1d(out_channels))
    for _ in range(layers-1):
        mlp.add_module(f"relu_{_}", nn.ReLU())
        mlp.add_module(f"linear_{_}", nn.Linear(out_channels, out_channels))
        if use_bn:
            logger.debug(f"Using batch norm in GIN module")
            mlp.add_module(f"bn_{_}", nn.BatchNorm1d(out_channels))
    if dropout > 0:
        mlp.add_module(f"dropout", nn.Dropout(dropout))
    return pyg_nn.GINConv(mlp, eps, train_eps, aggr=aggr)

def _gine_module(in_channels, out_channels, eps=0, train_eps=False, layers=2, dropout=0.0, use_bn=False, edge_dim=1, aggr="add"):
    
    mlp = nn.Sequential(
        nn.Linear(in_channels, out_channels),
    )
    if use_bn:
        logger.debug(f"Using batch norm in GIN module")
        mlp.add_module(f"bn_input", nn.BatchNorm1d(out_channels))
    for _ in range(layers-1):
        mlp.add_module(f"relu_{_}", nn.ReLU())
        mlp.add_module(f"linear_{_}", nn.Linear(out_channels, out_channels))
        if use_bn:
            logger.debug(f"Using batch norm in GIN module")
            mlp.add_module(f"bn_{_}", nn.BatchNorm1d(out_channels))
    if dropout > 0:
        mlp.add_module(f"dropout", nn.Dropout(dropout))
    
    return pyg_nn.GINEConv(mlp, eps, train_eps, edge_dim=edge_dim, aggr=aggr)

def _get_processor(name):
    if name == "GCNConv":
        return pyg_nn.GCNConv
    elif name == "GINConv":
        return _gin_module    
    elif name == "GINEConv":
        return _gine_module
    elif name == "GRUConv":
        return _gruconv_module
    elif name == "RecGNNConv": # initially called GRUMLPConv
        return _grumlpconv_module
    elif name == "PGN":
        return PGN
    elif name == "FE":
        return FloodModel
    else:
        raise ValueError(f"Unknown processor {name}")
    
class Processor(nn.Module, ABC):
    def __init__(self, cfg, has_randomness=False):
        super().__init__()
        self.cfg = cfg        
        processor_input = self.cfg.MODEL.HIDDEN_DIM*3 if self.cfg.MODEL.PROCESSOR_USE_LAST_HIDDEN else self.cfg.MODEL.HIDDEN_DIM*2
        if has_randomness:
            processor_input += 1
        self.core = _get_processor(self.cfg.MODEL.PROCESSOR.NAME)(in_channels=processor_input, out_channels=self.cfg.MODEL.HIDDEN_DIM, **self.cfg.MODEL.PROCESSOR.KWARGS[0])
        if self.cfg.MODEL.PROCESSOR.LAYERNORM.ENABLE:
            self.norm = pyg_nn.LayerNorm(self.cfg.MODEL.HIDDEN_DIM, mode=self.cfg.MODEL.PROCESSOR.LAYERNORM.MODE)
        
        self._core_requires_last_hidden = "last_hidden" in signature(self.core.forward).parameters

    def forward(self, input_hidden, hidden, last_hidden, batch_assignment, randomness=None, **kwargs):
        stacked = stack_hidden(input_hidden, hidden, last_hidden, self.cfg.MODEL.PROCESSOR_USE_LAST_HIDDEN)
        if randomness is not None:
            stacked = torch.cat((stacked, randomness.unsqueeze(1)), dim=-1)
        if self._core_requires_last_hidden:
            kwargs["last_hidden"] = last_hidden
        out = self.core(stacked, **kwargs)
        if self.cfg.MODEL.PROCESSOR.LAYERNORM.ENABLE:
            # norm
            out = self.norm(out, batch=batch_assignment)
        return out

    def has_edge_weight(self):
        return "edge_weight" in signature(self.core.forward).parameters
    
    def has_edge_attr(self):
        return "edge_attr" in signature(self.core.forward).parameters



