import numpy as np
import torch

import lfrl.torch.pytorch_util as ptu


def _default_prediction(dynamics_model, state_actions, terminal_cutoff=0.5):
    transitions = dynamics_model.sample(state_actions)
    if (transitions != transitions).any():
        print('warning: nan transitions')
        transitions[transitions != transitions] = 0
    r = transitions[:,:1]
    d = (transitions[:,1:2] > terminal_cutoff).float()
    obs_delta = transitions[:,2:]
    return r, d, obs_delta


def model_policy_rollout_torch(
        dynamics_model,
        policy,
        start_states,
        max_path_length=1,
        terminal_cutoff=0.5,
        predict_transition=None,
):
    num_rollouts = start_states.shape[0]
    obs_dim = dynamics_model.obs_dim
    action_dim = dynamics_model.action_dim

    observations = ptu.zeros((num_rollouts, max_path_length+1, obs_dim))
    actions = ptu.zeros((num_rollouts, max_path_length, action_dim))
    rewards = ptu.zeros((num_rollouts, max_path_length, 1))
    terminals = ptu.zeros((num_rollouts, max_path_length+1, 1))

    observations[:,0] = ptu.from_numpy(start_states)
    for t in range(max_path_length):
        actions[:,t], *_ = policy.forward(observations[:,t])
        state_actions = torch.cat(
            [observations[:,t], actions[:,t]], dim=1)

        if predict_transition is None:
            r, d, obs_delta = _default_prediction(
                dynamics_model, state_actions, terminal_cutoff=terminal_cutoff
            )
        else:
            r, d, obs_delta = predict_transition(state_actions)

        rewards[:,t] = r
        terminals[:,t+1] = d
        observations[:,t+1] = observations[:,t] + obs_delta

    observations = ptu.get_numpy(observations)
    actions = ptu.get_numpy(actions)
    rewards = ptu.get_numpy(rewards)
    terminals = ptu.get_numpy(terminals)

    paths = []
    for i in range(num_rollouts):
        rollout_len = 0
        while rollout_len < max_path_length and \
              terminals[i,rollout_len,0] < 0.5:
            rollout_len += 1

        paths.append(dict(
            observations=observations[i,:rollout_len],
            actions=actions[i,:rollout_len],
            rewards=rewards[i,:rollout_len],
            next_observations=observations[i,1:rollout_len+1],
            terminals=terminals[i,1:rollout_len+1],
            agent_infos=[[] for _ in range(rollout_len)],
            env_infos=[[] for _ in range(rollout_len)],
        ))

    return paths


def model_policy_latent_rollout_torch(
        dynamics_model,
        policy,
        start_states,
        latents,
        max_path_length=1,
        terminal_cutoff=0.5,
        predict_transition=None,
):
    if terminal_cutoff is None:
        terminal_cutoff = 1e6

    num_rollouts = start_states.shape[0]
    obs_dim = dynamics_model.obs_dim
    action_dim = dynamics_model.action_dim

    observations = ptu.zeros((num_rollouts, max_path_length+1, obs_dim))
    actions = ptu.zeros((num_rollouts, max_path_length, action_dim))
    rewards = ptu.zeros((num_rollouts, max_path_length, 1))
    terminals = ptu.zeros((num_rollouts, max_path_length+1, 1))

    observations[:,0] = ptu.from_numpy(start_states)
    latents = ptu.from_numpy(latents)
    for t in range(max_path_length):
        state_latents = torch.cat((observations[:,t], latents), dim=1)
        actions[:,t], *_ = policy.forward(state_latents)
        state_actions = torch.cat(
            [observations[:,t], actions[:,t]], dim=1)

        if predict_transition is None:
            r, d, obs_delta = _default_prediction(
                dynamics_model, state_actions, terminal_cutoff=terminal_cutoff
            )
        else:
            r, d, obs_delta = predict_transition(state_actions)

        rewards[:,t] = r
        terminals[:,t+1] = d
        observations[:,t+1] = observations[:,t] + obs_delta

    observations = ptu.get_numpy(observations)
    actions = ptu.get_numpy(actions)
    rewards = ptu.get_numpy(rewards)
    terminals = ptu.get_numpy(terminals)

    paths = []
    for i in range(num_rollouts):
        rollout_len = 0
        while rollout_len < max_path_length and \
              terminals[i,rollout_len,0] < 0.5:
            rollout_len += 1

        paths.append(dict(
            observations=observations[i,:rollout_len],
            actions=actions[i,:rollout_len],
            rewards=rewards[i,:rollout_len],
            next_observations=observations[i,1:rollout_len+1],
            terminals=terminals[i,1:rollout_len+1],
            agent_infos=[[] for _ in range(rollout_len)],
            env_infos=[[] for _ in range(rollout_len)],
        ))

    return paths


