import abc
import copy
import math
import re
import yaml
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, Literal, Final, Callable, OrderedDict

import torch
import torch.nn as nn
from torch import Tensor
from transformers import PretrainedConfig, AutoConfig
from safetensors.torch import load_file

# --- Base Blocks ---

class Block(nn.Module):
    def __init__(self, resume: bool = False, **kwargs):
        super(Block, self).__init__()
        self.initialization_scheme = kwargs.get("initialization_scheme", "kaiming_normal")
        self.resume = resume

    def param_init(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                if self.initialization_scheme == "kaiming_normal":
                    nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="leaky_relu")
                elif self.initialization_scheme == "lecun_normal":
                    nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="linear")
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

class MLP(Block):
    def __init__(self, in_features: int, out_features: int, hidden_layers: List[int],
                 hidden_act: nn.Module = nn.ReLU(), output_act: nn.Module = None, dropout: float = 0.0, **kwargs):
        super(MLP, self).__init__(**kwargs)
        self.layers = nn.Sequential()
        in_size = in_features
        for i, h_size in enumerate(hidden_layers):
            self.layers.add_module(f"linear_{i}", nn.Linear(in_size, h_size))
            self.layers.add_module(f"activation_{i}", hidden_act)
            self.layers.add_module(f"dropout_{i}", nn.Dropout(dropout))
            in_size = h_size
        self.layers.add_module("output", nn.Linear(in_size, out_features))
        if output_act is not None:
            self.layers.add_module("output_activation", output_act)
        self.param_init()

    def forward(self, x):
        return self.layers(x)

# --- Normalization ---

class InstanceNormalization(abc.ABC):
    @abc.abstractmethod
    def get_norm_stats(self, values: Tensor, mask: Optional[Tensor] = None, **kwargs) -> Any: pass
    @abc.abstractmethod
    def normalization_map(self, values: Tensor, norm_stats: Any, derivative_num: Optional[int] = 0) -> Tensor: pass
    @abc.abstractmethod
    def inverse_normalization_map(self, values: Tensor, norm_stats: Any, derivative_num: Optional[int] = 0) -> Tensor: pass

    @staticmethod
    def squash_intermediate_dims(values: Tensor) -> tuple[Tensor, tuple]:
        original_shape = values.shape
        B, D = values.shape[0], values.shape[-1]
        reshaped_values = values.reshape(B, -1, D)
        return reshaped_values, original_shape

    @staticmethod
    def expand_norm_stats(shape: tuple, norm_stats: tuple[Tensor]) -> tuple[Tensor]:
        return tuple([x.unsqueeze(-2).expand(shape) for x in norm_stats])

class Standardization(InstanceNormalization):
    def get_norm_stats(self, values: Tensor, mask: Optional[Tensor] = None) -> tuple[Tensor]:
        values_flat, _ = self.squash_intermediate_dims(values)
        if mask is None:
            mean = torch.mean(values_flat, dim=-2)
            std = torch.std(values_flat, dim=-2)
        else:
            mask_flat, _ = self.squash_intermediate_dims(mask.bool())
            mask_flat = torch.broadcast_to(mask_flat, values_flat.shape)
            mean = torch.nanmean(torch.where(mask_flat, values_flat, torch.nan), dim=-2)
            se = (values_flat - mean.unsqueeze(-2)) ** 2
            var = torch.nanmean(torch.where(mask_flat, se, torch.nan), dim=-2)
            std = torch.sqrt(var)
        std = torch.clip(std, min=1e-6)
        return mean, std

    def normalization_map(self, values: Tensor, norm_stats: tuple[Tensor], derivative_num: int = 0) -> Tensor:
        mean, std = norm_stats
        values_flat, original_shape = self.squash_intermediate_dims(values)
        mean_exp = mean.unsqueeze(-2).expand_as(values_flat)
        std_exp = std.unsqueeze(-2).expand_as(values_flat)
        if derivative_num == 0:
            out = (values_flat - mean_exp) / std_exp
        elif derivative_num == 1:
            out = 1.0 / std_exp
        else:
            out = torch.zeros_like(values_flat)
        return out.reshape(original_shape)

    def inverse_normalization_map(self, values: Tensor, norm_stats: tuple[Tensor], derivative_num: int = 0) -> Tensor:
        mean, std = norm_stats
        values_flat, original_shape = self.squash_intermediate_dims(values)
        mean_exp = mean.unsqueeze(-2).expand_as(values_flat)
        std_exp = std.unsqueeze(-2).expand_as(values_flat)
        if derivative_num == 0:
            out = values_flat * std_exp + mean_exp
        elif derivative_num == 1:
            out = std_exp
        else:
            out = torch.zeros_like(values_flat)
        return out.reshape(original_shape)

