import torch.nn as nn
import torch
import copy
from utils import NodeType
from torch_scatter import scatter_add, scatter_mean
import torch.nn.functional as F


def build_mlp(in_size, hidden_size, out_size, lay_norm=True):

    module = nn.Sequential(nn.Linear(in_size, hidden_size), nn.ReLU(), 
                           nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, out_size))
    if lay_norm: return nn.Sequential(module,  nn.LayerNorm(normalized_shape=out_size))
    return module   
  
class Encoder(nn.Module):

    def __init__(self,
                node_input_size,
                element_input_size, 
                face_input_size, 
                face_face_input_size, 
                scripted_input_size,
                node_num,
                enode_num,
                knode_num, 
                kface_num, 
                kf_node,
                device,
                hidden_size = 128):
        super(Encoder, self).__init__()

        self.nb_encoder = build_mlp(node_input_size, hidden_size, hidden_size)
        self.element_encoder = build_mlp(element_input_size + 8 * 3, hidden_size, hidden_size) 
        self.face_encoder = build_mlp(face_input_size + 4 * 3, hidden_size, hidden_size)  
        self.face_encoder2 = build_mlp(2 * hidden_size, hidden_size, hidden_size) 
        self.ff_encoder   = build_mlp(face_face_input_size + hidden_size, hidden_size, hidden_size)
        self.scripted_eoncoder = build_mlp(scripted_input_size, hidden_size, hidden_size)
        self.device = device
        
        
        self.e_n_encoder = build_mlp(3 * knode_num, hidden_size, knode_num)
        self.e_f_encoder = build_mlp(3 * kface_num, hidden_size, kface_num)
        self.f_n_encoder = build_mlp(3 * kf_node, hidden_size, kf_node)
        
        
        
    def normalize_rows_l2(self, matrix):
        norm = torch.norm(matrix, p=2, dim=1, keepdim=True)
        return matrix / (norm + 1e-8)
        
    def forward(self, data_list):

        node_features, element_features, face_features, face_to_face_features = data_list["node_features"], data_list["element_features"], data_list["face_features"], data_list["face_to_face_features"]

        scripted_features = data_list["scripted_features"]
        node_features_ = self.nb_encoder(node_features)
        scripted_features_ = self.scripted_eoncoder(scripted_features)
        
        element_self_features_ = self.element_encoder(torch.concat([element_features, data_list["e_n_vec"]], dim = -1))
        element_features_ = element_self_features_
        
        face_self_features_ = self.face_encoder(torch.concat([face_features, data_list["f_n_vec"]], dim = -1))
        face_features_ = self.face_encoder2(torch.concat([face_self_features_, scripted_features_], dim = -1))
        
        faces_to_faces_edge = data_list["faces_to_faces"] 
        
        face_to_face_features_ = self.ff_encoder(torch.concat([face_to_face_features, face_features_[faces_to_faces_edge[:, 0]]], dim = -1))
               
        faces_receivers_idx = faces_to_faces_edge[:, 1]
        faces_receivers_idx_ = faces_receivers_idx.to(torch.int64)
        faces_agg_received_edges = scatter_add(face_to_face_features_, faces_receivers_idx_, dim=0, dim_size=face_features_.shape[0])
        data_list["faces_agg_received_edges"] = faces_agg_received_edges
        
        data_list["node_features"], data_list["element_features"], data_list["face_features"], data_list["face_to_face_features"] = node_features_, element_features_, face_features_, face_to_face_features_
        data_list["scripted_features"] = scripted_features_
        
        
        e_n_fac_ori = self.e_n_encoder(data_list["e_n_vec"]).unsqueeze(-1)
        e_n_fac = self.normalize_rows_l2(e_n_fac_ori)
        e_f_fac_ori = self.e_f_encoder(data_list["e_f_vec"]).unsqueeze(-1)
        e_f_fac = self.normalize_rows_l2(e_f_fac_ori)
        f_n_fac_ori = self.f_n_encoder(data_list["f_n_vec"]).unsqueeze(-1)
        f_n_fac = self.normalize_rows_l2(f_n_fac_ori)
        data_list["e_n_fac_ori"], data_list["e_f_fac_ori"], data_list["f_n_fac_ori"] = e_n_fac_ori, e_f_fac_ori, f_n_fac_ori
        data_list["e_n_fac"], data_list["e_f_fac"], data_list["f_n_fac"] = e_n_fac, e_f_fac, f_n_fac
        
        
        return data_list


