import torch
from neuralop.models import FNO as neuralop_FNO

from src.utils.database import standardize


class FNO(neuralop_FNO):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, 
        x: torch.Tensor, 
        predict_normed=False,
        n_future_steps=1,
        state_labels=None,
        dset_name: str | None = None
    ):
        x = x[:,-1,...]

        # dimensions
        spatial_dims = tuple(range(2,x.squeeze(-1,-2).ndim))

        # preprocess
        x, mean, std = standardize(x, dims=spatial_dims, return_stats=True)
        metadata = {'mean': mean.unsqueeze(1), 'std': std.unsqueeze(1)}

        outputs = []

        for _ in range(n_future_steps):

            x = super().forward(x)

            outputs.append(x)

        x = torch.stack(outputs, dim=1)

        if predict_normed:
            x = x * metadata['std'] + metadata['mean']
            
        return x, metadata
