import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class TimeSeriesEncoder(nn.Module):
    def __init__(self, out_dim=64, input_dim=None):
        super().__init__()
        self.out_dim = out_dim
        self.input_dim = input_dim
        
        if input_dim is not None:
            self.input_norm = nn.LayerNorm(input_dim)
            self.feature_projection = nn.Linear(input_dim, out_dim)
            self.output_norm = nn.LayerNorm(out_dim)
            self.residual_proj = nn.Linear(input_dim, out_dim) if input_dim != out_dim else nn.Identity()
            nn.init.kaiming_normal_(self.feature_projection.weight, mode='fan_in', nonlinearity='relu')
            self.feature_projection.weight.data *= 0.001
            nn.init.constant_(self.feature_projection.bias, 0.0)
            if input_dim != out_dim:
                nn.init.kaiming_normal_(self.residual_proj.weight, mode='fan_in', nonlinearity='relu')
                self.residual_proj.weight.data *= 0.001
                nn.init.constant_(self.residual_proj.bias, 0.0)
            
            if hasattr(self.input_norm, 'weight') and self.input_norm.weight is not None:
                self.input_norm.weight.data.clamp_(0.5, 2.0)
            if hasattr(self.output_norm, 'weight') and self.output_norm.weight is not None:
                self.output_norm.weight.data.clamp_(0.5, 2.0)
        else:
            self.input_norm = None
            self.feature_projection = None
            self.output_norm = None
            self.residual_proj = None
        
    def _get_or_create_projection(self, feature_dim, device):
        if self.feature_projection is None or self.feature_projection.in_features != feature_dim:
            self.input_norm = nn.LayerNorm(feature_dim).to(device)
            self.feature_projection = nn.Linear(feature_dim, self.out_dim).to(device)
            self.output_norm = nn.LayerNorm(self.out_dim).to(device)
            self.residual_proj = nn.Linear(feature_dim, self.out_dim).to(device) if feature_dim != self.out_dim else nn.Identity()
            nn.init.kaiming_normal_(self.feature_projection.weight, mode='fan_in', nonlinearity='relu')
            self.feature_projection.weight.data *= 0.001
            nn.init.constant_(self.feature_projection.bias, 0.0)
            if feature_dim != self.out_dim:
                nn.init.kaiming_normal_(self.residual_proj.weight, mode='fan_in', nonlinearity='relu')
                self.residual_proj.weight.data *= 0.001
                nn.init.constant_(self.residual_proj.bias, 0.0)
            
            if hasattr(self.input_norm, 'weight') and self.input_norm.weight is not None:
                self.input_norm.weight.data.clamp_(0.5, 2.0)
            if hasattr(self.output_norm, 'weight') and self.output_norm.weight is not None:
                self.output_norm.weight.data.clamp_(0.5, 2.0)
        return self.input_norm, self.feature_projection, self.output_norm, self.residual_proj

    def forward(self, x, lengths=None):
        if torch.isnan(x).any() or torch.isinf(x).any():
            print(f"[WARNING] TimeSeriesEncoder input contains NaN/Inf, replacing with zeros")
            x = torch.where(torch.isnan(x) | torch.isinf(x), torch.zeros_like(x), x)
        
        x = torch.clamp(x, min=-10.0, max=10.0)
        
        if x.numel() > 0:
            feature_mean = x.mean(dim=(0, 1, 2), keepdim=True)
            feature_std = x.std(dim=(0, 1, 2), keepdim=True) + 1e-8
            x = (x - feature_mean) / feature_std
            x = torch.clamp(x, min=-3.0, max=3.0)
        
        if len(x.shape) == 3:
            x = x.unsqueeze(1)
        if len(x.shape) != 4:
            print(f"[WARNING] TimeSeriesEncoder received {len(x.shape)}D input: {x.shape}")
            print(f"Expected 4D: [batch, n_med, seq_len, feature_dim]")
            if len(x.shape) == 5:
                # [batch, n_med, seq_len, n_sub_features, val_miss(=2)]
                b, n, t, f1, f2 = x.shape
                x = x.reshape(b, n, t, f1 * f2)
                print(f"Auto-fixed 5D input to 4D by flattening last two dims -> {x.shape}")
            else:
                raise ValueError(f"TimeSeriesEncoder expects 4 dims [batch, n_med, seq_len, feature_dim], got {x.shape}")
        
        batch_size, n_med, seq_len, feature_dim = x.shape
        
        if torch.allclose(x, torch.zeros_like(x)):
            print(f"[WARNING] TimeSeriesEncoder input is all zeros (padding), returning zeros")
            return torch.zeros(batch_size, n_med, seq_len, self.out_dim, device=x.device, dtype=x.dtype)
        
        try:
            input_norm, projection, output_norm, residual_proj = self._get_or_create_projection(feature_dim, x.device)
            
            x_reshaped = x.view(batch_size * n_med, seq_len, feature_dim)
            
            x_norm = input_norm(x_reshaped)  # [batch * n_med, seq_len, feature_dim]
            
            x_norm = torch.clamp(x_norm, min=-5.0, max=5.0)
            
            x_projected = projection(x_norm)  # [batch * n_med, seq_len, out_dim]
            
            x_projected = torch.clamp(x_projected, min=-10.0, max=10.0)
            
            use_residual = False
            if use_residual:
                if feature_dim == self.out_dim:
                    residual = x_norm
                else:
                    residual = residual_proj(x_norm)
                    residual = torch.clamp(residual, min=-10.0, max=10.0)
                
                residual_scale = 0.1
                x_projected = x_projected + residual_scale * residual
            
            x_projected = output_norm(x_projected)
            
            x_projected = torch.clamp(x_projected, min=-5.0, max=5.0)
            
            if torch.isnan(x_projected).any() or torch.isinf(x_projected).any():
                print(f"[ERROR] TimeSeriesEncoder projection output contains NaN or inf!")
                print(f"Projection input stats: min={x_reshaped.min():.6f}, max={x_reshaped.max():.6f}, mean={x_reshaped.mean():.6f}")
                print(f"Projection output stats: min={x_projected.min():.6f}, max={x_projected.max():.6f}, mean={x_projected.mean():.6f}")
                x_projected = torch.where(torch.isnan(x_projected) | torch.isinf(x_projected), 
                                        torch.zeros_like(x_projected), x_projected)
            
            try:
                x_projected = torch.gelu(x_projected)
            except AttributeError:
                import torch.nn.functional as F
                x_projected = F.gelu(x_projected)
            
            x_output = x_projected.view(batch_size, n_med, seq_len, self.out_dim)
            
        except Exception as e:
            print(f"[ERROR] TimeSeriesEncoder processing failed: {e}")
            x_output = torch.zeros(batch_size, n_med, seq_len, self.out_dim, device=x.device, dtype=x.dtype)
        
        return x_output

