import torch
import torch.nn as nn
from nns.fno.fno import FNO1d

class PredictionWrapper:
    def __init__(self, loss):
        self.loss = loss

def fno_loss_fn(y_pred_wrapper, label):
    return y_pred_wrapper.loss

class FNO1dWrap(nn.Module):
    def __init__(self, num_channels, width, modes, initial_step,t_train):
        super(FNO1dWrap, self).__init__()
        self.nn = FNO1d(num_channels, width, modes, initial_step)
        self.initial_step= initial_step
        self.t_train = t_train
        self.loss_fn = nn.MSELoss(reduction="mean")

    def forward(self, batched_data):
        (data,grid,label)=batched_data
        loss = 0
        inp_shape = list(data.shape)
        inp_shape = inp_shape[:-2]
        inp_shape.append(-1)
        
        # Initialize the prediction tensor
        pred = label[..., :self.initial_step, :]
        # Autoregressive training 
        for t in range(self.initial_step, self.t_train):
            # Reshape input tensor into [b, x1, ..., xd, t_init*v]
            inp = data.reshape(inp_shape)
            
            # Extract target at current time step
            y = label[..., t:t+1, :]

            # Model run
     
            if inp.dim() >= 4:
                inp = inp.squeeze(3)

            im = self.nn.forward(inp, grid)

            # Loss calculation
            _batch = im.size(0)
        
            loss += self.loss_fn(im.reshape(_batch, -1), y.reshape(_batch, -1))

            # Concatenate the prediction at current time step into the
            # prediction tensor
            pred = torch.cat((pred, im), -2)

            # Concatenate the prediction at the current time step to be used
            # as input for the next time step
            data = torch.cat((data[..., 1:, :], im), dim=-2)
        return PredictionWrapper(loss)


def id_loss_fn(data, label):
    return data
    