from typing import Dict, Optional

import numpy as np

from src.datamodules.components.ode_datamodule import ODEDataset, ODEDataModule
from src.utils.random import temp_seed


class GlycolyticOscillatorDataset(ODEDataset):

    def _get_params(self, param_range: Dict) -> np.array:
        # J0, k1, k2, k3, k4, k5, k6, K1, q, N, A, kappa, psi, k
        # following the CoDA's parameter setting

        params = np.zeros(shape=(self.n_envs, 14))
        params[:, 0] = 2.5  # J0
        params[:, 1] = np.random.randint(param_range['k1'][0], param_range['k1'][1] + 1, self.n_envs)  # k1
        params[:, 2] = 6  # k2
        params[:, 3] = 16  # k3
        params[:, 4] = 100  # k4
        params[:, 5] = 1.28  # k5
        params[:, 6] = 12  # k6
        params[:, 7] = np.random.uniform(param_range['K1'][0], param_range['K1'][1], self.n_envs)  # K1
        params[:, 8] = 4  # q
        params[:, 9] = 1  # N
        params[:, 10] = 4  # A
        params[:, 11] = 13  # kappa
        params[:, 12] = 0.1  # psi
        params[:, 13] = 1.8  # k
        return params

    def _f(self, t, x, env_index):
        J0, k1, k2, k3, k4, k5, k6, K1, q, N, A, kappa, psi, k = self.params[env_index]

        d = np.zeros(7)
        k1s1s6 = k1 * x[0] * x[5] / (1 + (x[5] / K1) ** q)
        d[0] = J0 - k1s1s6
        d[1] = 2 * k1s1s6 - k2 * x[1] * (N - x[4]) - k6 * x[1] * x[4]
        d[2] = k2 * x[1] * (N - x[4]) - k3 * x[2] * (A - x[5])
        d[3] = k3 * x[2] * (A - x[5]) - k4 * x[3] * x[4] - kappa * (x[3] - x[6])
        d[4] = k2 * x[1] * (N - x[4]) - k4 * x[3] * x[4] - k6 * x[1] * x[4]
        d[5] = -2 * k1s1s6 + 2 * k3 * x[2] * (A - x[5]) - k5 * x[5]
        d[6] = psi * kappa * (x[3] - x[6]) - k * x[6]
        return d

    def _get_init_cond(self, env_index):
        # following the CoDA's setting
        ic_range = [(0.15, 1.60), (0.19, 2.16), (0.04, 0.20), (0.10, 0.35), (0.08, 0.30), (0.14, 2.67), (0.05, 0.10)]
        with temp_seed(env_index):
            return np.random.random(7) * np.array([b - a for a, b in ic_range]) + np.array([a for a, _ in ic_range])


class GlycolyticOscillatorDataModule(ODEDataModule):

    def setup(self, stage: Optional[str] = None):
        self.train_dataset = GlycolyticOscillatorDataset(**self.train_dataset_params)
        self.val_dataset = GlycolyticOscillatorDataset(**self.val_dataset_params)
        self.test_dataset = GlycolyticOscillatorDataset(**self.test_dataset_params)


def get_coda_params(train: bool):
    import itertools

    if train:
        k1 = [100, 90, 80]
        K1 = [1, 0.75, 0.5]
    else:  # test
        k1 = [85, 95]
        K1 = [0.625, 0.875]

    k1_K1 = itertools.product(k1, K1)
    k1_K1 = np.array(list(k1_K1))

    params = np.zeros(shape=(k1_K1.shape[0], 14))
    params[:, 0] = 2.5  # J0
    params[:, 1] = k1_K1[:, 0]
    params[:, 2] = 6  # k2
    params[:, 3] = 16  # k3
    params[:, 4] = 100  # k4
    params[:, 5] = 1.28  # k5
    params[:, 6] = 12  # k6
    params[:, 7] = k1_K1[:, 1]
    params[:, 8] = 4  # q
    params[:, 9] = 1  # N
    params[:, 10] = 4  # A
    params[:, 11] = 13  # kappa
    params[:, 12] = 0.1  # psi
    params[:, 13] = 1.8  # k
    return params
