from typing import Optional

import numpy as np

from src.datamodules.components.ode_datamodule import ODEDataset, ODEDataModule
from src.utils.random import temp_seed


class LotkaVolterraDataset(ODEDataset):

    def _get_params(self, param_range):
        # alpha, beta, gamma, delta of the VK equation
        params = np.random.uniform(low=self.param_range[0], high=self.param_range[1],
                                   size=(self.n_envs, 4))
        # following CoDA's dataset implementation
        params[:, 0] = 0.5
        params[:, 2] = 0.5
        return params

    def _get_init_cond(self, env_index):
        with temp_seed(env_index):
            # follows the CoDA dataset generation routine.
            return np.random.uniform(low=1, high=2, size=2)

    def _f(self, t, x, env_index):
        alpha, beta, gamma, delta = self.params[env_index]
        d = np.zeros(2)
        d[0] = alpha * x[0] - beta * x[0] * x[1]
        d[1] = delta * x[0] * x[1] - gamma * x[1]
        return d


class LotkaVolterraDataModule(ODEDataModule):

    def setup(self, stage: Optional[str] = None):
        self.train_dataset = LotkaVolterraDataset(**self.train_dataset_params)
        self.val_dataset = LotkaVolterraDataset(**self.val_dataset_params)
        self.test_dataset = LotkaVolterraDataset(**self.test_dataset_params)


if __name__ == '__main__':
    ds_params = {'bundling_k':20, 'T':50.0, 'dt':0.5, 'rollout':True, 'param_range': [ 0.5, 1.0 ]}
    dm = LotkaVolterraDataModule(num_train_envs=10, num_test_envs=10, 
                                train_dataset_params=ds_params, 
                                test_dataset_params=ds_params)
    dm.setup()
    cur, prev, target, t, env_idx, param = next(iter(dm.train_dataloader()))
    print(cur.shape, prev.shape, target.shape)
