import gym
from typing import List
from envs.tabular_env import TabularEnv
from envs.dis_tabular_env import DiscountedTabularEnv
import numpy as np
import random
import scipy.signal


def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.

    input:
        vector x,
        [x0,
         x1,
         x2]

    output:
        [x0 + discount * x1 + discount^2 * x2,
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]


def cumulative_discounted(data_list: List, discount_factor: float):
    res = 0
    for i in range(len(data_list)):
        res += data_list[i] * np.float_power(discount_factor, i)

    return res


def sample_one_trajetory(env: gym.Env, policy: np.ndarray, env_type='Episodic', is_deterministic=False):
    """
    Args:
         env: the env to interact
         policy: the sample policy, for Episodic env, its shape is (ns, na, H), for Discounted env, its shape is (ns, na)
         env_type: 'Episodic' or 'Discounted'.
         is_deterministic:
    """
    trajectory = []
    rews = []
    obs = env.reset()
    step = 0
    assert env_type in ['Episodic', 'Discounted'], 'Invalid environment type.'

    while True:
        if env_type == 'Episodic':
            prob = policy[obs, :, step]
        else:
            prob = policy[obs, :]
        if is_deterministic:
            action = np.argmax(prob)
        else:
            action = np.random.choice(a=env.action_space.n, p=prob)
        trajectory.append((obs, action, step))
        next_obs, reward, done, _ = env.step(action)
        step += 1
        rews.append(reward)
        if done:
            break
        obs = next_obs
    if env_type == 'Episodic':
        ret = sum(rews)
    else:
        gamma = env._gamma
        rets = discount_cumsum(rews, gamma)
        ret = rets[0]
    return trajectory, ret


def sample_dataset_per_traj(env: gym.Env, policy: np.ndarray, num_traj: int, is_deterministic=False):
    """

    Args:
        env: the environment.
        policy: the policy, numpy array with shape [n_state, n_action, H]
        num_traj: the number of data to collect
        is_deterministic: take the deterministic action or not

    Returns:
        dataset: [[(state, action, step)]], len(dataset) = num_traj.
    """
    dataset = []
    rets = []
    for _ in range(num_traj):
        trajectory, ret = sample_one_trajetory(env, policy, 'Episodic', is_deterministic)
        rets.append(ret)
        for i in trajectory:
            dataset.append(i)
        # dataset.append(trajectory)
    avg_ret = sum(rets) / num_traj
    print('Collect {} trajectories and average return is {}'.format(num_traj, avg_ret))
    return dataset

def sample_dataset(env: gym.Env, policy: np.ndarray, num_data: int, is_deterministic: bool) -> List[tuple]:
    """

    Args:
        env: the environment.
        policy: the policy, numpy array with shape [n_state, n_action, H]
        num_data: the number of data to collect
        is_deterministic: take the deterministic action or not

    Returns:
        data_set: [(state, action, step)]
    """
    n_state, n_action, H = env.observation_space.n, env.action_space.n, env._max_episode_steps
    num_trajectories = int(num_data / H) + 1
    all_data_set = []
    rets = []
    for _ in range(num_trajectories):
        trajectory, ret = sample_one_trajetory(env=env, policy=policy, is_deterministic=is_deterministic)
        rets.append(ret)
        all_data_set.extend(trajectory)
    avg_ret = sum(rets) / num_trajectories
    print('Collect {} trajectories and average return is {}'.format(num_trajectories, avg_ret))
    # random.shuffle(all_data_set)
    data_set = all_data_set[: num_data]

    return data_set


def sample_dataset_from_distribution(state_dist: np.ndarray, expert_policy: np.ndarray, num_samples: int):
    """
    Collect expert demonstrations directly from the discounted stationary state distribution.
    Args:
        state_dist: the state distribution, a numpy array with shape (ns).
        expert_policy: the expert policy, a numpy array with shape (ns, na).
        num_samples: the size of dataset.
    Returns:
        dataset: a list of tuples, [(state, action)].
        unique_states: the unique states in dataset.
    """
    num_state = state_dist.shape[0]
    all_sampled_states = np.random.choice(a=num_state, size=num_samples, p=state_dist)
    all_action_dists = expert_policy[all_sampled_states, :]
    cum_all_action_dists = all_action_dists.cumsum(axis=1)
    u = np.random.rand(num_samples, 1)
    all_sampled_actions = (u < cum_all_action_dists).argmax(axis=1)
    unique_states = np.unique(all_sampled_states)
    dataset = list(zip(all_sampled_states, all_sampled_actions))

    return dataset, unique_states

def get_next_state(transition_prob:np.ndarray,dataset:List[tuple]):
    """Sample from the next state distribution given (s,a)

    Args:
        transition_prob (np.ndarray): transition probability with shape [dim_state,dim_action,dim_state], where the last dim stands for s_next
        dataset (List[tuple]): list of (s,a)
    
    Returns:
        dataset (List[tuple]): list of (s,s_next)
    """
    next_s_prob = transition_prob[tuple(zip(*dataset))]
    next_s_cum = next_s_prob.cumsum(axis=1)
    u = np.random.rand(len(dataset), 1)
    all_sampled_next_s = (u < next_s_cum).argmax(axis=1,keepdims=True)
    dataset = np.concatenate([np.array(dataset)[:,0].reshape((-1,1)),all_sampled_next_s],axis=1)
    
    return dataset


