import multiprocessing
import random
from functools import partial
from typing import Optional, Union

import numpy as np
import torch

from core.abstract_mdp import AbstractMDP
from core.data import TransitionData, AbstractStateDict
from core.utils import show, run_parallel
from envs import ConditionalActionEnv


def collect_data(env: ConditionalActionEnv,
                 max_timestep=1_000_000,
                 max_episode=1_000_000,
                 max_timestep_per_ep=1_000_000,
                 verbose=False,
                 seed=None,
                 n_jobs=1,
                 mdp=None,
                 state_buffer=None,
                 **kwargs) -> Union[TransitionData, AbstractStateDict]:
    """
    Collect data from the environment through uniform random exploration in parallel

    :param env: the environment
    :param max_timestep: the maximum number of timesteps in total (not to be confused with maximum time steps per
    episode) Default is infinity
    :param max_episode: the maximum number of episodes. Default is infinity
    :param max_timestep_per_ep: the maximum number of timesteps per episode. Default is infinity
    :param verbose: whether to print additional information
    :param seed: the random seed. Use for reproducibility
    :param n_jobs: the number of processes to spawn to collect data in parallel. If -1, use all CPUs
    :return: data frame holding transition data
    """

    if max_timestep == np.inf and max_episode == np.inf:
        raise ValueError('Must specify at least a maximum timestep or episode limit')

    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    if n_jobs == -1:
        n_jobs = multiprocessing.cpu_count()

    # run collection in parallel
    max_timestep /= n_jobs
    max_episode /= n_jobs

    functions = [
        partial(_collect_data, env, np.random.randint(0, 1_000_000),
                max_timestep, max_episode, max_timestep_per_ep,
                verbose, int(max_episode * i), mdp, state_buffer, **kwargs)
        for i in range(n_jobs)]

    results = run_parallel(functions)
    if isinstance(results[0], TransitionData):
        transition_data = TransitionData.concat(results)
        return transition_data
    else:
        assert len(results) == 1
        state_buffer = results[0]
        state_buffer.flush_buffer()
        return state_buffer


def _collect_data(env: ConditionalActionEnv,
                  seed: Optional[int] = None,
                  max_timestep: int = 1_000_000,
                  max_episode: int = 1_000_000,
                  max_timestep_per_ep: int = 1_000_000,
                  verbose: bool = False,
                  episode_offset: int = 0,
                  mdp: Optional[AbstractMDP] = None,
                  state_buffer: Optional[AbstractStateDict] = None,
                  **kwargs) -> Union[TransitionData, AbstractStateDict]:

    transition_data = TransitionData()

    n_episode = 0
    n_timesteps = 0

    # set the seed
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    while n_episode < max_episode and n_timesteps < max_timestep:
        show('Running episode {}'.format(n_episode + episode_offset), verbose)
        state, inf = env.reset()
        pos = inf.get('position', [0, 0])
        prev_info = {"t": 0, "position": pos}
        done = False
        ep_timestep = 0
        while not done and n_timesteps < max_timestep and ep_timestep < max_timestep_per_ep:
            current_options = env.available_mask
            sbar = -1
            sbar_next = -1

            if mdp is not None:
                initialize = True if ep_timestep == 0 else False
                action, sbar = mdp.sample_action(env, initialize)
            else:
                exploration_policy = kwargs.get('exploration_policy', env.sample_action)
                action = exploration_policy()

            next_state, reward, terminated, truncated, inf = env.step(action)
            steps = inf.get('steps', 1)
            pos = inf.get('position', [0, 0])
            info = {"position": pos, "t": ep_timestep}
            n_timesteps += 1
            mask = np.where(np.array(state) != np.array(next_state))[0]  # calculate the state vars that change
            next_options = env.available_mask
            done = terminated or truncated

            if state_buffer is None:
                transition_data.add(n_episode + episode_offset, state, action,
                                    reward, next_state, done,
                                    mask, steps, current_options, next_options,
                                    prev_info, info)
            else:
                assert mdp is not None
                next_probs = mdp.get_grounding_prob(torch.tensor(next_state, dtype=torch.float), next_options)
                if next_probs is not None:
                    sbar_next = mdp.states[np.random.choice(len(next_probs), p=next_probs.numpy())]
                state_buffer.add_to_buffer(state, action, reward, next_state, done,
                                           steps, current_options, next_options, prev_info,
                                           info, sbar, sbar_next)

            ep_timestep += 1
            show('\tStep: {}'.format(ep_timestep), verbose and ep_timestep > 0 and ep_timestep % 50 == 0)
            state = next_state
            prev_info = info
        n_episode += 1

    if state_buffer is not None:
        return state_buffer

    return transition_data