def model_action_rollout_torch(
        dynamics_model,
        actions,
        start_states,
        max_path_length=1,
        terminal_cutoff=0.5,
):
    num_rollouts = start_states.shape[0]
    obs_dim = dynamics_model.obs_dim
    action_dim = dynamics_model.action_dim

    observations = ptu.zeros((num_rollouts, max_path_length+1, obs_dim))
    actions = ptu.from_numpy(actions)
    rewards = ptu.zeros((num_rollouts, max_path_length, 1))
    terminals = ptu.zeros((num_rollouts, max_path_length+1, 1))

    observations[:,0] = ptu.from_numpy(start_states)
    for t in range(max_path_length):
        state_actions = torch.cat(
            [observations[:,t], actions[:,t]], dim=1)
        transitions = dynamics_model.sample(state_actions)
        if (transitions != transitions).any():
            print('warning: nan transitions')
            transitions[transitions != transitions] = 0
        rewards[:,t] = transitions[:,:1]
        terminals[:,t+1] = (transitions[:,1:2] > terminal_cutoff).float()
        observations[:,t+1] = observations[:,t] + transitions[:,2:]

    observations = ptu.get_numpy(observations)
    actions = ptu.get_numpy(actions)
    rewards = ptu.get_numpy(rewards)
    terminals = ptu.get_numpy(terminals)

    paths = []
    for i in range(num_rollouts):
        rollout_len = 0
        while rollout_len < max_path_length and \
              terminals[i,rollout_len,0] < 0.5:
            rollout_len += 1

        paths.append(dict(
            observations=observations[i,:rollout_len],
            actions=actions[i,:rollout_len],
            rewards=rewards[i,:rollout_len],
            next_observations=observations[i,1:rollout_len+1],
            terminals=terminals[i,1:rollout_len+1],
            agent_infos=[[] for _ in range(rollout_len)],
            env_infos=[[] for _ in range(rollout_len)],
        ))

    return paths


def model_action_rollout_torch_online(
        dynamics_model,
        actions,
        start_states,
        max_path_length=1,
        gamma=.99,
        terminal_cutoff=0.5,
):
    num_rollouts = start_states.shape[0]
    obs_dim = dynamics_model.obs_dim
    action_dim = dynamics_model.action_dim

    actions = ptu.from_numpy(actions)
    returns = ptu.zeros(num_rollouts)
    terminals = ptu.zeros((num_rollouts, 1))
    path_lens = ptu.zeros(num_rollouts)

    observations = ptu.from_numpy(start_states)
    discount = 1
    for t in range(max_path_length):
        state_actions = torch.cat(
            [observations, actions[:,t]], dim=1)
        transitions = dynamics_model.sample(state_actions)
        if (transitions != transitions).any():
            print('warning: nan transitions')
            transitions[transitions != transitions] = 0
        returns += discount * (1 - terminals[:,0]) * transitions[:,0]
        discount *= gamma
        terminals = torch.max(
            (transitions[:,1:2] > terminal_cutoff).float(),
            terminals
        )
        observations = observations + transitions[:,2:]
        path_lens += (1 - terminals[:,0])

    observations = ptu.get_numpy(observations)
    returns = ptu.get_numpy(returns)
    path_lens = ptu.get_numpy(path_lens)

    paths = []
    for i in range(num_rollouts):
        paths.append(dict(
            returns=returns[i],
            path_len=path_lens[i],
            next_observations=observations[i],
        ))

    return paths