class GnBlock(nn.Module):

    def __init__(self, hidden_size, node_num, face_num, f_node, device):

        super(GnBlock, self).__init__()

        self.hidden_size = hidden_size
        self.face_num = face_num

        self._combined_net = build_mlp(4 * hidden_size, hidden_size, hidden_size)
        self.facrb_encoder = build_mlp(2 * hidden_size, hidden_size, hidden_size) 
        self._face_elem_net = build_mlp(2 * hidden_size, hidden_size, hidden_size)
        self.elerb_encoder = build_mlp(2 * hidden_size, hidden_size, hidden_size) 
        self.FFN = build_mlp(hidden_size, hidden_size, hidden_size)
        self.Norm1 = nn.LayerNorm(normalized_shape=hidden_size)
        self.Norm2 = nn.LayerNorm(normalized_shape=hidden_size)
        
        '''
        self.e_n_encoder = build_mlp(3 * 4, hidden_size, 4)
        self.e_f_encoder = build_mlp(3 * 4, hidden_size, 4)
        self.f_n_encoder = build_mlp(3 * 3, hidden_size, 3)
        '''
        
        self.device = device
        
    def normalize_rows_l2(self, matrix):
        norm = torch.norm(matrix, p=2, dim=1, keepdim=True)
        return matrix / (norm + 1e-8)

    def forward(self, data_list):
        ori_node_features_ = data_list["node_features"]
        node_features_ = self.Norm1(ori_node_features_)
        
        face_features, element_features, scripted_features = data_list["face_features"], data_list["element_features"], data_list["scripted_features"]
        faces = data_list["faces"]
        elements = data_list["elements"]
        
        
        e_n_fac_ori = data_list["e_n_fac_ori"]
        e_n_fac = data_list["e_n_fac"]
        e_f_fac_ori = data_list["e_f_fac_ori"]
        e_f_fac = data_list["e_f_fac"]
        f_n_fac_ori = data_list["f_n_fac_ori"]
        f_n_fac = data_list["f_n_fac"]
        
        
        '''
        e_n_fac_ori = self.e_n_encoder(data_list["e_n_vec"]).unsqueeze(-1)
        e_n_fac = self.normalize_rows_l2(e_n_fac_ori)
        e_f_fac_ori = self.e_f_encoder(data_list["e_f_vec"]).unsqueeze(-1)
        e_f_fac = self.normalize_rows_l2(e_f_fac_ori)
        f_n_fac_ori = self.f_n_encoder(data_list["f_n_vec"]).unsqueeze(-1)
        f_n_fac = self.normalize_rows_l2(f_n_fac_ori)
        '''
        
        #step 0: update elements and faces:
        
        elements_node_feature = e_n_fac * (node_features_[elements.reshape(-1)].reshape(elements.shape[0], -1, self.hidden_size))
        elements_node_feature = torch.sum(elements_node_feature, dim = 1)
        element_features_upd = self.elerb_encoder(torch.concat([element_features, elements_node_feature], dim = 1))
        
        faces_node_feature = f_n_fac * (node_features_[faces.reshape(-1)].reshape(faces.shape[0], -1, self.hidden_size))
        faces_node_feature = torch.sum(faces_node_feature, dim = 1)
        face_features_upd = self.facrb_encoder(torch.concat([face_features, faces_node_feature], dim = 1))
                
        #step 1: update faces:
        cells_faces = data_list["cells_faces"]
        generate_face_features = (e_f_fac_ori * element_features_upd.reshape(-1, 1, self.hidden_size).expand(-1, self.face_num, -1)).reshape(-1, self.hidden_size)
        element_receivers_idx = cells_faces.reshape(-1, 1)
        generate_face_features = torch.concat([generate_face_features, (e_f_fac_ori ** 2).reshape(-1, 1)], dim = 1)
        element_receivers_idx_ = element_receivers_idx.to(torch.int64)
        avg_face_features = scatter_add(generate_face_features, element_receivers_idx_, dim = 0, dim_size = faces.shape[0]) 
        avg_face_features = avg_face_features[:,:-1] / (torch.sqrt(avg_face_features[:, -1].reshape(-1, 1)) + 1e-8)
        faces_agg_received_edges = data_list["faces_agg_received_edges"]
        face_features = face_features + self._combined_net(torch.concat([face_features_upd, avg_face_features, faces_agg_received_edges, scripted_features], dim = 1))
        #face_features = self._combined_net(torch.concat([face_features_upd, avg_face_features, faces_agg_received_edges, scripted_features], dim = 1))
        
        #step 3: update elements
        cells_faces = data_list["cells_faces"]
        collect_face_features = e_f_fac * (face_features[cells_faces.reshape(-1)].reshape(-1, self.face_num, self.hidden_size))
        collect_face_features = torch.sum(collect_face_features, dim = 1)
        element_features = element_features + self._face_elem_net(torch.concat([element_features_upd, collect_face_features], dim = 1))
        #element_features = self._face_elem_net(torch.concat([element_features_upd, collect_face_features], dim = 1))
  
        #step 4: update-nodes
        node_idx = elements.reshape(-1)
        sum_element_features = (e_n_fac_ori * element_features.reshape(elements.shape[0], 1, -1).expand(-1, elements.shape[1] ,-1)).reshape(-1, self.hidden_size)
        sum_element_features = torch.concat([sum_element_features, (e_n_fac_ori ** 2).reshape(-1, 1)], dim = 1)
        node_idx_ = node_idx.to(torch.int64)
        sum_node_features = scatter_add(sum_element_features, node_idx_, dim = 0, dim_size = node_features_.shape[0])
        sum_node_features = sum_node_features[:,:-1] / (torch.sqrt(sum_node_features[:, -1].reshape(-1, 1)) + 1e-8)
        
        #linear_node_features = self.FFN(node_features_)
        #node_features = F.gelu(linear_node_features + sum_node_features)
        node_features = ori_node_features_ + sum_node_features
        node_features = node_features + self.FFN(self.Norm2(node_features))
        
        data_list["node_features"] = node_features
        data_list["face_features"], data_list["element_features"] = face_features, element_features
           
        return data_list