def evaluate(env: gym.Env, policy: np.ndarray, num_trajectories: int, env_type='Episodic', is_deterministic=False):
    """ Evaluate the value of the learned policy by MC.
    Args:
        env: the env to interact
        policy: numpy array with shape (n_state, n_action, H)
        num_trajectories: the number of trajectories to sample
        env_type: the type of environment.
        is_deterministic:
    """
    mean_ret = 0.0
    mean_length = 0.0
    for num in range(num_trajectories):
        traj, ret = sample_one_trajetory(env, policy, env_type, is_deterministic)
        mean_length += len(traj)
        mean_ret += ret
    mean_ret /= num_trajectories
    mean_length /= num_trajectories

    return mean_ret, mean_length


def estimate_occupancy_measure(env: TabularEnv, policy: np.ndarray, num_trajectories: int, is_deterministic=False):
    """ Estimate the occupancy measure of the policy via MC.
    Args:
        env: the env to interact
        policy: numpy array with shape (n_state, n_action)
        num_trajectories: the number of trajectories to sample
        is_deterministic:
    Returns:
        rho: the estimated occupancy measure, numpy array with shape (n_state, n_action, H).
    """
    n_state, n_action, H = env.observation_space.n, env.action_space.n, env._max_episode_steps
    rho = np.zeros(shape=[n_state, n_action, H], dtype=np.float32)
    for _ in range(num_trajectories):
        traj, ret = sample_one_trajetory(env, policy, 'Episode', is_deterministic)
        for each_tuple in traj:
            state, action, step = each_tuple[0], each_tuple[1], each_tuple[2]
            rho[state, action, step] += 1.0
    normalizer = np.sum(rho, axis=(0, 1), keepdims=True)

    for h in range(H):
        if normalizer[0, 0, h] == 0:
            rho[:, :, h] = 1.0 / (n_state * n_action)
        else:
            rho[:, :, h] = rho[:, :, h] / normalizer[:, :, h]
    # rho = rho / normalizer

    return rho


def estimate_discounted_occupancy_measure(env: DiscountedTabularEnv, policy: np.ndarray, num_trajectories: int,
                                          is_deterministic=False):
    """
    Estimate the discounted occupancy measure.
    Args:
        env: the environment.
        policy: the stationary policy, a numpy array with shape [ns, na].
        num_trajectories: the number of trajectories.
        is_deterministic:
    """
    num_state, num_action, gamma = env.observation_space.n, env.action_space.n, env._gamma
    rho = np.zeros(shape=[num_state, num_action], dtype=np.float32)

    for t in range(num_trajectories):
        traj, ret = sample_one_trajetory(env, policy, 'Discounted', is_deterministic)
        for each_tuple in traj:
            state, action, step = each_tuple[0], each_tuple[1], each_tuple[2]
            rho[state, action] += np.float_power(gamma, step)

    normalizer = float(np.sum(rho))
    rho = rho / normalizer

    return rho


def estimate_dis_occupancy_measure_from_data(num_state: int, num_action: int, dataset: List) -> np.ndarray:
    """
    Estimate the discounted occupancy measure directly from dataset. Note that the dataset is drawn from the stationary
    discounted state-action distribution. We do not need to re-weight it with gamma^t.
    Args:
         num_state: # S
         num_action: # A
         dataset: [(s, a)]
    Return:
        dis_occ: the discounted occupancy measure, numpy array with shape (num_state, num_action).
    """
    num_state = num_state
    num_action = num_action
    dis_occ = np.zeros(shape=(num_state, num_action), dtype=np.float32)
    for each_tuple in dataset:
        state, action = each_tuple[0], each_tuple[1]
        dis_occ[state, action] += 1.0

    normalizer = len(dataset)
    dis_occ = dis_occ / normalizer
    return dis_occ


def estimate_occupancy_measure_from_data(num_state: int, num_action: int, max_episode_steps: int, dataset: List[tuple]):

    """
    Estimate occupancy measure from expert demonstrations.
    Args:
        num_state: # S
        num_action: # A
        max_episode_steps: H
        dataset: [(state, action, step)]
    Returns:
        occupancy_measure: numpy array with shape [dim_state, dim_action, max_episode_steps]
    """
    n_state = num_state
    n_action = num_action
    max_episode_steps = max_episode_steps
    occupancy_measure = np.zeros(shape=(n_state, n_action, max_episode_steps),
                                        dtype=np.float32)
    
    for each_tuple in dataset:
        
        state, action, step = each_tuple[0], each_tuple[1], each_tuple[2]
        # print(max_episode_steps,state,action,step)
        occupancy_measure[state, action, step] += 1.0
    normalizer = np.sum(occupancy_measure, axis=(0, 1), keepdims=True)

    for h in range(max_episode_steps):
        # if the data at step h is missing, then estimate rho_h(s, a) as uniform distribution.
        if normalizer[0, 0, h] == 0:
            occupancy_measure[:, :, h] = 1.0 / (n_state * n_action)
        else:
            occupancy_measure[:, :, h] = occupancy_measure[:, :, h] / normalizer[0, 0, h]

    return occupancy_measure


def get_optimal_policy(num_state: int, num_action: int, optimal_action: int, env_id: str):
    if env_id == 'CliffWalking':
        M, N = num_state, num_action
        action_dis = np.zeros(shape=[N], dtype=np.float64)
        action_dis[optimal_action] = 1.0
        optimal_policy = np.tile(action_dis, reps=(M, 1))
    elif env_id == 'Bandit':
        optimal_policy = np.zeros(shape=(num_state, num_action), dtype=np.float32)
        optimal_policy[:, optimal_action] = 1.0
    else:
        raise ValueError('%s is not supported' % env_id)
    return optimal_policy