def model_policy_rollout_torch_online(
        dynamics_model,
        policy,
        start_states,
        max_path_length=1,
        gamma=.99,
        terminal_cutoff=0.5,
):
    num_rollouts = start_states.shape[0]
    obs_dim = dynamics_model.obs_dim
    action_dim = dynamics_model.action_dim

    returns = ptu.zeros(num_rollouts)
    terminals = ptu.zeros((num_rollouts, 1))
    path_lens = ptu.zeros(num_rollouts)

    observations = ptu.from_numpy(start_states)
    discount = 1
    for t in range(max_path_length):
        actions, *_ = policy.forward(observations)
        state_actions = torch.cat((observations, actions), dim=1)
        transitions = dynamics_model.sample(state_actions)
        if (transitions != transitions).any():
            print('warning: nan transitions')
            transitions[transitions != transitions] = 0
        returns += discount * (1 - terminals[:,0]) * transitions[:,0]
        discount *= gamma
        terminals = torch.max(
            (transitions[:,1:2] > terminal_cutoff).float(),
            terminals
        )
        observations = observations + transitions[:,2:]
        path_lens += (1 - terminals[:,0])

    return ptu.get_numpy(returns)


def model_policy_latent_rollout_torch_online(
        dynamics_model,
        policy,
        start_states,
        latents,
        max_path_length=1,
        gamma=.99,
        terminal_cutoff=0.5,
):
    if terminal_cutoff is None:
        terminal_cutoff = 1e6

    num_rollouts = start_states.shape[0]
    obs_dim = dynamics_model.obs_dim
    action_dim = dynamics_model.action_dim

    returns = ptu.zeros(num_rollouts)
    terminals = ptu.zeros((num_rollouts, 1))
    path_lens = ptu.zeros(num_rollouts)

    observations = ptu.from_numpy(start_states)
    latents = ptu.from_numpy(latents)
    discount = 1
    for t in range(max_path_length):
        state_latents = torch.cat((observations, latents), dim=1)
        actions, *_ = policy.forward(state_latents)
        state_actions = torch.cat((observations, actions), dim=1)
        transitions = dynamics_model.sample(state_actions)
        if (transitions != transitions).any():
            print('warning: nan transitions')
            transitions[transitions != transitions] = 0
        returns += discount * (1 - terminals[:,0]) * transitions[:,0]
        discount *= gamma
        terminals = torch.max(
            (transitions[:,1:2] > terminal_cutoff).float(),
            terminals
        )
        observations = observations + transitions[:,2:]
        path_lens += (1 - terminals[:,0])

    return dict(
        observations=observations,
        returns=returns,
        terminals=terminals,
        path_lens=path_lens,
    )


def model_policy_latent_lstm_rollout_torch_online(
        dynamics_model,
        policy,
        start_states,
        latents,
        max_path_length=1,
        gamma=.99,
        terminal_cutoff=0.5,
):
    if terminal_cutoff is None:
        terminal_cutoff = 1e6

    num_rollouts = start_states.shape[0]
    obs_dim = dynamics_model.obs_dim
    action_dim = dynamics_model.action_dim

    returns = ptu.zeros(num_rollouts)
    terminals = ptu.zeros((num_rollouts, 1))
    path_lens = ptu.zeros(num_rollouts)

    observations = ptu.from_numpy(start_states)
    latents = ptu.from_numpy(latents)
    discount = 1
    hidden, sample_idx = None, None
    for t in range(max_path_length):
        state_latents = torch.cat((observations, latents), dim=1)
        actions, *_ = policy.forward(state_latents)
        state_actions = torch.cat((observations, actions), dim=1)
        transitions, hidden, sample_idx = \
            dynamics_model.sample(state_actions, hidden=hidden, sample_idx=sample_idx)
        if (transitions != transitions).any():
            print('warning: nan transitions')
            transitions[transitions != transitions] = 0
        returns += discount * (1 - terminals[:,0]) * transitions[:,0]
        discount *= gamma
        terminals = torch.max(
            (transitions[:,1:2] > terminal_cutoff).float(),
            terminals
        )
        observations = observations + transitions[:,2:]
        path_lens += (1 - terminals[:,0])

    return dict(
        observations=observations,
        returns=returns,
        terminals=terminals,
        path_lens=path_lens,
    )


