import torch
import torch.nn as nn
import torch.nn.functional as F
from advection_diffusion import compute_advection, compute_diffusion

class LocationEmbedding(nn.Module):
    def __init__(self, n_locations, d_model):
        super().__init__()
        self.embedding = nn.Embedding(n_locations, d_model)
        nn.init.xavier_uniform_(self.embedding.weight)

    def forward(self, location_idx):
        return self.embedding(location_idx).unsqueeze(2)

class LearnablePositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pe = nn.Parameter(torch.randn(max_len, 1, 1, d_model))
        nn.init.xavier_uniform_(self.pe)

    def forward(self, x):
        return x + self.pe[:x.size(0)]

class SpatialAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=0.1)

    def forward(self, x):
        seq_len, batch_size, n_locations, d_model = x.shape
        x = x.view(seq_len * batch_size * n_locations, 1, d_model)
        attn_output, _ = self.attention(x, x, x)
        attn_output = attn_output.view(seq_len, batch_size, n_locations, d_model)
        return attn_output

class TemporalAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=0.1)

    def forward(self, x):
        seq_len, batch_size, n_locations, d_model = x.shape
        x = x.permute(2, 0, 1, 3).reshape(n_locations * seq_len * batch_size, 1, d_model)
        attn_output, _ = self.attention(x, x, x)
        attn_output = attn_output.view(n_locations, seq_len, batch_size, d_model).permute(1, 2, 0, 3)
        return attn_output

class SpatioTemporalTransformerLayer(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=0.1)
        self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=0.1)
        self.fc = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(0.1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.advection_weight = nn.Parameter(torch.zeros(1))
        self.diffusion_weight = nn.Parameter(torch.zeros(1))
        self.advection_norm = nn.LayerNorm(d_model)
        self.diffusion_norm = nn.LayerNorm(d_model)

        self.spatial_attention = SpatialAttention(d_model, n_heads)
        self.temporal_attention = TemporalAttention(d_model, n_heads)

        self.activation = nn.GELU()

    def forward(self, x, encoder_output, temp_k, velocity_field, use_attention=False):
        if use_attention:
            x = self.spatial_attention(x) + x
            x = self.temporal_attention(x) + x

        seq_len, batch_size, n_locations, d_model = x.shape
        x_reshaped = x.view(seq_len, batch_size * n_locations, d_model)
        attn_output, _ = self.self_attn(x_reshaped, x_reshaped, x_reshaped)
        attn_output = self.dropout(attn_output).view(seq_len, batch_size, n_locations, d_model)
        x = self.norm1(x + attn_output)

        if encoder_output is not None:
            enc_reshaped = encoder_output.view(seq_len, batch_size * n_locations, d_model)
            cross_attn_output, _ = self.cross_attn(x_reshaped, enc_reshaped, enc_reshaped)
            cross_attn_output = self.dropout(cross_attn_output).view(seq_len, batch_size, n_locations, d_model)
            x = self.norm2(x + cross_attn_output)

        x = self.dropout(self.activation(self.fc(x)))
        x = self.norm3(x)

        if temp_k is not None:
            batch_x = x.permute(1, 0, 2, 3)
            adv = compute_advection(batch_x, velocity_field).permute(1, 0, 2, 3)
            dif = compute_diffusion(batch_x, temp_k).permute(1, 0, 2, 3)
            x = x + self.advection_weight * self.advection_norm(adv) + self.diffusion_weight * self.diffusion_norm(dif)

        return x

class SpatioTemporalEncoder(nn.Module):
    def __init__(self, d_model, n_heads, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            SpatioTemporalTransformerLayer(d_model, n_heads) for _ in range(num_layers)
        ])

    def forward(self, x, temp_k, velocity_field):
        for i, layer in enumerate(self.layers):
            use_attention = (i == 0)
            x = layer(x, None, temp_k, velocity_field, use_attention=use_attention)
        return x

class SpatioTemporalDecoder(nn.Module):
    def __init__(self, d_model, horizon=24):
        super().__init__()
        self.horizon = horizon
        self.norm = nn.LayerNorm(d_model)

        self.output_projection = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, horizon),
            # nn.Hardtanh(-2, 2)  # 更宽松的限制，便于学习
        )

    def forward(self, encoder_output, temp_k, velocity_field):
        x = encoder_output.mean(dim=[0, 2])
        x = self.norm(x)
        out = self.output_projection(x)
        return out

class SpatioTemporalTransformer(nn.Module):
    def __init__(self, input_dim, d_model, n_heads,
                 num_encoder_layers, num_decoder_layers,
                 n_locations, sequence_length, horizon=24):
        super().__init__()
        self.input_projection = nn.Linear(input_dim, d_model)
        self.location_embedding = LocationEmbedding(n_locations, d_model)
        self.positional_encoding = LearnablePositionalEncoding(sequence_length, d_model)
        self.encoder = SpatioTemporalEncoder(d_model, n_heads, num_encoder_layers)
        self.decoder = SpatioTemporalDecoder(d_model, horizon)

    def forward(self, x, location_idx, temp_k, velocity_field):
        x = self.input_projection(x).unsqueeze(2)
        x = x + self.location_embedding(location_idx[:x.size(0)])
        # x = self.positional_encoding(x)
        enc_out = self.encoder(x, temp_k, velocity_field)
        return self.decoder(enc_out, temp_k, velocity_field)