if __name__ == "__main__":
    batch_size = 4
    num_time_series = 3
    max_sequence_length = 10
    n_sub_features = 5
    value_missing_indicator_dim = 2

    encoder = TimeSeriesEncoder(
        num_sub_features_per_timestep=n_sub_features,
        value_missing_indicator_dim=value_missing_indicator_dim,
        hidden_size=64,
        num_layers=2,
        bidirectional=True,
        out_dim=32,
        dropout=0.1
    )

    print(f"Encoder Architecture:\n{encoder}\n")

    data_1 = torch.randn(batch_size, num_time_series, max_sequence_length, n_sub_features)
    data_1[:, :, :, :, 1] = (torch.rand(batch_size, num_time_series, max_sequence_length, n_sub_features) > 0.8).float()
    
    lengths_1 = torch.randint(1, max_sequence_length + 1, (batch_size, num_time_series))
    
    for i in range(batch_size):
        for j in range(num_time_series):
            current_length = lengths_1[i, j].item()
            if current_length < max_sequence_length:
                data_1[i, j, current_length:, :, :] = 0.0
                data_1[i, j, current_length:, :, 1] = 0.0

    print(f"Input data shape: {data_1.shape}")
    print(f"Input lengths shape: {lengths_1.shape}")
    print(f"Example lengths (batch 0): {lengths_1[0]}\n")

    output_1 = encoder(data_1, lengths_1)
    print(f"Output encoding shape: {output_1.shape}\n") 

    data_2 = torch.randn(batch_size, num_time_series, max_sequence_length, n_sub_features)
    data_2[:, :, :, :, 1] = (torch.rand(batch_size, num_time_series, max_sequence_length, n_sub_features) > 0.8).float()
    
    lengths_2 = torch.randint(0, max_sequence_length + 1, (batch_size, num_time_series))
    lengths_2[0, 0] = 0 
    lengths_2[1, 2] = 0

    for i in range(batch_size):
        for j in range(num_time_series):
            current_length = lengths_2[i, j].item()
            if current_length < max_sequence_length:
                data_2[i, j, current_length:, :, :] = 0.0
                data_2[i, j, current_length:, :, 1] = 0.0

    print(f"Input data shape: {data_2.shape}")
    print(f"Input lengths shape: {lengths_2.shape}")
    print(f"Example lengths (batch 0): {lengths_2[0]}\n")

    output_2 = encoder(data_2, lengths_2)
    print(f"Output encoding shape: {output_2.shape}\n")
    print(f"Output for length 0 sequence (batch 0, ts 0):\n{output_2[0, 0]}\n")
    print(f"Output for length 0 sequence (batch 1, ts 2):\n{output_2[1, 2]}\n")
    assert torch.allclose(output_2[0, 0], torch.zeros_like(output_2[0, 0]), atol=1e-6)
    assert torch.allclose(output_2[1, 2], torch.zeros_like(output_2[1, 2]), atol=1e-6)

    data_3 = torch.randn(batch_size, num_time_series, max_sequence_length, n_sub_features)
    lengths_3 = torch.zeros(batch_size, num_time_series, dtype=torch.long)
    
    data_3[:,:,:,:,:] = 0.0

    print(f"Input data shape: {data_3.shape}")
    print(f"Input lengths shape: {lengths_3.shape}")

    output_3 = encoder(data_3, lengths_3)
    print(f"Output encoding shape: {output_3.shape}\n")
    assert torch.allclose(output_3, torch.zeros_like(output_3), atol=1e-6)
    print("All outputs for length 0 sequences are zero as expected.")