import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from src.gift.models.attention import AttentionBlock
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from src.gift.models.attention import AttentionBlock
import numpy as np
from torch.nn.utils import spectral_norm
def relaxed_spectral_norm(module, lip_constant=5.0):
    """具有可调整Lipschitz常数的谱归一化"""
    u = torch.nn.Parameter(torch.Tensor(module.weight.size(0)).normal_(0, 1))
    u.requires_grad = False
    module._u = u
    
    def spectral_norm_forward_hook(module, input, output):
        w = module.weight
        size = w.size()
        w_mat = w.view(size[0], -1)
        
        with torch.no_grad():
            for _ in range(1):
                v = F.normalize(torch.matmul(w_mat.t(), module._u), dim=0, eps=1e-12)
                module._u.data = F.normalize(torch.matmul(w_mat, v), dim=0, eps=1e-12)
            
            sigma = torch.matmul(module._u, torch.matmul(w_mat, v))
            if sigma > lip_constant:
                w_normalized = w / sigma * lip_constant
                w.data.copy_(w_normalized.data)
    
    module.register_forward_hook(spectral_norm_forward_hook)
    return module

class HistoryEncoder(nn.Module):
    def __init__(self, input_dim=1, output_dim=1, treatment_dim=2, static_dim=1, hiddens_enc=[128, 128], 
                 num_layers=2, dropout=0.1, use_attention=False, num_heads=4, use_spectral_norm=False,
                 lip_constant=5.0, input_x=False):
        super(HistoryEncoder, self).__init__()
        if not hiddens_enc:
            raise ValueError("hiddens_enc 列表不能为空")
        lstm_hidden_dim = hiddens_enc[0] 
        self.output_hidden_dim = hiddens_enc[-1] 
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.treatment_dim = treatment_dim
        self.static_dim = static_dim
        self.use_attention = use_attention
        self.use_spectral_norm = use_spectral_norm
        self.lip_constant = lip_constant
        self.input_x = input_x
        self.feature_dim = output_dim + treatment_dim + static_dim
        if input_x:
            self.feature_dim += self.input_dim
        print(f"feature_dim:{self.feature_dim}")
        self.lstm = nn.LSTM(
            input_size=self.feature_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        if use_attention:
            self.attention = AttentionBlock(lstm_hidden_dim, num_heads, dropout)
        self.goal_encoder = nn.Sequential(
            nn.Linear(output_dim, lstm_hidden_dim),
            nn.ReLU(),
            nn.Linear(lstm_hidden_dim, lstm_hidden_dim)
        )
        fusion_layers = []
        in_dim = lstm_hidden_dim * 2
        for i, h_dim in enumerate(hiddens_enc):
            linear_layer = nn.Linear(in_dim, h_dim)
            if use_spectral_norm:
                linear_layer = relaxed_spectral_norm(linear_layer, lip_constant=lip_constant)
            
            fusion_layers.append(linear_layer)
            if i < len(hiddens_enc) - 1:
                fusion_layers.append(nn.ReLU())
                fusion_layers.append(nn.Dropout(dropout))
            
            in_dim = h_dim 

        self.fusion_layer = nn.Sequential(*fusion_layers)
        
        self._init_weights()
        
    def _init_weights(self):
        """初始化网络权重 (此函数无需修改)"""
        for name, param in self.named_parameters():
            if 'weight' in name:
                if 'lstm' in name:
                    continue
                if not self.use_spectral_norm or 'fusion_layer' not in name:  
                    nn.init.xavier_normal_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)

    def forward(self, history_batch, goal_batch):
        """编码历史和目标 (此函数无需修改)"""
        outputs = history_batch['outputs']
        static_features = history_batch['static_features']
        current_treatments = history_batch['current_treatments']
        if self.input_x:
            current_vitals = history_batch['vitals']
        
        batch_size = outputs.size(0)
        
        combined_features = torch.cat([
            outputs, static_features, current_treatments
        ], dim=2)
        if self.input_x:
            combined_features = torch.cat([
                outputs, static_features, current_treatments, current_vitals,
            ], dim=2)
        
        if 'seq_lengths' in history_batch:
            seq_lengths = history_batch['seq_lengths']
            packed_features = nn.utils.rnn.pack_padded_sequence(
                combined_features, seq_lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            _, (h_n, _) = self.lstm(packed_features)
        else:
            _, (h_n, _) = self.lstm(combined_features)
        
        history_encoding = h_n[-1]
        
        if isinstance(goal_batch, np.ndarray):
            goal_batch = torch.tensor(goal_batch, dtype=torch.float32, device=DEVICE)
        
        if goal_batch.dim() == 1:
            goal_batch = goal_batch.unsqueeze(0)
        
        if goal_batch.size(0) == 1 and history_encoding.size(0) > 1:
            goal_batch = goal_batch.expand(history_encoding.size(0), -1)
        
        goal_encoding = F.relu(self.goal_encoder(goal_batch))
        
        combined_encoding = torch.cat([history_encoding, goal_encoding], dim=1)
        
        encoded_state = self.fusion_layer(combined_encoding)
        
        return encoded_state