import torch
import torch.nn as nn
import torch.nn.functional as F

class LiquidNeuron(nn.Module):   
    def __init__(self, input_dim, hidden_dim, output_dim, dt=0.1, steps=3):
        super(LiquidNeuron, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.dt = dt
        self.steps = steps 
        self.dynamics = nn.Sequential(
            nn.Linear(hidden_dim + input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.output_mapping = nn.Linear(hidden_dim, output_dim)
        
        self.init_mapping = nn.Linear(input_dim, hidden_dim)
        
    def forward(self, x, h=None):
        batch_size = x.size(0)
        
        if h is None:
            h = self.init_mapping(x)
        
        for _ in range(self.steps):
            dh = self.dynamics(torch.cat([h, x], dim=-1))
            h = h + self.dt * dh
            
        out = self.output_mapping(h)
        return out, h

import torch
import torch.nn as nn

class LiquidTimeSeries(nn.Module):
    
    def __init__(self, n_features, hidden_dim, dt=0.1, steps=3):
        super().__init__()
        self.n_features = n_features
        self.hidden_dim = hidden_dim
        self.dt = dt
        self.steps = steps
        self.time_encoder = nn.Linear(n_features, hidden_dim)
        self.liquid_dynamics = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.output_layer = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        batch_size, n_stations, n_hours, n_features = x.shape
        
        x_flat = x.view(batch_size * n_stations, n_hours, n_features)
        
        state = torch.zeros(batch_size * n_stations, self.hidden_dim, device=x.device)
        for t in range(n_hours):
            current = self.time_encoder(x_flat[:, t])
            
            for _ in range(self.steps):
                dstate = self.liquid_dynamics(state)
                
                state = state + self.dt * (dstate + 0.5 * current)
        
        prediction = self.output_layer(state)  
        
        prediction = prediction.view(batch_size, n_stations, 1)
        
        return prediction

import torch
import torch.nn as nn
from Modules.Activations import Tanh
from Modules.GNN.liquidnet import LiquidNeuron

class LiquidEmbedding(nn.Module):
    
    def __init__(self, feature_dim, hidden_dim, n_days=6, dt=0.1, steps=3):
        super().__init__()
        self.feature_dim = feature_dim 
        self.hidden_dim = hidden_dim
        self.n_days = n_days  

        self.feature_proj = nn.Linear(feature_dim, hidden_dim // 2)
        
        self.liquid_neuron = LiquidNeuron(
            input_dim=hidden_dim // 2,  
            hidden_dim=hidden_dim,
            output_dim=hidden_dim,
            dt=dt,
            steps=steps
        )

        self.pos_encoder = nn.Linear(2, hidden_dim // 2)

        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim // 2, hidden_dim),
            Tanh()
        )
        
    def forward(self, x):
        batch_size = x.size(0)
        
        history_features = x[:, :-2]
        position = x[:, -2:] 
        
        features_per_day = history_features.shape[1] // self.n_days
        history_reshaped = history_features.view(batch_size, self.n_days, features_per_day)
        

        liquid_state = None
        

        for day in range(self.n_days):

            day_features = history_reshaped[:, day]

            day_proj = self.feature_proj(day_features)

            liquid_state, _ = self.liquid_neuron(day_proj, liquid_state)
    
        pos_embedding = self.pos_encoder(position)
        
        embedding = self.fusion(torch.cat([liquid_state, pos_embedding], dim=1))
        
        return embedding

class GeoLiquidNeuron(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, dt_base=0.1, steps=3):
        super(GeoLiquidNeuron, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.dt_base = dt_base  
        self.steps = steps 
        
        self.tau_generator = nn.Sequential(
            nn.Linear(2, 16),
            nn.Tanh(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
        
        self.dynamics = nn.Sequential(
            nn.Linear(hidden_dim + input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.output_mapping = nn.Linear(hidden_dim, output_dim)
        
        self.init_mapping = nn.Linear(input_dim, hidden_dim)
        
    def forward(self, x, pos, h=None):
        
        tau = self.tau_generator(pos) + 0.5
        dt = self.dt_base * tau
        
        if h is None:
            h = self.init_mapping(x)
        
        h_orig = h 
        
        for _ in range(self.steps):
            dh = self.dynamics(torch.cat([h, x], dim=-1))
            delta = dt * dh  
            h = h + delta  
        out = self.output_mapping(h)
        return out, h

import torch
from torch import nn

class GeoLiquidOU(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, dt_base=0.1, steps=3):
        super(GeoLiquidOU, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.dt_base = dt_base
        self.steps = steps
        self.tau_generator = nn.Sequential(
            nn.Linear(2, 16),
            nn.Tanh(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
        self.theta_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh()
        )
        self.kappa = nn.Parameter(torch.ones(1, hidden_dim)) 
        self.sigma = nn.Parameter(torch.ones(1, hidden_dim) * 0.1)
        self.init_mapping = nn.Linear(input_dim, hidden_dim)
        self.output_mapping = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, pos, h=None):
        batch_size = x.size(0)

        tau = self.tau_generator(pos) + 0.5  
        dt = self.dt_base * tau 
        if h is None:
            h = self.init_mapping(x) 

        theta = self.theta_net(x) 
        for _ in range(self.steps):
            epsilon = torch.randn_like(h)
            dh = self.kappa * (theta - h) + self.sigma * epsilon 
            delta = dt * dh
            h = h + delta

        out = self.output_mapping(h) 
        return out, h


class GeoLiquidEmbedding(nn.Module): 
    def __init__(self, feature_dim, hidden_dim, n_days=6, dt=0.1, steps=3):
        super().__init__()
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        self.n_days = n_days
        
        self.feature_proj = nn.Linear(feature_dim, hidden_dim)
        self.e_mlp = nn.Sequential(
            nn.Linear(feature_dim+2, self.hidden_dim),
            Tanh(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            Tanh())
        self.pos_proj = nn.Linear(2, hidden_dim//2)
        
        self.geo_liquid = GeoLiquidOU(
            input_dim=hidden_dim,  
            hidden_dim=hidden_dim,
            output_dim=hidden_dim,
            dt_base=dt,
            steps=steps
        )
        
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim // 2, hidden_dim),
            nn.Tanh()
        )
        
    def forward(self, x):
        batch_size = x.size(0)
        history_features = x[:, :-2]
        position = x[:, -2:]
        features_per_day = history_features.shape[1] // self.n_days
        history_reshaped = history_features.view(batch_size, self.n_days, features_per_day)
        
        liquid_state = None
        for day in range(self.n_days):
            day_features = history_reshaped[:, day]
            day_proj = self.feature_proj(day_features)
            liquid_out, liquid_state = self.geo_liquid(day_proj, position, liquid_state)
        pos_embedding = self.pos_proj(position)
        embedding = self.fusion(torch.cat([liquid_out, pos_embedding], dim=1))
        if self.n_days > 0:
            last_day_proj = self.e_mlp(torch.cat((history_reshaped[:, -1], position), dim=1))
            embedding = embedding + 0.2 * last_day_proj
        
        return embedding