class DeltaLogCentering(InstanceNormalization):
    def __init__(self, target_value: float = 0.01, **kwargs):
        self.target_value = target_value

    def get_norm_stats(self, values: Tensor, mask: Optional[Tensor] = None) -> tuple[Tensor]:
        delta_times = values[:, :, 1:, :] - values[:, :, :-1, :]
        delta_times_flat, _ = self.squash_intermediate_dims(delta_times)
        if mask is None:
            log_mean = torch.mean(torch.log(torch.clip(delta_times_flat, min=1e-12)), dim=-2)
        else:
            mask_flat, _ = self.squash_intermediate_dims(mask[:, :, 1:, :].bool())
            mask_flat = torch.broadcast_to(mask_flat, delta_times_flat.shape)
            log_mean = torch.nanmean(torch.where(mask_flat, torch.log(torch.clip(delta_times_flat, min=1e-12)), torch.nan), dim=-2)
        return (log_mean,)

    def normalization_map(self, values: Tensor, norm_stats: tuple[Tensor], derivative_num: int = 0) -> Tensor:
        log_mean, = norm_stats
        if derivative_num == 0:
            scale = self.target_value / torch.exp(log_mean)
            scale_exp = scale.unsqueeze(-2).expand_as(values.reshape(values.shape[0], -1, values.shape[-1]))
            return values * scale_exp.reshape(values.shape)
        elif derivative_num == 1:
            scale = self.target_value / torch.exp(log_mean)
            return scale.unsqueeze(-2).expand_as(values.reshape(values.shape[0], -1, values.shape[-1])).reshape(values.shape)
        return torch.zeros_like(values)

    def inverse_normalization_map(self, values: Tensor, norm_stats: tuple[Tensor], derivative_num: int = 0) -> Tensor:
        log_mean, = norm_stats
        if derivative_num == 0:
            scale = torch.exp(log_mean) / self.target_value
            scale_exp = scale.unsqueeze(-2).expand_as(values.reshape(values.shape[0], -1, values.shape[-1]))
            return values * scale_exp.reshape(values.shape)
        elif derivative_num == 1:
            scale = torch.exp(log_mean) / self.target_value
            return scale.unsqueeze(-2).expand_as(values.reshape(values.shape[0], -1, values.shape[-1])).reshape(values.shape)
        return torch.zeros_like(values)

# --- Attention Layers ---

class LinearAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
                 feature_map: Literal["elu", "softmax"] = "elu", normalize: bool = True):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.linear_Q = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.linear_K = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.linear_V = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.feature_map_type = feature_map
        self.normalize = normalize

    def feature_map(self, x: Tensor) -> Tensor:
        if self.feature_map_type == "elu":
            return torch.nn.functional.elu(x) + 1
        return x.softmax(dim=-1)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None):
        B, Tq, _ = query.shape
        B, Tk, _ = key.shape
        # standard MultiHead style: reshape to (B, head, T, head_dim)
        q = self.linear_Q(query).reshape(B, Tq, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.linear_K(key).reshape(B, Tk, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.linear_V(value).reshape(B, Tk, self.num_heads, self.head_dim).transpose(1, 2)
        
        q_mapped = self.feature_map(q)
        k_mapped = self.feature_map(k)
        
        if key_padding_mask is not None:
            # mask has shape (B, Tk) or similar
            m = (~key_padding_mask.bool()).float().unsqueeze(1).unsqueeze(-1) # (B, 1, Tk, 1)
            k_mapped = k_mapped * m
            v = v * m
            
        if self.normalize:
            k_summed = k_mapped.sum(dim=-2, keepdim=True).expand(-1, -1, Tq, -1)
            norm_coeff = (q_mapped * k_summed).sum(dim=-1, keepdim=True)
        else:
            norm_coeff = 1.0
            
        kv = k_mapped.transpose(-2, -1) @ v # (B, head, head_dim, head_dim)
        attn_output = (1 / (norm_coeff + 1e-8)) * q_mapped @ kv # (B, head, Tq, head_dim)
        attn_output = attn_output.transpose(1, 2).reshape(B, Tq, self.embed_dim)
        return self.out_proj(attn_output)

class ResidualAttentionLayer(nn.Module):
    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: str = "torch.nn.ReLU", bias: bool = True, batch_first: bool = True,
                 query_residual: bool = True, attn_method: str = "nn_multihead", **kwargs):
        super().__init__()
        self.batch_first = batch_first
        self.attn_method = attn_method
        if attn_method == "nn_multihead":
            self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=bias, batch_first=batch_first)
        else:
            self.attn = LinearAttention(d_model, nhead, dropout=dropout, bias=bias, 
                                        feature_map=kwargs.get("lin_feature_map", "elu"))
            
        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias)
        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias)
        self.norm1 = nn.LayerNorm(d_model, bias=bias)
        self.norm2 = nn.LayerNorm(d_model, bias=bias)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU() # Simplified from create_class_instance
        self.query_residual = query_residual

    def forward(self, queries: Tensor, keys: Tensor, values: Tensor, key_padding_mask: Optional[Tensor] = None) -> Tensor:
        if self.attn_method == "nn_multihead":
            attn_out, _ = self.attn(queries, keys, values, key_padding_mask=key_padding_mask)
        else:
            attn_out = self.attn(queries, keys, values, key_padding_mask=key_padding_mask)
        
        attn_out = self.dropout1(attn_out)
        
        if self.query_residual:
            x = self.norm1(queries + attn_out)
        else:
            x = self.norm1(attn_out)
        
        ff_out = self.linear2(self.dropout(self.activation(self.linear1(x))))
        ff_out = self.dropout2(ff_out)
        
        return self.norm2(x + ff_out)

class ResidualEncoderLayer(nn.Module):
    def __init__(self, d_model, **kwargs):
        super().__init__()
        # By default uses settings from kwargs (which includes attention_method)
        self.self_attn = ResidualAttentionLayer(d_model=d_model, **kwargs)

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, is_causal: bool = False):
        return self.self_attn(src, src, src, src_key_padding_mask)

class PathsSummaryBlockAttention(nn.Module):
    def __init__(self, embed_dim, **kwargs):
        super().__init__()
        self.locations_as_final_query = kwargs.get("locations_as_final_query", True)
        self.omega_1 = ResidualAttentionLayer(d_model=embed_dim, query_residual=False, **kwargs)
        self.omega_2 = ResidualAttentionLayer(d_model=embed_dim, query_residual=False, **kwargs)

    def forward(self, locations_encoding: Tensor, observations_encoding: Tensor,
                observations_padding_mask: Optional[Tensor] = None, paths_padding_mask: Optional[Tensor] = None) -> Tensor:
        B, G, H = locations_encoding.shape
        B, P, T, _ = observations_encoding.shape
        loc_enc_expand = locations_encoding.unsqueeze(1).expand(-1, P, -1, -1).reshape(B * P, G, H)
        obs_enc_flat = observations_encoding.reshape(B * P, T, H)
        if observations_padding_mask is not None:
            obs_mask_flat = observations_padding_mask.reshape(B * P, T)
        else:
            obs_mask_flat = None
        
        loc_dep_path_enc = self.omega_1(loc_enc_expand, obs_enc_flat, obs_enc_flat, obs_mask_flat)
        loc_dep_path_enc = loc_dep_path_enc.reshape(B, P, G, H).permute(0, 2, 1, 3).reshape(B * G, P, H)
        
        if self.locations_as_final_query:
            query = locations_encoding.reshape(B * G, 1, H)
        else:
            query = torch.ones((B * G, 1, H), device=locations_encoding.device)
            
        if paths_padding_mask is not None:
            paths_mask_expand = paths_padding_mask.unsqueeze(1).expand(-1, G, -1, -1).reshape(B * G, P)
        else:
            paths_mask_expand = None
            
        paths_dep_loc_enc = self.omega_2(query, loc_dep_path_enc, loc_dep_path_enc, paths_mask_expand)
        return paths_dep_loc_enc.view(B, G, H)

