from functools import partial
from typing import List, Dict
from typing import Optional

import numpy as np
import pytorch_lightning as pl
import torch
from scipy.integrate import solve_ivp
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from src.utils.random import temp_seed


class ODEDataset(Dataset):

    def __init__(self,
                 n_envs: int,
                 bundling_k: int,
                 T: float,
                 dt: float,
                 rollout: bool,
                 method: float = "RK45",
                 seed: int = 1234,
                 pushforward_n: int = 2,
                 param_range: List[float] or Dict = None,
                 params: np.array = None):
        super().__init__()

        assert param_range is not None or params is not None, "Either 'param_range' or 'params' must be given."
        if param_range is None:
            n_envs = params.shape[0]
        self.n_envs = n_envs
        self.param_range = param_range
        self.T = T
        self.dt = dt
        self.method = method
        self.random_seed = seed

        if params is None:
            # set random seed locally. Assure that dataset always generate the
            # same envs given the random seed
            with temp_seed(self.random_seed):
                self.params = self._get_params(param_range)
        else:
            self.params = params

        self.cache = {}  # the key will be the environment index

        self.rollout = rollout
        if self.rollout:
            self.len = n_envs
        else:
            self.samples_per_traj = int(T / dt) - (bundling_k + (pushforward_n + 1) * bundling_k) + 1
            self.len = n_envs * self.samples_per_traj

        self.pushforward_n = pushforward_n
        self.bundling_k = bundling_k

    def _get_params(self, param_range):
        raise NotImplementedError("must be implemented in the child class")

    def _get_init_cond(self, env_index):
        raise NotImplementedError("must be implemented in the child class")

    def _f(self, t, x, env_idx):
        raise NotImplementedError("must be implemented in the child class")

    def __getitem__(self, index):
        """
        index: (int) the key for accessing data.
        """

        if self.rollout:
            return self._get_item_rollout(index)
        else:
            return self._get_item_pf(index)

    def _get_item_pf(self, index):
        env_idx, time_stamp = np.divmod(index, self.samples_per_traj)
        n, k = self.pushforward_n, self.bundling_k

        if self.cache.get(env_idx) is None:
            self.cache[env_idx] = result = self._solve_ivp(env_idx)
        else:
            result = self.cache[env_idx]

        state = result['state'].T  # [T, state]
        prev = state[time_stamp: time_stamp + k, ...]  # ctx inference input
        cur = state[time_stamp + k: time_stamp + 2 * k, ...]  # current obs
        target = state[time_stamp + 2 * k: time_stamp + 3 * k, ...]
        pf_target = state[time_stamp + (n + 1) * k: time_stamp + (n + 2) * k, ...]

        ret = (prev, cur, target, pf_target,
               result['t'],
               result['env_idx'],
               result['param'])
        return ret

    def _get_item_rollout(self, index):
        k = self.bundling_k

        env_idx = index
        if self.cache.get(env_idx) is None:
            self.cache[env_idx] = result = self._solve_ivp(env_idx)
        else:
            result = self.cache[env_idx]
        state = result['state'].T  # [T, state]
        prev = state[:k, ...]  # ctx inference input
        cur = state[k: 2 * k, ...]  # current obs
        target = state[2 * k:, ...]
        return prev, cur, target, result['t'], result['env_idx'], result['param']

    def _solve_ivp(self, env_index):
        y0 = self._get_init_cond(env_index)
        y = solve_ivp(partial(self._f, env_index=env_index), (0.0, self.T),
                      y0=y0, method=self.method,
                      t_eval=np.arange(0.0, self.T, self.dt))
        result = {}
        result['state'] = torch.from_numpy(y.y).float()  # [state, T]
        result['t'] = torch.arange(0, self.T, self.dt).float()
        result['env_idx'] = env_index
        result['param'] = torch.from_numpy(self.params[env_index]).float()
        return result

    def __len__(self):
        return self.len


class ODEDataModule(pl.LightningDataModule):

    def __init__(self,
                 num_train_envs: int,
                 num_test_envs: int,
                 train_dataset_params: dict = None,
                 val_dataset_params: dict = None,
                 test_dataset_params: dict = None,
                 batch_size: int = 32,
                 num_workers: int = None,
                 pin_memory: bool = True):
        super().__init__()

        self.batch_size = batch_size

        if num_workers is None:
            import os
            num_workers = int(0.2 * os.cpu_count())
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        if train_dataset_params is None:
            train_dataset_params = dict()
            train_dataset_params['n_envs'] = num_train_envs
            train_dataset_params['seed'] = 1234
        else:
            train_dataset_params['n_envs'] = num_train_envs
        self.train_dataset_params = train_dataset_params

        if test_dataset_params is None:
            test_dataset_params = dict()
            test_dataset_params['n_envs'] = num_test_envs
            test_dataset_params['seed'] = 4321
            test_dataset_params['rollout'] = True
        else:
            test_dataset_params['n_envs'] = num_test_envs

        self.test_dataset_params = test_dataset_params

        if val_dataset_params is None:
            val_dataset_params = test_dataset_params
        else:
            val_dataset_params['n_envs'] = num_test_envs
        self.val_dataset_params = val_dataset_params

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def setup(self, stage: Optional[str] = None):
        raise NotImplementedError

    def _get_loader(self, dataset, is_train: bool):
        dataloader_params = {
            'dataset': dataset,
            'batch_size': self.batch_size,
            'num_workers': self.num_workers,
            'pin_memory': self.pin_memory,
            'shuffle': is_train,
        }
        return DataLoader(**dataloader_params)

    def train_dataloader(self):
        return self._get_loader(self.train_dataset, is_train=True)

    def val_dataloader(self):
        return self._get_loader(self.val_dataset, is_train=False)

    def test_dataloader(self):
        return self._get_loader(self.test_dataset, is_train=False)