"""Auto-regressive models to model time-series data. A model is a
`torch.nn.Module`.

Data is assumed to be of the form (state, action, next state), where each
item is a `torch.Tensor`.

Behind the curtains, the model is a residual model.
"""
from swmpo.transition import Transition
import random
import torch


def get_relu_mlp(
        input_size: int,
        hidden_sizes: list[int],
        output_size: int,
        seed: str,
        ) -> torch.nn.Module:
    """Return an initialized ReLU-MLP."""
    # Construct list of layers
    _random = random.Random(seed)
    torch.manual_seed(
        seed=int.from_bytes(_random.randbytes(3), 'big', signed=False),
    )

    layer_sizes = list()
    for i, size in enumerate(hidden_sizes + [output_size]):
        in_size = input_size if i == 0 else hidden_sizes[i-1]
        out_size = size
        layer_sizes.append((in_size, out_size))

    layers = list()
    for i, (in_size, out_size) in enumerate(layer_sizes):
        layers.append(torch.nn.Linear(
            in_features=in_size, out_features=out_size,
        ))
        if i < len(layer_sizes) - 1:
            # Last layer has no activation so that the entire state space
            # is in the range
            layers.append(torch.nn.BatchNorm1d(num_features=out_size))
            layers.append(torch.nn.ReLU())

    model = torch.nn.Sequential(*layers)
    return model


def get_prediction(
        source_state: torch.Tensor,
        action: torch.Tensor,
        model: torch.nn.Module,
        dt: float,
        ) -> torch.Tensor:
    """Helper function to output the prediction of the model."""
    x = torch.cat([source_state, action]).unsqueeze(0)
    predicted = source_state + model(x)*dt
    return predicted


def get_raw_error(
        transition: Transition,
        model: torch.nn.Module,
        dt: float,
        ) -> float:
    """Helper function to get the L2 distance between the prediction
    of the model and the ground-truth data."""
    predicted = get_prediction(
        source_state=transition.source_state,
        action=transition.action,
        model=model,
        dt=dt,)
    error = (predicted - transition.next_state).norm()
    return error.item()


def get_input_output_size(
        transition: Transition,
        ) -> tuple[int, int]:
    """Return the input and output size of a model that models transitions
    like the one given."""
    state_size = len(transition.source_state)
    action_size = len(transition.action)
    input_size = state_size + action_size
    output_size = state_size
    return input_size, output_size
