from re import sub
from typing import Callable
import torch
import torch.nn.functional as F
import numpy as np
import dgl
from dgl.nn import GlobalAttentionPooling
from dgl.nn.pytorch import GraphConv
from transformers import RobertaConfig, RobertaModel, RobertaTokenizer
from .ERUMLayers import RUMLayer, Consistency

att_op_dict = {
    'sum': 'sum',
    'mul': 'mul',
    'concat': 'concat'
}

class ERUMModel(torch.nn.Module):
    def __init__(
            self,
            in_features: int,
            out_features: int,
            hidden_features: int,
            depth: int,
            activation: Callable = torch.nn.ELU(),
            temperature=0.1,
            self_supervise_weight=0.05,
            consistency_weight=0.01,
            **kwargs,
    ):
        super().__init__()
        self.fc_in = torch.nn.Linear(in_features, hidden_features, bias=True)
        self.fc_out = torch.nn.Linear(hidden_features, out_features, bias=True)
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.depth = depth
        self.layers = torch.nn.ModuleList()
        self.out_dim = hidden_features
        for _ in range(depth):
            self.layers.append(RUMLayer(hidden_features, hidden_features, in_features, **kwargs))
        self.activation = activation
        self.consistency = Consistency(temperature=temperature)
        self.self_supervise_weight = self_supervise_weight
        self.consistency_weight = consistency_weight

    def forward(self, g, h, e=None, consistency_weight=None, subsample=None):
        g = g.local_var()
        if consistency_weight is None:
            consistency_weight = self.consistency_weight
        h0 = h
        h = self.fc_in(h)
        loss = 0.0
        for idx, layer in enumerate(self.layers):
            if idx > 0:
                h = h.mean(0)
            h, _loss = layer(g, h, h0, e=e, subsample=subsample)
            loss = loss + self.self_supervise_weight * _loss
        h = self.fc_out(h).softmax(-1)
        if self.training:
            _loss = self.consistency(h)
            _loss = _loss * consistency_weight
            loss = loss + _loss
        return h, loss

class ERUMGraphRegressionModel(ERUMModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.fc_out = torch.nn.Sequential(
            # torch.nn.BatchNorm1d(self.hidden_features),
            self.activation,
            torch.nn.Linear(self.hidden_features, self.hidden_features),
            self.activation,
            torch.nn.Dropout(kwargs["dropout"]),
            torch.nn.Linear(self.hidden_features, self.out_features),
        )

    def forward(self, g, h, e=None, subsample=None):
        g = g.local_var()
        h0 = h
        h = self.fc_in(h)
        loss = 0.0
        for idx, layer in enumerate(self.layers):
            if idx > 0:
                # h = torch.nn.functional.tanh(h)
                h = torch.nn.SiLU()(h)
                h = h.mean(0)
            h, _loss = layer(g, h, h0, e=e, subsample=subsample)
            loss = loss + self.self_supervise_weight * _loss
        # h = self.activation(h)
        h = h.mean(0)
        g.ndata["h"] = h
        # h = dgl.sum_nodes(g, "h")
        h = dgl.mean_nodes(g, "h")
        h = self.fc_out(h)
        return h, loss


class ERUMJointModel(torch.nn.Module):
    def __init__(
            self,
            in_features: int,
            node_out_features: int,
            graph_out_features: int,
            hidden_features: int,
            depth: int,
            activation: Callable = torch.nn.ELU(),
            temperature=0.1,
            self_supervise_weight=0.05,
            consistency_weight=0.01,
            dropout=0.1,
            **kwargs,
    ):
        super().__init__()
        self.fc_in = torch.nn.Linear(in_features, hidden_features, bias=True)
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.depth = depth
        self.layers = torch.nn.ModuleList()
        for _ in range(depth):
            self.layers.append(RUMLayer(hidden_features, hidden_features, in_features, **kwargs))
        
        self.fc_out_node = torch.nn.Linear(hidden_features, node_out_features, bias=True)
        
        self.fc_out_graph = torch.nn.Sequential(
            activation,
            torch.nn.Linear(hidden_features, hidden_features),
            activation,
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_features, graph_out_features),
        )
        
        self.activation = activation
        self.consistency = Consistency(temperature=temperature)
        self.self_supervise_weight = self_supervise_weight
        self.consistency_weight = consistency_weight
        self.pooling_layer = GlobalAttentionPooling(torch.nn.Linear(hidden_features, 1))

    def forward(self, g, h, e=None, consistency_weight=None, subsample=None):
        g = g.local_var()
        if consistency_weight is None:
            consistency_weight = self.consistency_weight
        
        h0 = h
        h = self.fc_in(h)
        loss = 0.0
        
        for idx, layer in enumerate(self.layers):
            if idx > 0:
                h = torch.nn.SiLU()(h) 
                h = h.mean(0)           
            h, _loss = layer(g, h, h0, e=e, subsample=subsample)
            loss = loss + self.self_supervise_weight * _loss
        
        h_node = self.fc_out_node(h)
        
        g.ndata['h'] = h
        h_graph = self.pooling_layer(g, g.ndata['h'])
        h_graph = self.fc_out_graph(h_graph)         
        
        if self.training:
            _loss = self.consistency(h_node) * consistency_weight
            loss = loss + _loss
        
        return h_node, h_graph, loss

