import numpy as np
import haiku as hk
import copy
from tqdm import tqdm

from utils import normalise_data_mean_stddev


def evaluate_episode_rtg(
    seed,
    agent_state,
    env,
    act_dim,
    agent,
    alpha=0.,
    max_ep_len=1000,
    state_mean=0.0,
    state_stddev=1.0,
    concat_obs=False,
    robomimic=False,
):

    rng_seq = hk.PRNGSequence(seed)
    obs = env.reset()

    actions = np.zeros((0, act_dim))
    rewards = np.zeros(0)

    episode_return, episode_length = 0, 0
    done = False
    avg_rtg_pred = 0
    success = 0
    while not done and (episode_length <= max_ep_len):
        if concat_obs:
            obs = np.concatenate([obs[v] for v in obs], -1)
        obs = normalise_data_mean_stddev(obs, state_mean, state_stddev)
        action, rtg_pred = agent.get_action(
            rng=next(rng_seq),
            agent_state=agent_state,
            obs=obs,
            alpha=alpha,
        )
        avg_rtg_pred += rtg_pred

        next_obs, reward, done, _ = env.step(action)
        obs = copy.deepcopy(next_obs)
        episode_return += reward
        episode_length += 1
        if robomimic:
            if env.is_success()["task"]:
                episode_return = 1
                break

    return episode_return, episode_length, avg_rtg_pred/episode_length