class Decoder(nn.Module):

    def __init__(self, hidden_size, output_size, node_num):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self._decode_node_net = build_mlp(hidden_size, hidden_size, output_size, lay_norm=False)

    def forward(self, data_list):
        node_features = data_list["node_features"]
        node_features = self._decode_node_net(node_features)
        
        return node_features


class EncoderProcesserDecoder(nn.Module):

    def __init__(self, message_passing_num, node_input_size, node_output_size, element_input_size, 
                 face_input_size, face_face_input_size, scripted_input_size, face_num, node_num, device, hidden_size=96):

        super(EncoderProcesserDecoder, self).__init__()
        
        #self.EncoderDecoder = NiuMoBlock(hidden_size, (1 + node_num) * hidden_size)

        self.encoder = Encoder(node_input_size = node_input_size, element_input_size = element_input_size, face_input_size = face_input_size, 
                                face_face_input_size = face_face_input_size, scripted_input_size = scripted_input_size, 
                                hidden_size=hidden_size, node_num = 3 * node_num // face_num, enode_num = node_num, 
                                knode_num = node_num, kface_num = face_num, kf_node = 3 * node_num // face_num, device = device)
        
        processer_list = []
        for _ in range(message_passing_num):
            processer_list.append(GnBlock(hidden_size=hidden_size, node_num = node_num, face_num = face_num, 
                                          f_node = 3 * node_num // face_num, device = device))
        self.processer_list = nn.ModuleList(processer_list)
        
        self.decoder = Decoder(hidden_size=hidden_size, output_size=node_output_size, node_num = node_num)

    def forward(self, data_list):

        data_list= self.encoder(data_list)
     
        for model in self.processer_list:
            data_list = model(data_list)
        decoded = self.decoder(data_list)

        return decoded







