import torch
import torch.nn as nn

class StationAwareLinear(nn.Module):
    def __init__(self, in_features, out_features, n_stations=2288, residual_scale=0.1):
        super(StationAwareLinear, self).__init__()
        self.standard_linear = nn.Linear(in_features, out_features)
        self.station_weight_delta = nn.Parameter(torch.zeros(n_stations, out_features))
        self.station_bias_delta = nn.Parameter(torch.zeros(n_stations, out_features))
        self.residual_scale = residual_scale
        
    def forward(self, x, station_indices=None):
        out = self.standard_linear(x)
        
        if station_indices is not None:
            weight_delta = self.station_weight_delta[station_indices] * self.residual_scale
            bias_delta = self.station_bias_delta[station_indices] * self.residual_scale
            out = out * (1.0 + weight_delta) + bias_delta
            
        return out

