import os
from omegaconf.listconfig import ListConfig

import numpy as np
import torch
from torch.utils.data import Dataset

from src.utils import pylogger


log = pylogger.get_pylogger(__name__)


class PDEDataset(Dataset):
    """
    Base PDE dataset class
    Args:
        size (int): size of the 2D grid
        t_horizon (float): total time
        n_data_per_env (int): number of data points per environment
        n_envs (int): number of environments
        dt_eval (float): time step for evaluation
        dt_int (float): time step for integration
        n_states (int): number of states
        params (dict): parameters for the PDE
        cache_file (str): path to the cache file
        pushforward_n (int): number of pushforward steps
        rollout (bool): whether to rollout or not
        bundling_k (int): bundling factor
        method (str): integration method
        group (str): train/test
        create_dataset (bool): whether to create the dataset or not
        cpus_usage (float): fraction of cpus to use
    """
    def __init__(self, 
                size=64, 
                t_horizon=1.0,
                n_data_per_env=10, 
                n_envs=1,
                dt_eval=0.1,
                dt_int=1e-3,
                n_states=1, 
                params=None, 
                cache_file=None, 
                pushforward_n=2, 
                rollout=True, 
                bundling_k=1,
                method='RK45', 
                group='train', 
                create_dataset=False, 
                cpus_usage=0.5, 
                use_ray=False,
                **unused_kwargs):

        super().__init__()
        if unused_kwargs != {}:
            log.warning(f'Unused kwargs: {unused_kwargs}')
        self.size = int(size)  # size of the 2D grid
        if isinstance(params, str):
            # Execute command of params and create list
            params = eval(params)
            params = [{'f': 0.1, 'visc': v, 'ood': v<8e-4 or v>1.2e-3} for v in params]
        else:
            params = params if params is not None else self.default_params

        if isinstance(params, list) or isinstance(params, ListConfig):
            if len(params) == 1:
                if isinstance(params[0], list):

                    self.params_eq = params[0]
                else:
                    self.params_eq = params
            else:
                # log.warning("Manually fixing the params_eq, may want to modify this code!")
                self.params_eq = params
        else:
            self.params_eq = [params]
        self.params_eq = self.params_eq
        self.n_data_per_env = n_data_per_env
        self.n_envs = n_envs
        self.t_horizon = float(t_horizon)  # total time
        self.n = int(t_horizon / dt_eval)  # number of iterations
        self.dt_eval = dt_eval
        self.dt_int = dt_int
        self.len = self.n_envs * self.n_data_per_env
        self.test = (group == 'test')
        self.n_states = n_states
        self.max = np.iinfo(np.int32).max
        self.method = method
        self.rollout = rollout
        self.pushforward_n = pushforward_n
        self.bundling_k = bundling_k * n_states # need to be multiplied by states; these become channels
        self.cache = {}
        self.cache_file = cache_file
        self.create_dataset_flag = create_dataset
        self.cpus_usage = cpus_usage
        self.use_ray = use_ray
 
    def __getitem__(self, index):
        if self.rollout:
            return self._get_item_rollout(index)
        else:
            return self._get_item_pf(index)
    
    def __len__(self):
        return self.len

    def load_dataset(self, *unused_args, **unused_kwargs):
        if self.cache_file is not None:
            if os.path.exists(self.cache_file):
                log.info(f'Data found. Delete {self.cache_file} and re-run with create_dataset=True to regenerate data.')
                self.cache = np.load(self.cache_file, allow_pickle=True).item()
            else:
                log.warning(f'File {self.cache_file} does not exist. Create dataset first!')
                if self.create_dataset_flag:
                    self.create_dataset(self.cache_file)

    def _get_item_pf(self, index):
        # NOTE: for simplicity, we assume that each element is given by env_index and cond_index. Read above
        # env_idx, time_stamp = np.divmod(index, self.samples_per_traj)
        env_idx = index 
        time_stamp = 0
        n, k = self.pushforward_n, self.bundling_k
        result = self._get_element(env_idx)
        state = result['state'] # [T, state]
        prev = state[time_stamp : time_stamp + k, ...]  # ctx inference input
        cur = state[time_stamp + k : time_stamp + 2*k, ...]  # current obs
        target = state[time_stamp + 2*k : time_stamp + 3*k, ...]
        pf_target = state[time_stamp + (n + 1) * k: time_stamp + (n + 2) * k, ...]
        ret = (prev, cur, target, pf_target,
               result['t'],
               result['env_idx'],
               result['param'], 
               result['ood'])
        return ret

    def _get_item_rollout(self, index):
        k = self.bundling_k
        env_idx = index
        result = self._get_element(env_idx)
        state = result['state'] # [T, state] # removed the transpose
        prev = state[0 : k, ...]  # ctx inference input
        cur = state[k : 2*k, ...]  # current obs
        target = state[2*k :, ...]
        return prev, cur, target, result['t'], result['env_idx'], result['param'], result['ood']

    def _get_element(self, index):
        raise NotImplementedError("Implement this in the child class!")

    def create_dataset(self, cache_file):
        """
        Iterate over the dataset: if an index is not found, create data
        Can either use ray or multiprocessing; ray seems to be more suited for the task
        """
        log.info("Creating dataset") # use multiprocessing
        
        if self.use_ray:
            import psutil
            import ray
            ray.init(num_cpus=int(self.cpus_usage * psutil.cpu_count(logical=False)), ignore_reinit_error=True)
            
            @ray.remote
            def _func(index):
                return self._get_element(index)
            
            results = ray.get([_func.remote(i) for i in range(self.len)])
            ray.shutdown()
        else:
            import multiprocessing as mp
            pool = mp.Pool(int(self.cpus_usage * mp.cpu_count()))
            results = pool.map(self._get_element, range(self.len))
            pool.close(); pool.join()

        results = {i: result for i, result in enumerate(results)} # convert to dict
        if cache_file is not None:
            np.save(cache_file, results)
            log.info('Dataset created!')
        self.cache = results