def model_policy_rollout_torch_rnd(
        dynamics_model,
        policy,
        rnd,
        start_states,
        rnd_threshold,
        max_path_length=1,
        terminal_cutoff=0.5,
):
    num_rollouts = start_states.shape[0]
    obs_dim = dynamics_model.obs_dim
    action_dim = dynamics_model.action_dim

    observations = ptu.zeros((num_rollouts, max_path_length+1, obs_dim))
    actions = ptu.zeros((num_rollouts, max_path_length, action_dim))
    rewards = ptu.zeros((num_rollouts, max_path_length, 1))
    terminals = ptu.zeros((num_rollouts, max_path_length+1, 1))

    stopped = np.zeros(num_rollouts)
    rollout_lens = np.zeros(num_rollouts)

    observations[:,0] = ptu.from_numpy(start_states)
    for t in range(max_path_length):
        s_torch = ptu.from_numpy(stopped)
        finished = torch.max(s_torch, terminals[:,t].squeeze(dim=-1))
        if torch.sum(finished) >= num_rollouts:
            break

        rem = ptu.get_numpy(finished) < 0.5

        actions[rem,t], *_ = policy.forward(observations[rem,t])
        state_actions = torch.cat(
            [observations[rem,t], actions[rem,t]], dim=1)

        rnd_preds = rnd.get_prediction(state_actions)
        rnd_preds = ptu.get_numpy(rnd_preds.squeeze(dim=-1))
        stopped[rem] = rnd_preds > rnd_threshold
        rollout_lens[stopped < 0.5] += 1

        transitions = dynamics_model.sample(state_actions)
        if (transitions != transitions).any():
            print('warning: nan transitions')
            transitions[transitions != transitions] = 0
        rewards[rem,t] = transitions[:,:1]
        terminals[rem,t+1] = (transitions[:,1:2] > terminal_cutoff).float()
        terminals[:,t+1] = torch.max(terminals[:,t], terminals[:,t+1])
        observations[rem,t+1] = observations[rem,t] + transitions[:,2:]

    observations = ptu.get_numpy(observations)
    actions = ptu.get_numpy(actions)
    rewards = ptu.get_numpy(rewards)
    terminals = ptu.get_numpy(terminals)

    paths = []
    for i in range(num_rollouts):
        rollout_len = 0
        while rollout_len < max_path_length and \
              terminals[i,rollout_len,0] < 0.5:
            rollout_len += 1
        rollout_len = min(rollout_len, int(rollout_lens[i]))
        rollout_len = max(rollout_len, 1)

        paths.append(dict(
            observations=observations[i,:rollout_len],
            actions=actions[i,:rollout_len],
            rewards=rewards[i,:rollout_len],
            next_observations=observations[i,1:rollout_len+1],
            terminals=terminals[i,1:rollout_len+1],
            agent_infos=[[] for _ in range(rollout_len)],
            env_infos=[[] for _ in range(rollout_len)],
        ))

    return paths