class AttentionOperator(nn.Module):
    def __init__(self, embed_dim, out_features, attention: dict = {}, projection: dict = {},
                 paths_block_attention: bool = True, num_res_layers: int = 1):
        super().__init__()
        self.paths_block_attention = paths_block_attention
        if self.paths_block_attention:
            self.paths_summary_attention = PathsSummaryBlockAttention(embed_dim, **attention)
        else:
            self.res_layers = nn.ModuleList([ResidualAttentionLayer(d_model=embed_dim, **attention) for _ in range(num_res_layers)])
        
        # Use MLP to match checkpoint naming structure (projection.layers...)
        self.projection = MLP(in_features=embed_dim, out_features=out_features, 
                              hidden_layers=projection.get("hidden_layers", [embed_dim, embed_dim]),
                              dropout=projection.get("dropout", 0.0))

    def forward(self, locations_encoding: Tensor, observations_encoding: Tensor,
                observations_padding_mask: Optional[Tensor] = None, paths_padding_mask: Optional[Tensor] = None):
        if self.paths_block_attention:
            x = self.paths_summary_attention(locations_encoding, observations_encoding, observations_padding_mask, paths_padding_mask)
        else:
            B, P, T, H = observations_encoding.shape
            obs_enc = observations_encoding.view(B, P * T, H)
            obs_mask = observations_padding_mask.view(B, P * T) if observations_padding_mask is not None else None
            x = locations_encoding
            for layer in self.res_layers:
                x = layer(x, obs_enc, obs_enc, obs_mask)
        return self.projection(x)

# --- Core Model Classes ---

class SDEConcepts:
    def __init__(self, locations: Tensor, drift: Tensor, diffusion: Tensor, log_var_drift: Optional[Tensor] = None, log_var_diffusion: Optional[Tensor] = None, normalized: bool = False):
        self.locations, self.drift, self.diffusion, self.normalized = locations, drift, diffusion, normalized
        self.log_var_drift, self.log_var_diffusion = log_var_drift, log_var_diffusion

    def _states_transformation(self, states_norm, states_norm_stats, normalize: bool):
        if normalize:
            grad = states_norm.normalization_map(self.locations, states_norm_stats, derivative_num=1)
        else:
            grad = states_norm.inverse_normalization_map(self.locations, states_norm_stats, derivative_num=1)
        self.drift = self.drift * grad
        self.diffusion = self.diffusion * grad
        if self.log_var_drift is not None: self.log_var_drift = self.log_var_drift + 2 * torch.log(grad)
        if self.log_var_diffusion is not None: self.log_var_diffusion = self.log_var_diffusion + 2 * torch.log(grad)

    def _times_transformation(self, times_norm, times_norm_stats, normalize: bool):
        dummy_times = torch.zeros_like(self.locations[..., 0:1])
        if normalize:
            inv_grad = times_norm.inverse_normalization_map(dummy_times, times_norm_stats, derivative_num=1)
        else:
            inv_grad = times_norm.normalization_map(dummy_times, times_norm_stats, derivative_num=1)
        self.drift = self.drift * inv_grad
        self.diffusion = self.diffusion * torch.sqrt(inv_grad)
        if self.log_var_drift is not None: self.log_var_drift = self.log_var_drift + 2 * torch.log(inv_grad)
        if self.log_var_diffusion is not None: self.log_var_diffusion = self.log_var_diffusion + torch.log(inv_grad)

    def _locations_transformation(self, states_norm, states_norm_stats, normalize: bool):
        if normalize:
            self.locations = states_norm.normalization_map(self.locations, states_norm_stats)
        else:
            self.locations = states_norm.inverse_normalization_map(self.locations, states_norm_stats)

    def normalize(self, states_norm, states_norm_stats, times_norm, times_norm_stats):
        if not self.normalized:
            self._states_transformation(states_norm, states_norm_stats, normalize=True)
            self._locations_transformation(states_norm, states_norm_stats, normalize=True)
            self._times_transformation(times_norm, times_norm_stats, normalize=True)
            self.normalized = True

    def renormalize(self, states_norm, states_norm_stats, times_norm, times_norm_stats):
        if self.normalized:
            self._states_transformation(states_norm, states_norm_stats, normalize=False)
            self._locations_transformation(states_norm, states_norm_stats, normalize=False)
            self._times_transformation(times_norm, times_norm_stats, normalize=False)
            self.normalized = False

