from typing import List, Optional

import numpy as np
import torch
from torch.utils.data import Dataset

from src.datamodules.components.ode_datamodule import ODEDataModule
from src.utils import pylogger
from src.utils.random import temp_seed

log = pylogger.get_pylogger(__name__)


class PolynomialDataset(Dataset):

    def __init__(self,
                 n_envs: int,
                 n_samples_per_env: 20,
                 order: int = 4,
                 param_range: List[float] = [0.1, 1.5],
                 params: np.array = None,
                 x_range: List[float] = [-1.0, 1.0],
                 bias: bool = True,
                 seed: int = 1234,
                 dominant_only: bool = False):
        assert param_range is not None or params is not None, "Either 'param_range' or 'params' must be given."
        if param_range is None:
            n_envs = params.shape[0]

        super().__init__()
        self.n_envs = n_envs
        self.n_samples_per_env = n_samples_per_env
        self.order = order
        self.param_range = param_range
        if params is None:
            # set random seed locally. Assure that dataset always generate the
            # same envs given the random seed
            with temp_seed(seed):
                params = torch.FloatTensor(n_envs, order + 1).uniform_(*param_range)
        else:
            params = torch.from_numpy(params).float()
            self.n_envs = params.shape[0]
            self.order = params.shape[1]

        # ablation study purpose
        if not bias:
            log.info(f"Disable the bias terms of the polynomial coefficients.")
            params[:, -1] = 0.0

        if dominant_only:
            log.info(f"Disable the non-dominant coefficients of the polynomial coefficients.")
            params[:, 1:] = 0.0
        self.params = params

        self.cache = {}

        self.x_range = x_range
        self.seed = seed
        self.len = n_envs

    def __getitem__(self, index):
        if self.cache.get(index) is None:
            with temp_seed(index):
                x = torch.FloatTensor(self.n_samples_per_env, 1).uniform_(*self.x_range)  # [n samples, 1]

            input = torch.cat([x ** i for i in reversed(range(self.order + 1))], dim=-1)  # [n samples, order + 1]
            param = self.params[index].view(1, -1)  # [1, order + 1]
            y = (param * input).sum(-1).view(-1, 1)  # [n samples, 1]
            result = {'x': x, 'y': y, 'param': param, 'index': index}
            self.cache[index] = result
        else:
            result = self.cache[index]
            x, y, param = result['x'], result['y'], result['param']

        ret = (x, y, param, index)
        return ret

    def __len__(self):
        return self.len


class PolynomialDataModule(ODEDataModule):

    def setup(self, stage: Optional[str] = None):
        self.train_dataset = PolynomialDataset(**self.train_dataset_params)
        self.val_dataset = PolynomialDataset(**self.val_dataset_params)
        self.test_dataset = PolynomialDataset(**self.test_dataset_params)