class ERUMGraphClassificationModel(torch.nn.Module):
    def __init__(
            self,
            in_features: int,
            out_features: int,
            hidden_features: int,
            depth: int,
            num_samples: int,
            walk_length: int,
            activation: Callable = torch.nn.ELU(),
            temperature=0.1,
            self_supervise_weight=0.05,
            consistency_weight=0.01,
            dropout=0.1,
            block_size=400,
            model_name_or_path="microsoft/codebert-base",
            encoder_type="by_code_line",
            **kwargs,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.depth = depth
        self.num_samples = num_samples
        self.walk_length = walk_length
        self.dropout = dropout
        self.self_supervise_weight = self_supervise_weight
        self.consistency_weight = consistency_weight
        self.temperature = temperature
        self.activation = activation
        self.model_name_or_path = model_name_or_path

        self.encoder_config = RobertaConfig.from_pretrained(model_name_or_path)
        self.encoder = RobertaModel.from_pretrained(model_name_or_path)
        self.word_embeddings = None
        
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.encoder_type = encoder_type
        self.block_size = block_size

        self.model = ERUMGraphRegressionModel(
            in_features=self.in_features,
            # out_features=self.num_classes-1,
            out_features=self.out_features,
            hidden_features=self.hidden_features,
            depth=self.depth,
            num_samples=self.num_samples,
            length=self.walk_length,
            dropout=self.dropout,
            self_supervise_weight=self.self_supervise_weight,
            consistency_weight=self.consistency_weight,
            temperature=self.temperature,
            activation=self.activation,
            **kwargs
        )

    def preprocess_features(self, features):
        """Feature matrix preprocessing and filling"""
        pad = self.block_size - len(features)
        padded = np.pad(features, ((0, pad), (0, 0)), mode='constant')
        return padded

    def get_encoder_vec(self, source_ids, mask):
        with torch.no_grad():
            outputs = self.encoder(source_ids, attention_mask=mask)
            last_hidden = outputs.last_hidden_state
            
            mask_expanded = mask.unsqueeze(-1).to(last_hidden.dtype)
            sum_embeddings = torch.sum(last_hidden * mask_expanded, dim=1)
            sum_mask = torch.clamp(mask.sum(dim=1, keepdim=True), min=1e-9)
            return sum_embeddings / sum_mask, last_hidden

    def get_nodes_feature(self, node_cids, node_types):
        device = node_cids[0].device
        
        if self.encoder_type == 'by_code_line':
            code_ids = node_cids
            mask = code_ids.ne(self.encoder_config.pad_token_id)
            with torch.no_grad(), torch.cuda.amp.autocast():
                code_embs, _ = self.get_encoder_vec(code_ids, mask)
                
            node_types_tensor = node_types.float().unsqueeze(1).to(device)
            return torch.cat([node_types_tensor, code_embs], dim=1)
        
        elif self.encoder_type == 'by_token':
            if self.word_embeddings == None:
                self.word_embeddings = self.encoder.embeddings.word_embeddings
            
            all_cids = node_cids
            vocab_size = self.word_embeddings.num_embeddings
            
            valid_mask = (all_cids >= 0) & (all_cids < vocab_size)
            safe_cids = torch.where(valid_mask, all_cids, torch.zeros_like(all_cids))
            
            with torch.no_grad():
                embeds = self.word_embeddings.weight[safe_cids]
                
                embeds = embeds * valid_mask.unsqueeze(-1).to(embeds.dtype)
                
                code_embs = embeds.sum(dim=1)
            
            node_types_tensor = node_types.float().unsqueeze(1).to(device)
            return torch.cat([node_types_tensor, code_embs], dim=1)

    def forward(self, g, e=None, subsample=None):
        node_features = self.get_nodes_feature(g.ndata['node_cids'], g.ndata['node_types'])
        logits, loss = self.model(g, node_features, e, subsample)
        return logits, loss
    


class GCN(torch.nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)
 
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h