class OdeConcept:
    def __init__(self, locations, drift, normalized, states_norm, states_norm_stats, times_norm, times_norm_stats):
        self.states_norm, self.states_norm_stats = states_norm, states_norm_stats
        self.times_norm, self.times_norm_stats = times_norm, times_norm_stats
        self.base = SDEConcepts(locations, drift, torch.zeros_like(drift), None, None, normalized)
        self._sync()

    def normalize(self):
        self.base.normalize(self.states_norm, self.states_norm_stats, self.times_norm, self.times_norm_stats)
        self._sync()

    def renormalize(self):
        self.base.renormalize(self.states_norm, self.states_norm_stats, self.times_norm, self.times_norm_stats)
        self._sync()

    def _sync(self):
        self.drift, self.locations, self.normalized = self.base.drift, self.base.locations, self.base.normalized

    @classmethod
    def builder(cls): return cls.OdeConceptBuilder()

    class OdeConceptBuilder:
        def __init__(self):
            self._locations, self._drift, self._normalized = None, None, None
            self._states_norm, self._states_norm_stats = None, None
            self._times_norm, self._times_norm_stats = None, None
        def locations(self, v): self._locations = v; return self
        def drift(self, v): self._drift = v; return self
        def normalized(self, v): self._normalized = v; return self
        def states_norm(self, v): self._states_norm = v; return self
        def times_norm(self, v): self._times_norm = v; return self
        def states_norm_stats(self, v): self._states_norm_stats = v; return self
        def times_norm_stats(self, v): self._times_norm_stats = v; return self
        def build(self): 
            return OdeConcept(self._locations, self._drift, self._normalized, self._states_norm, self._states_norm_stats, self._times_norm, self._times_norm_stats)

@dataclass
class TrajectoryFeatures:
    x: Tensor; delta_x: Tensor; delta_x_squared: Tensor; delta_t: Tensor; feature_mask: Tensor

class TrajectoryEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        dim_out = config.dim_embed // 4
        self.x_proj = nn.Linear(config.dim_max_trajectory, dim_out, bias=config.use_bias_for_projection)
        self.delta_x_proj = nn.Linear(config.dim_max_trajectory, dim_out, bias=config.use_bias_for_projection)
        self.delta_x_squared_proj = nn.Linear(config.dim_max_trajectory, dim_out, bias=config.use_bias_for_projection)
        self.delta_t_proj = nn.Linear(1, dim_out, bias=config.use_bias_for_projection)
        
        # Uses custom ResidualEncoderLayer with config-based attention
        layer = ResidualEncoderLayer(d_model=config.dim_embed, **config.get_attention_layer_config())
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning, message=".*enable_nested_tensor is True.*")
            self.context_encoder = nn.TransformerEncoder(layer, num_layers=config.num_context_encoder_layers)

    def forward(self, features: TrajectoryFeatures):
        x = self.x_proj(features.x)
        dx = self.delta_x_proj(features.delta_x)
        dx2 = self.delta_x_squared_proj(features.delta_x_squared)
        dt = self.delta_t_proj(features.delta_t)
        feat = torch.cat([dt, x, dx, dx2], dim=-1)
        b, t, n, d = feat.shape
        feat = feat.view(b, t * n, d)
        mask = (~features.feature_mask.bool()).view(b, t * n)
        D = self.context_encoder(feat, src_key_padding_mask=mask)
        return D.view(b, t, n, d)

