# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
from einops import *

class Value_Encoder(nn.Module):
    def __init__(self, output_dim):
        self.output_dim = output_dim
        super(Value_Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(1, output_dim),
            nn.ReLU()
        )

    def forward(self, x):
        x = rearrange(x, 'b l k -> b l k 1')
        x = self.encoder(x)
        return x
    

class Time_Encoder(nn.Module):
    def __init__(self, embed_time, var_num):
        super(Time_Encoder, self).__init__()
        self.periodic = nn.Linear(1, embed_time - 1)
        self.var_num = var_num
        self.linear = nn.Linear(1, 1)

    def forward(self, tt):
        if tt.dim() == 3:  # [B,L,K]
            tt = rearrange(tt, 'b l k -> b l k 1')
        else:  # [B,L]
            tt = rearrange(tt, 'b l -> b l 1 1')

        out2 = torch.sin(self.periodic(tt))
        out1 = self.linear(tt)
        out = torch.cat([out1, out2], -1)  # [B,L,1,D]
        return out

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.GELU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        return self.layers(x)

class EdgeSAGEConv(nn.Module):
    def __init__(self, in_channels, out_channels, edge_channels, normalize_emb=True):
        super(EdgeSAGEConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_channels = edge_channels
        self.normalize_emb = normalize_emb
        self.message_lin = nn.Linear(in_channels + edge_channels, out_channels)
        self.agg_lin = nn.Linear(in_channels + out_channels, out_channels)
        self.message_activation = nn.ReLU()
        self.update_activation = nn.ReLU()

    def forward(self, x, edge_attr, edge_index):
        # 构造消息
        row, col = edge_index  # row: source, col: target
        m = torch.cat([x[col], edge_attr], dim=-1)
        m = self.message_activation(self.message_lin(m))

        # 用 index_add_ 聚合
        out = torch.zeros(x.size(0), self.out_channels, device=x.device, dtype=x.dtype)
        out.index_add_(0, row, m)

        # 更新节点表示

        out = self.update_activation(self.agg_lin(torch.cat([x, out], dim=-1)))
        if self.normalize_emb:
            out = F.normalize(out, p=2, dim=-1)
        return out

    def __repr__(self):
        return '{}({}, {}, edge_channels={})'.format(
            self.__class__.__name__, self.in_channels, self.out_channels, self.edge_channels
        )


class GNNStack(torch.nn.Module):
    def __init__(self, node_channels, edge_channels, normalize_embs, num_layers, dropout):
        super(GNNStack, self).__init__()
        self.node_channels = node_channels
        self.edge_channels = edge_channels
        self.normalize_embs = normalize_embs
        self.num_layers = num_layers
        self.dropout = dropout

        self.convs = self.build_convs(node_channels, edge_channels, normalize_embs, num_layers)
        self.edge_update_mlps = self.build_edge_update_mlps(node_channels, edge_channels, num_layers)

    def build_convs(self, node_channels, edge_channels, normalize_embs, num_layers):
        convs = nn.ModuleList()
        for l in range(num_layers):
            conv = EdgeSAGEConv(node_channels, node_channels, edge_channels, normalize_embs[l])
            convs.append(conv)
        return convs

    def build_edge_update_mlps(self, node_channels, edge_channels, num_layers):
        edge_update_mlps = nn.ModuleList()
        for l in range(num_layers):
            edge_update_mlp = nn.Sequential(
                nn.Linear(node_channels + node_channels + edge_channels, edge_channels),
                nn.ReLU()
            )
            edge_update_mlps.append(edge_update_mlp)
        return edge_update_mlps

    def update_edge_attr(self, x, edge_attr, edge_index, mlp):
        x_i = x[edge_index[0], :]
        x_j = x[edge_index[1], :]
        edge_attr = mlp(torch.cat((x_i, x_j, edge_attr), dim=-1))
        return edge_attr

    def forward(self, x, edge_attr, edge_index):
        for l, conv in enumerate(self.convs):
            x = conv(x, edge_attr, edge_index)
            # if l < self.num_layers - 1:
            x = F.dropout(x, p=self.dropout, training=self.training)
            edge_attr = edge_attr + self.update_edge_attr(x, edge_attr, edge_index, self.edge_update_mlps[l])
        return x, edge_attr



class GTDE(nn.Module):
    def __init__(self, d_in, d_model, num_of_nodes):
        super(GTDE, self).__init__()
        self.d_in = d_in
        self.d_model = d_model
        self.num_of_nodes = num_of_nodes
        self.gnn = GNNStack(d_model, d_model, [True, True], 2, 0.)
        self.modality_nodes = nn.Parameter(torch.zeros(num_of_nodes, d_model))
        nn.init.xavier_uniform_(self.modality_nodes)
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=4, batch_first=True)
        self.r_gate = nn.Linear(d_model + d_model, d_model)
        self.proj = nn.Linear(d_model, d_model)

        self.decay_param = MLP(d_model, d_model*2, d_model)#
        # self.decay_emb = nn.Parameter(torch.ones(num_of_nodes, d_model))
        self.codebook = nn.Embedding(4096, d_model)
        self.codebook.weight.data.uniform_(-1 / 4096, 1 / 4096)

        # self.f_gate = nn.Linear(d_model + d_model, d_model)
    def init_hidden_states(self, x):
        return torch.zeros(size=(x.shape[0], x.shape[2], self.d_model)).to(x.device)

    def forward(self, obs_emb, observed_mask, lengths, avg_interval):
        batch, steps, nodes, features = obs_emb.size()
        device = obs_emb.device

        h = self.init_hidden_states(obs_emb)
        I = repeat(torch.eye(nodes).to(device), 'v x -> b v x', b=batch)

        nodes_initial_mask = torch.zeros(batch, nodes).to(device)


        g_patient_nodes = torch.ones(batch, features).to(device)
        g_nodes = torch.cat([g_patient_nodes, self.modality_nodes], dim=0)

        codebook = self.codebook.weight.data
        for step in range(steps):

            if step > 0:
                code_indice = F.normalize(g_nodes,dim=-1) @ F.normalize(codebook.T, dim=-1)
                # 1️⃣ softmax 权重归一化
                weights = code_indice.softmax(dim=-1)                # (num_nodes, num_code)
                weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8)
                # 2️⃣ 聚合 codebook embedding
                g_nodes_quant = weights @ codebook                  # (num_nodes, d_model)
                # 3️⃣ 自适应残差缩放
                scale = g_nodes_quant.norm(dim=-1, keepdim=True) / (g_nodes.norm(dim=-1, keepdim=True) + 1e-6)
                g_nodes = g_nodes + g_nodes_quant * scale
                
                g_patient_nodes = g_nodes[:batch]            
                g_patient_nodes = g_patient_nodes.unsqueeze(1)
                g_patient_nodes = self.attn(g_patient_nodes, h, h)[0].squeeze(1)
                g_patient_nodes = self.proj(g_patient_nodes)
            else:
                g_patient_nodes = g_nodes[:batch]

            cur_interval = avg_interval[:, step]
            cur_interval = cur_interval.unsqueeze(-1)
            cur_interval_weight =  torch.sigmoid(-cur_interval) # b,n,1
            modality_nodes = g_nodes[batch:] + (cur_interval_weight * g_nodes[batch:].unsqueeze(0)).mean(dim=0)
            g_nodes = torch.cat([g_patient_nodes, modality_nodes], dim=0)

            cur_obs = obs_emb[:, step]
            cur_mask = observed_mask[:, step]
            cur_obs_var = torch.where(cur_mask)
            nodes_initial_mask[cur_obs_var] = 1
            nodes_need_update = cur_obs_var

            if nodes_need_update[0].shape[0] > 0:
                g_edge_index = cur_mask.bool().nonzero().t()
                g_edge_index[1] += batch
                g_edge_attr = cur_obs[cur_obs_var]
                g_edge_index = torch.cat([g_edge_index, g_edge_index.flip([0])], dim=1)

                g_edge_attr = g_edge_attr.repeat(2, 1)

                g_nodes, edge_attr = self.gnn(g_nodes, g_edge_attr, g_edge_index)
                edge_attr = edge_attr[:g_edge_attr.size(0)//2] #+ edge_attr[g_edge_attr.size(0)//2:]
                # delta 可以是可学习参数，也可以根据节点动态生成
                decay_rate = self.decay_param(edge_attr)  # (B*N, d_model)
                decay_rate = F.softplus(decay_rate)   # 保证 >= 0
                delta_t = cur_interval[nodes_need_update]   # (B, N, 1)，时间间隔
                gamma = torch.exp(-decay_rate * delta_t)    # 衰减因子 ∈ (0,1)

                # 衰减历史状态
                h_decay = h[nodes_need_update] * gamma 
                # 融合门控更新
                r = torch.sigmoid(self.r_gate(torch.cat([edge_attr, h_decay], dim=-1)))
                h[nodes_need_update] = (1 - r) * h_decay + r * edge_attr

        h = h + h * F.softmax(observed_mask.sum(dim=1),dim=-1).unsqueeze(-1)
        h = h.reshape(batch, -1)
        g_patient_nodes = g_nodes[:batch]

        codebook_sim = F.normalize(g_patient_nodes,dim=-1) @ F.normalize(codebook.T, dim=-1)
        codebook_sim = torch.argmax(codebook_sim, dim=-1) # B, NUM
        code_top = codebook[codebook_sim]

        return torch.cat([g_nodes[:batch], code_top, h], dim=-1)

class DBGL(nn.Module):
    def __init__(self, DEVICE, hidden_dim, num_of_variables, num_of_timestamps, d_static,
                 n_class):

        super(DBGL, self).__init__()
        self.num_of_variables = num_of_variables
        self.num_of_timestamps = num_of_timestamps
        self.hidden_dim = hidden_dim
        self.adj = nn.Parameter(torch.ones(size=[num_of_variables, num_of_variables]))
        self.value_enc = Value_Encoder(output_dim=hidden_dim)
        self.abs_time_enc = Time_Encoder(embed_time=hidden_dim, var_num=num_of_variables)
        self.type_emb = nn.Embedding(num_of_variables, hidden_dim)
        self.GTDE = GTDE(d_in=self.hidden_dim, d_model=self.hidden_dim,
                                 num_of_nodes=num_of_variables)
        # self.final_conv = nn.Conv2d(hidden_dim, 1, kernel_size=1)
        self.d_static = d_static
        if d_static != 0:
            self.emb = nn.Linear(d_static, hidden_dim)
            self.classifier = nn.Sequential(
                nn.Linear(hidden_dim * (3+num_of_variables), 256),
                nn.ReLU(),
                nn.Linear(256, n_class)).to(DEVICE)
        else:
            self.classifier = nn.Sequential(
                nn.Linear(hidden_dim*(2+num_of_variables), 256),
                nn.ReLU(),
                nn.Linear(256, n_class))

        self.DEVICE = DEVICE
        self.to(DEVICE)

    def forward(self, P, P_static, P_avg_interval, P_length, P_time):
        b, t, v = P.shape
        v = v // 2
        observed_data = P[:, :, :v]
        observed_mask = P[:, :, v:]

        value_emb = self.value_enc(observed_data) * observed_mask.unsqueeze(-1)
        abs_time_emb = self.abs_time_enc(P_time) * observed_mask.unsqueeze(-1)
        type_emb = repeat(self.type_emb.weight, 'v d -> b v d', b=b)
        structure_input_encoding = (value_emb + abs_time_emb + repeat(type_emb, 'b v d -> b t v d', t=t)) * observed_mask.unsqueeze(-1)

        last_hidden_state = self.GTDE(structure_input_encoding, observed_mask, P_length, P_avg_interval)

        if P_static is not None:
            static_emb = self.emb(P_static)
            return self.classifier(torch.cat([last_hidden_state, static_emb], dim=-1))
        else:
            return self.classifier(last_hidden_state)