from typing import Tuple

import numpy as np

from icpe.feat import Feature
from icpe.MRP.cartpole import CartPole
from icpe.MRP.mrp import MRP
from joblib import Parallel, delayed


def simulate(mrp: MRP, num: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    '''
    param mrp: MRP
    param num: int
    return: Tuple[np.ndarray, np.ndarray]
    simulate the MRP for num steps
    return the state indices for current states, next states and rewards
    '''
    states = []
    rewards = []

    s = mrp.reset()
    states.append(s)
    for _ in range(num):
        s, r = mrp.step(s)
        states.append(s)
        rewards.append(r)
    states = np.array(states, dtype=np.int32)
    rewards = np.array(rewards, dtype=np.float32)
    return states[:-1], states[1:], rewards


class ContextGenerator:
    def __init__(self, mrp: MRP, gamma: float, feat: Feature):
        self.mrp = mrp
        self.gamma = gamma
        self.feat = feat

    def _get_expected(self, states: np.ndarray) -> np.ndarray:
        '''
        param states: state indices
        '''
        return self.mrp.P[states] @ self.feat.phi

    def generate_context(self, n: int, b: int, enumerated: bool, expected: bool) -> np.ndarray:
        '''
        param n: context length
        param b: batch size
        param enumerated: whether to use enumerated states
        param expected: whether to use expected next feature
        return: numpy array of shape (b, 2*d+1, n)
        generate a batch of training data
        '''
        assert not (enumerated and isinstance(self.mrp, CartPole))
        assert not (expected and isinstance(self.mrp, CartPole))
        if enumerated:
            assert n == self.mrp.n_states

        def make_one() -> np.ndarray:
            '''
            return: context of shape (2*d+1, n)
            '''
            if enumerated:
                top = self.feat.phi  # (n, d)
                if expected:
                    bottom = self.gamma * self._get_expected(np.arange(n))
                    rewards = self.mrp.r.reshape((1, n))
                else:  # sample next states and rewards
                    bottom = []
                    rewards = []
                    for i in range(n):
                        s, r = self.mrp.step(i)
                        bottom.append(self.feat(self.mrp.get_feature_index(s)))
                        rewards.append(r)
                    bottom = self.gamma * np.array(bottom)
                    rewards = np.array(rewards).reshape((1, n))
            else:  # sampled top
                if expected:
                    states, _, _ = simulate(self.mrp, n)
                    state_indices = self.mrp.get_feature_index(states)
                    top = self.feat(state_indices)
                    bottom = self.gamma * self._get_expected(state_indices)
                    rewards = self.mrp.r[state_indices].reshape((1, n))
                else:  # sample next states and rewards
                    states, next_states, rewards = simulate(self.mrp, n)
                    state_indices = self.mrp.get_feature_index(states)
                    next_state_indices = self.mrp.get_feature_index(
                        next_states)
                    top = self.feat(state_indices)
                    bottom = self.gamma * self.feat(next_state_indices)
                    rewards = rewards.reshape((1, n))
            return np.concatenate((top.T, bottom.T, rewards), axis=0)

        ctxts = Parallel(n_jobs=-1)(delayed(make_one)() for _ in range(b))
        ctxts = np.stack(ctxts, axis=0)
        return ctxts


def pro(ctxt: np.ndarray, x_q: np.ndarray) -> np.ndarray:
    '''
    param ctxt: context of shape (2*d+1, n)
    param x_q: query of shape (d, )
    return: the prompt of shape (2*d+1, n)
    '''
    assert ctxt.shape[0] == 2 * x_q.shape[0] + 1

    pro = np.concatenate((ctxt, np.zeros((ctxt.shape[0], 1))), axis=1)
    pro[:x_q.shape[0], -1] = x_q
    return pro