class FimOdeon(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.spatial_norm = Standardization()
        self.temporal_norm = DeltaLogCentering()
        self.trajectory_encoder = TrajectoryEncoder(config)
        self.location_proj = nn.Sequential(
            nn.Linear(config.dim_max_trajectory, config.dim_embed), nn.ReLU(),
            nn.Linear(config.dim_embed, config.dim_embed)
        )
        self.functional_decoder = AttentionOperator(
            embed_dim=config.dim_embed, out_features=config.dim_max_trajectory,
            attention=config.get_attention_layer_config(), projection=config.get_projection_config(),
            paths_block_attention=False, num_res_layers=config.num_res_layers_functional_decoder
        )

    def pad_if_necessary(self, values: torch.Tensor) -> torch.Tensor:
        if values.shape[-1] < self.config.dim_max_trajectory:
            padding = torch.zeros(values.shape[:-1] + (self.config.dim_max_trajectory - values.shape[-1],),
                                  device=values.device, dtype=values.dtype)
            values = torch.concat([values, padding], dim=-1)
        return values

    def prepare_input(self, trajectories, times, mask):
        # Backward fill
        mask_bc = torch.broadcast_to(mask, trajectories.shape)
        traj_flip = torch.flip(trajectories, dims=(-2,))
        mask_flip = torch.flip(mask_bc, dims=(-2,))
        # Simplified forward fill
        def simple_fill(x, m):
            res = x.clone()
            for i in range(1, x.shape[-2]):
                res[..., i, :] = torch.where(m[..., i, :], x[..., i, :], res[..., i-1, :])
            return res
        traj_filled = torch.flip(simple_fill(traj_flip, mask_flip), dims=(-2,))
        times_filled = torch.flip(simple_fill(torch.flip(times, dims=(-2,)), torch.flip(mask, dims=(-2,))), dims=(-2,))
        
        # Normalize
        concept = OdeConcept.builder()
        s_stats = self.spatial_norm.get_norm_stats(traj_filled, mask)
        concept.states_norm(self.spatial_norm).states_norm_stats(s_stats)
        traj_norm = self.spatial_norm.normalization_map(traj_filled, s_stats)
        
        delta_times = times_filled[:, :, 1:, :] - times_filled[:, :, :-1, :]
        t_stats = self.temporal_norm.get_norm_stats(times_filled, mask)
        concept.times_norm(self.temporal_norm).times_norm_stats(t_stats)
        times_norm = self.temporal_norm.normalization_map(times_filled, t_stats)
        
        # Features
        X = traj_norm[:, :, :-1, :]
        dX = traj_norm[:, :, 1:, :] - traj_norm[:, :, :-1, :]
        dX2 = dX ** 2
        dT = times_norm[:, :, 1:, :] - times_norm[:, :, :-1, :]
        f_mask = mask[:, :, :-1, :]
        
        return TrajectoryFeatures(x=X, delta_x=dX, delta_x_squared=dX2, delta_t=dT, feature_mask=f_mask), concept

    def forward(self, trajectories, times, locations, mask):
        # trajectories: [B, T, N, D_traj], times: [B, T, N, 1], locations: [B, L, D_traj], mask: [B, T, N, 1]
        trajectories = self.pad_if_necessary(trajectories)
        locations = self.pad_if_necessary(locations)
        
        features, concept = self.prepare_input(trajectories, times, mask)
        
        # Normalize locations based on trajectory stats
        locations_norm = self.spatial_norm.normalization_map(locations, concept._states_norm_stats)
        concept.locations(locations_norm)
        
        D = self.trajectory_encoder(features)
        loc_enc = self.location_proj(locations_norm)
        
        pred = self.functional_decoder(loc_enc, D, (~features.feature_mask.bool()).view(D.shape[0], -1))
        concept.drift(pred).normalized(True)
        return concept.build()

class UncertaintyEstimator(nn.Module):
    def __init__(self, config):
        super().__init__()
        cfg = config.get_u_model_config()
        self.functional_encoder = AttentionOperator(embed_dim=config.dim_embed, out_features=1, **cfg)

    def forward(self, loc_enc, D, feature_mask):
        u = self.functional_encoder(loc_enc, D, (~feature_mask).view(D.shape[0], -1))
        return u.squeeze(-1)

# --- Wrapper & Factory ---

@dataclass(kw_only=True)
class FimOdeonConfiguration(PretrainedConfig):
    dim_max_trajectory: int; use_bias_for_projection: bool; dim_embed: int
    num_context_encoder_layers: int; attention_method: str; attention_map: Optional[str]
    use_bias_in_attention: bool; use_query_residual_in_attention: bool; num_heads: int
    dim_feedforward: int; dropout: float; num_res_layers_functional_decoder: int
    num_res_layer_u_model: int; dim_hidden_u_model: int; dim_ffn_u_model: int
    model_type: str = "FimOdeon"
    def get_attention_layer_config(self):
        return {"nhead": self.num_heads, "dim_feedforward": self.dim_feedforward, "dropout": self.dropout, 
                "bias": self.use_bias_in_attention, "query_residual": self.use_query_residual_in_attention, 
                "attn_method": self.attention_method, "lin_feature_map": self.attention_map}
    def get_projection_config(self):
        return {"hidden_layers": (self.dim_embed, self.dim_embed), "dropout": self.dropout}
    def get_u_model_config(self):
        cfg = {"num_res_layers": self.num_res_layer_u_model, "attention": self.get_attention_layer_config(), 
               "projection": {"hidden_layers": (self.dim_hidden_u_model, self.dim_hidden_u_model), "dropout": self.dropout}, 
               "paths_block_attention": False}
        return cfg

class TrainingWrapper(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        if "model" in config:
            model_cfg = config["model"]["model_config"]
        elif "model_config" in config:
            model_cfg = config["model_config"]
        else:
            model_cfg = config
            
        self.model_config = FimOdeonConfiguration(**model_cfg)
        self.model = FimOdeon(self.model_config)
        self.u_model = UncertaintyEstimator(self.model_config)

    def load_state_dict(self, state_dict, strict=True):
        return super().load_state_dict(state_dict, strict)

# --- Evaluator Class ---

class OdeonEval:
    def __init__(self, path_to_checkpoints_dir: Path):
        # Find train_parameters.yaml in current or parents
        config_path = path_to_checkpoints_dir / "train_parameters.yaml"
        if not config_path.exists():
            config_path = path_to_checkpoints_dir.parent / "train_parameters.yaml"
        if not config_path.exists():
            config_path = path_to_checkpoints_dir.parent.parent / "train_parameters.yaml"
        
        if not config_path.exists():
            raise FileNotFoundError(f"Could not find train_parameters.yaml in {path_to_checkpoints_dir} or its parents.")
        
        class SafeLoaderWithTuple(yaml.SafeLoader):
            pass
        SafeLoaderWithTuple.add_constructor(
            'tag:yaml.org,2002:python/tuple',
            lambda loader, node: tuple(loader.construct_sequence(node))
        )
        
        with open(config_path, "r") as f:
            config_dict = yaml.load(f, Loader=SafeLoaderWithTuple)
        
        self.wrapper = TrainingWrapper(config_dict)
        
        # Find weights
        weights_path = None
        search_paths = [
            path_to_checkpoints_dir / "best-model" / "model.safetensors",
            path_to_checkpoints_dir / "model.safetensors",
            path_to_checkpoints_dir / "best-model" / "model-checkpoint.pth",
            path_to_checkpoints_dir / "model-checkpoint.pth"
        ]
        
        for p in search_paths:
            if p.exists():
                weights_path = p
                break
        
        if weights_path is None:
            raise FileNotFoundError(f"Could not find model weights in {path_to_checkpoints_dir} or subdirectories.")
            
        if weights_path.suffix == ".safetensors":
            weights = load_file(weights_path)
        else:
            weights = torch.load(weights_path, map_location='cpu', weights_only=False)
        
        new_weights = OrderedDict()
        for k, v in weights.items():
            nk = k
            # Normal key replacements for backward compatibility
            nk = nk.replace("context_encoder.layers", "trajectory_encoder.context_encoder.layers")
            if nk.startswith("model.functional_encoder"):
                nk = nk.replace("model.functional_encoder", "model.functional_decoder")
            nk = nk.replace("model.delta_t_proj", "model.trajectory_encoder.delta_t_proj")
            nk = nk.replace("model.delta_x_proj", "model.trajectory_encoder.delta_x_proj")
            nk = nk.replace("model.x_proj", "model.trajectory_encoder.x_proj")
            nk = nk.replace("model.delta_x_squared_proj", "model.trajectory_encoder.delta_x_squared_proj")
            
            # Remove any double mappings
            nk = nk.replace("model.trajectory_encoder.trajectory_encoder", "model.trajectory_encoder")
            
            new_weights[nk] = v
            
        self.wrapper.load_state_dict(new_weights, strict=True)
        self.model = self.wrapper.model
        self.model.eval()
        self.u_model = self.wrapper.u_model
        self.u_model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.wrapper.to(self.device)

    @torch.no_grad()
    def predict(self, traj, times, locations, mask=None):
        if mask is None: mask = torch.ones_like(traj[..., :1], dtype=torch.bool)
        traj, times, locations, mask = [x.to(self.device) for x in [traj, times, locations, mask]]
        concept = self.model(traj, times, locations, mask)
        concept.renormalize()
        return concept.drift