def model_policy_rollout_with_disagreement_torch(
        dynamics_model,
        policy,
        start_states,
        disagreement_threshold,
        max_path_length=1,
        terminal_cutoff=0.5,
):
    num_rollouts = start_states.shape[0]
    obs_dim = dynamics_model.obs_dim
    action_dim = dynamics_model.action_dim

    """
    observations = ptu.zeros((num_rollouts, max_path_length+1, obs_dim))
    actions = ptu.zeros((num_rollouts, max_path_length, action_dim))
    rewards = ptu.zeros((num_rollouts, max_path_length, 1))
    terminals = ptu.zeros((num_rollouts, max_path_length+1, 1))
    """

    observations = []
    actions = []
    rewards = []
    terminals = []

    stopped = np.zeros(num_rollouts)
    rollout_lens = np.zeros(num_rollouts)

    # observations[:,0] = ptu.from_numpy(start_states)
    observations.append(ptu.from_numpy(start_states))
    terminals.append(ptu.zeros(num_rollouts))

    for t in range(max_path_length):
        s_torch = ptu.from_numpy(stopped)
        # finished = torch.max(s_torch, terminals[:,t].squeeze(dim=-1))
        finished = torch.max(s_torch, terminals[t].squeeze(dim=-1))
        if torch.sum(finished) >= num_rollouts:
            max_length = t
            break

        """
        actions[:,t], *_ = policy.forward(observations[:,t])
        state_actions = torch.cat(
            [observations[:,t], actions[:,t]], dim=1)
        """
        acts, *_ = policy.forward(observations[t])
        actions.append(acts)
        state_actions = torch.cat(
            [observations[t], actions[t]], dim=1)

        transitions, disagreement = dynamics_model.sample_with_disagreement(state_actions)
        if (transitions != transitions).any():
            print('warning: nan transitions')
            transitions[transitions != transitions] = 0

        """
        rewards[:,t] = transitions[:,:1]
        terminals[:,t+1] = (transitions[:,1:2] > terminal_cutoff).float()
        terminals[:,t+1] = torch.max(terminals[:,t], terminals[:,t+1])
        observations[:,t+1] = observations[:,t] + transitions[:,2:]
        """
        rewards.append(transitions[:,:1])
        terminals.append((transitions[:,1:2] > terminal_cutoff).float())
        terminals[-1] = torch.max(terminals[-1], terminals[-2])[0]
        observations.append(observations[-1] + transitions[:,2:])

        disagreement = ptu.get_numpy(disagreement)
        stopped = np.maximum(stopped, disagreement > disagreement_threshold)
        rollout_lens[stopped < 0.5] += 1
        max_length = t + 1
    # print(t)

    """
    observations = ptu.get_numpy(observations)
    actions = ptu.get_numpy(actions)
    rewards = ptu.get_numpy(rewards)
    terminals = ptu.get_numpy(terminals)
    """

    def np_ify(x):
        return np.expand_dims(ptu.get_numpy(x), axis=1)

    observations_np = np_ify(observations[0])
    actions_np = np_ify(actions[0])
    rewards_np = np_ify(rewards[0])
    terminals_np = np_ify(terminals[0])

    for i in range(1, max_length+1):
        observations_np = np.concatenate((observations_np, np_ify(observations[i])), axis=1)
        terminals_np = np.concatenate((terminals_np, np_ify(terminals[i])), axis=1)

        if i < max_length:
            actions_np = np.concatenate((actions_np, np_ify(actions[i])), axis=1)
            rewards_np = np.concatenate((rewards_np, np_ify(rewards[i])), axis=1)

    observations = observations_np
    actions = actions_np
    rewards = rewards_np
    terminals = np.expand_dims(terminals_np, axis=-1)

    paths = []
    for i in range(num_rollouts):
        rollout_len = 0
        while rollout_len < max_length and \
              terminals[i,rollout_len,0] < 0.5:
            rollout_len += 1
        rollout_len = min(rollout_len, int(rollout_lens[i]))
        rollout_len = max(rollout_len, 1)

        paths.append(dict(
            observations=observations[i,:rollout_len],
            actions=actions[i,:rollout_len],
            rewards=rewards[i,:rollout_len],
            next_observations=observations[i,1:rollout_len+1],
            terminals=terminals[i,1:rollout_len+1],
            agent_infos=[[] for _ in range(rollout_len)],
            env_infos=[[] for _ in range(rollout_len)],
        ))

    return paths
