import torch
from einops import rearrange

from models.multi_input_ffno import MultiInputFFNO

def rollout_step(initial_step, final_step, model, yy, target, grid, loss_fn, is_training, step_cfg):
    assert initial_step >= 1, "Initial step should be at least 1"
    preds = []
    xx = yy[..., initial_step-1, :].unsqueeze(-2)
    inp_shape = list(xx.shape)
    inp_shape = inp_shape[:-2]
    inp_shape.append(-1)
    if final_step > yy.shape[-2]:
        final_step = yy.shape[-2]
    for t in range(initial_step, final_step):
        
        # Reshape input tensor into [b, x1, ..., xd, t_init*v]
        inp = xx.reshape(inp_shape)
        
        # Extract target at current time step
        _target = target[..., t:t+1, :]
        # y is the same as the target but can be noisy
        y = yy[..., t:t+1, :]

        # Model run
        im = model(inp, grid)
        im = im.unsqueeze(-2)

        # Loss calculation
        _batch = im.size(0)
        assert im.shape[:-1] == _target.shape[:-1], f"Prediction shape {im.shape} should be equal to target shape {_target.shape}"
        loss += loss_fn(im, _target) / (final_step - initial_step)

        # Concatenate the prediction at current time step into the
        # prediction tensor
        preds.append(im)

        # Concatenate the prediction at the current time step to be used
        # as input for the next time step

        if is_training and step_cfg.get("teacher_forcing", False):
            xx = torch.cat((xx[..., 1:, :], y), dim=-2)
        else: 
            xx = torch.cat((xx[..., 1:, :], im), dim=-2)

    pred = torch.cat(preds, dim=-2)
    _batch_size = yy.size(0)
    pred = pred
    _target = target[...,initial_step:final_step,:]
    return loss, target, pred


def single_step(model, yy, target, grid, loss_fn):
    _batch = yy.shape[0]
    x = yy[..., 0, :]
    _target = target[..., 1, :]
    pred = model(x, grid)
    assert pred.shape[:-1] == _target.shape[:-1], f"Prediction shape {pred.shape} should be equal to target shape {_target.shape}"
    loss = loss_fn(pred, _target) 
    return loss, target, pred


def one_to_seq_step(initial_step, final_step, model, yy, target, grid, loss_fn, is_training, step_cfg, batch_dt):
    # assert args.dataset.t_train == args.dataset.t_test, "One to sequence step type is only supported for t_train == t_test"
    _batch = yy.shape[0]
    # xx (B, Sx, [Sy], [Sz], 1, V), yy (B, Sx, [Sy], [Sz], T, V), grid (B,Sx, [Sy], [Sz], 1)
    T, V = yy.shape[-2:]
    inp = yy[..., initial_step-1, :]
    # The next time steps are the target
    _target = target[..., initial_step:final_step, :]

    if step_cfg.train_timesteps > 0: 
        n_timesteps = step_cfg.train_timesteps
    else: 
        # use all train timesteps
        n_timesteps = step_cfg.t_train - initial_step
    # # if is_training and (final_step - initial_step ) % n_timesteps != 0:
    # #     raise NotImplementedError(f"n_imesteps {n_timesteps} should divide the number of train timesteps {final_step - initial_step}")
    if n_timesteps > final_step - initial_step:
        raise NotImplementedError(f"n_timesteps {n_timesteps} should be less thant the total number of timesteps {final_step - initial_step}")
    if is_training:
        # predict n_timesteps by n_timesteps
        ys = []
        x_ = inp # (B, S, D)
        for t_i in range(initial_step, final_step, n_timesteps):
            t_f = min(final_step, t_i + n_timesteps)
            y_ = model(x_, grid, n_timesteps=t_f - t_i, batch_dt = batch_dt)
            x_ = yy[..., t_f - 1, :] if step_cfg.teacher_forcing else y_[...,-1,:]
            ys.append(y_)
        # concat them: (B, S, t, D) -> (B, S, T, D)
        pred = torch.cat(ys, dim=-2)
    else: 
        pred = model.predict(inp, grid, total_timesteps = final_step - initial_step, train_timesteps = n_timesteps, batch_dt = batch_dt, version = step_cfg.version)
    assert pred.shape[:-1] == _target.shape[:-1], f"Prediction shape {pred.shape} should be equal to target shape {_target.shape}"
    # norm2 loss is time-step wise
    batched_pred = rearrange(pred, 'b ... t v -> (b t) ... v')
    batched_target = rearrange(_target, 'b ... t v -> (b t) ... v')
    _batch = batched_pred.size(0)
    loss = loss_fn(batched_pred.reshape(_batch, -1), batched_target.reshape(_batch, -1)) / (final_step - initial_step)
    # evaluator handles the reshaping
    return loss, target, pred

def sequential_step(model, yy, target, grid, loss_fn, is_training, step_cfg, batch_dt):
    # yy (B, Sx, [Sy], [Sz], T, V), grid (B,Sx, [Sy], [Sz], 1)
    # All times steps are fed into the model
    inp = yy[..., :-1, :]
    # The next time steps are the target
    target = target[..., 1:, :]
    train_timesteps = step_cfg.train_timesteps

    T = target.shape[-2]
    
    # Model run
    if is_training:
        pred = []
        for t in range(0, T, train_timesteps):
            x = inp[..., t:t+train_timesteps, :]
            y = model(x, grid, batch_dt)
            pred.append(y)
        pred = torch.cat(pred, dim=-2)
    else:
        if isinstance(model, MultiInputFFNO):
            x_ = inp[..., :model.history_len, :]
            pred = model.predict_with_history(x_, grid, n_timesteps = T, train_timesteps = train_timesteps,
                                reset_memory = step_cfg.reset_memory, LG_length = step_cfg.LG_length,
                                batch_dt = batch_dt, discard_state = step_cfg.discard_state)

        else:
            x_ = inp[..., 0, :]
            pred = model.predict(x_, grid, n_timesteps = T, train_timesteps = train_timesteps,
                                    reset_memory = step_cfg.reset_memory, LG_length = step_cfg.LG_length,
                                    batch_dt = batch_dt, discard_state = step_cfg.discard_state)

    # Loss calculation
    
    assert pred.shape[:-1] == target.shape[:-1], f"Prediction shape {pred.shape} should be equal to target shape {target.shape}"
    # norm2 loss is time-step wise
    batched_pred = rearrange(pred, 'b ... t v -> (b t) ... v')
    batched_target = rearrange(target, 'b ... t v -> (b t) ... v')
    _batch = batched_pred.size(0)
    loss = loss_fn(batched_pred.reshape(_batch, -1), batched_target.reshape(_batch, -1)) / T
    # evaluator handles the reshaping

    return loss, target, pred