from typing import Optional

import numpy as np

from src.datamodules.components.ode_datamodule import ODEDataset, ODEDataModule
from src.utils.random import temp_seed


class MassSpringDataset(ODEDataset):
    """
    1D Mass Spring System with DF=2.
    -k1-m1-k2-m2-k3
    """

    def _get_params(self, param_range):
        # m1, m2, k1, k2, k3
        params = np.random.uniform(low=self.param_range[0], high=self.param_range[1],
                                   size=(self.n_envs, 5))
        return params

    def _get_init_cond(self, env_index):
        with temp_seed(env_index):
            # [pos1, pos2, vel1, vel2]
            return np.random.uniform(low=-2.5, high=2.5, size=4)

    def _f(self, t, x, env_index):
        m1, m2, k1, k2, k3 = self.params[env_index]
        pos, vel = x[:2], x[2:]
        K = np.array([[-(k1 + k2) / m1, k2 / m1],
                      [k2 / m2, -(k2 + k3) / m2]])

        d = np.zeros(4)
        d[:2] = vel
        d[2:] = K @ pos  # acc
        return d


class MassSpringDataModule(ODEDataModule):

    def setup(self, stage: Optional[str] = None):
        self.train_dataset = MassSpringDataset(**self.train_dataset_params)
        self.val_dataset = MassSpringDataset(**self.val_dataset_params)
        self.test_dataset = MassSpringDataset(**self.test_dataset_params)
