import torch
from torch import nn
from einops import rearrange,repeat


class GCN(nn.Module):
    def __init__(self, c_in, c_out, dropout=0.1, support_len=1, order=2):
        super(GCN, self).__init__()
        c_in = (order * support_len + 1) * c_in
        self.mlp = nn.Linear(c_in, c_out)
        self.order = order
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, support):
        out = [x]
        if type(support)!=list:
            support = [support]
        for a in support:
            x1 = torch.einsum('btkd,nk->btnd', x, a)
            out.append(x1)
            for k in range(2, self.order + 1):
                x2 = torch.einsum('btkd,nk->btnd', x1, a)
                out.append(x2)
                x1 = x2
        h = torch.cat(out, dim=-1)
        h = self.mlp(h)
        h = self.dropout(h)
        return h


class featMLP(nn.Module):
    def __init__(self, num_feat_dim, hid_dim, dropout_rate=0.1):
        super().__init__()
        self.meta_projection = nn.Linear(num_feat_dim, hid_dim)
        self.act = nn.ReLU()
        self.drop = nn.Dropout(p=dropout_rate)
        self.fc = nn.Linear(hid_dim,hid_dim)
    
    def forward(self, feat):
        h = self.meta_projection(feat)
        emb = self.fc(self.drop(self.act((h))))
        return emb




class getSTembedding_adp(nn.Module):
    def __init__(self, num_feat_dim, hid_dim, static_func_type, dropout_rate=0.1):
        super().__init__()
        if static_func_type == 'featMLP':
            self.embedder = featMLP(num_feat_dim, hid_dim, dropout_rate)

    def forward(self, aux_data):
        emb = self.embedder(aux_data)
        return emb

class getSTembedding_dynamic(nn.Module):
    def __init__(self, hid_dim, feat_type, support_len):
        super().__init__()
        assert feat_type in ['gcn', 'gcn2', 'sgcn', 'sgcn2']
        self.feat_type = feat_type
        self.in_dim = hid_dim

        if feat_type == 'gcn':
            self.gcn_layer = GCN(hid_dim, hid_dim, support_len=support_len, order=1)
        elif feat_type == 'gcn2':
            self.gcn_layer = GCN(hid_dim, hid_dim, support_len=support_len, order=2)
            
    def forward(self, x, adj_mx):
        if self.feat_type in ['gcn', 'gcn2', 'sgcn']:
            h = self.gcn_layer(x, adj_mx)
        elif self.feat_type =='sgcn2':
            h = self.act(self.gcn_layer(x, adj_mx))
            h = self.gcn_layer2(h, adj_mx)
        return h

class SNIP(nn.Module):
    def __init__(self, static_feats_dim_list, emb_dim, support_len, 
                    static_func_type='featMLP', dynamic_func_type='gcn2',
                   dropout_rate=0.1, dynamic_input_dim=None):
        super(SNIP, self).__init__()
        self.static_func_type = static_func_type
        self.emb_dim = emb_dim
        self.dynamic_func_type = dynamic_func_type

        num_all_prior_feats_dim = sum(static_feats_dim_list)
    
        self.static_feat_embedder = getSTembedding_adp(num_all_prior_feats_dim, emb_dim, static_func_type, dropout_rate)
        
        if dynamic_func_type != 'none':
            if dynamic_input_dim is None:
                self.dynamic_feat_embedder = getSTembedding_dynamic(emb_dim, dynamic_func_type, support_len)
            else:
                if dynamic_input_dim != emb_dim:
                    self.input_emb_linear =nn.Linear(dynamic_input_dim, emb_dim, bias=False)
                self.dynamic_feat_embedder = getSTembedding_dynamic(emb_dim, dynamic_func_type, support_len)

    def forward(self, x, static_feats_list, adj_list):
        B,T,N,d = x.shape
        
        build_static_feats_list = []
        for a in static_feats_list:
            if a is not None:
                build_static_feats_list.append(a)
        if len(build_static_feats_list)==0:
            static_emb = None
        else:            
            static_feats = torch.concat(build_static_feats_list, dim =-1)
            static_emb = self.static_feat_embedder(static_feats)
            if len(static_emb.shape) == 2:
                static_emb = rearrange(static_emb, 'n d -> 1 1 n d')
            elif len(static_emb.shape) == 3:
                static_emb = rearrange(static_emb, 'b n d -> b 1 n d')
            
        
        if self.dynamic_func_type != 'none':
            if x.shape[-1] != self.emb_dim:
                x_ready = self.input_emb_linear(x)
                dynamic_emb = self.dynamic_feat_embedder(x_ready, adj_list)
            else:
                dynamic_emb = self.dynamic_feat_embedder(x, adj_list)
            snip_emb = static_emb + dynamic_emb if static_emb is not None else dynamic_emb
        else:
            assert static_emb is not None
            snip_emb = repeat(static_emb, '1 1 n d -> b t n d', b = B, t = T)

        return snip_emb
    

    def get_meta_emb(self):
        if self.static_func_type == 'featMLP':
            return self.static_feat_embedder.embedder.meta_projection.weight.t()
        return None


class AttentionPrompt(nn.Module):
    def __init__(self, num_prompt, embed_dim, prompt_threshold=0.2, dynamic_input_dim=None) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.dynamic_input_dim = dynamic_input_dim
        if self.dynamic_input_dim is not None and dynamic_input_dim != embed_dim:
            self.input_emb_linear =nn.Linear(dynamic_input_dim, embed_dim, bias=False)
        self.prompt = nn.Embedding(num_prompt, embed_dim)
        nn.init.kaiming_uniform_(self.prompt.weight, nonlinearity='leaky_relu', mode='fan_in', a=0.01)
        self.prompt_threshold = prompt_threshold

    def forward(self, patches):
        """
        :param patches: [B, P, N, D]
        :return: prompted_patches
        """
        # spatial dimension
        if self.dynamic_input_dim is not None and self.dynamic_input_dim != self.embed_dim:
            patches = self.input_emb_linear(patches)
        score_map = torch.sigmoid(torch.matmul(patches, self.prompt.weight.t()))  # [B, P, N, num_prompt]
        score_map = score_map * (score_map > self.prompt_threshold).float()
        prompt_emb = torch.sum(score_map.unsqueeze(-1) * self.prompt.weight, dim=-2)
        # prompted_patches = patches + prompt_emb
        # return prompted_patches
        return prompt_emb
    