import gym
import numpy as np
import torch


def collect_experience(member, env_wrapper, buffer=None):
    if env_wrapper.done:
        env_wrapper.state = env_wrapper.env.reset()
        env_wrapper.step_count = 0
        env_wrapper.done = False

    state = env_wrapper.state
    action = member.agent.sample_action(state)
    next_state, reward, done, info = env_wrapper.env.step(action)

    # infinite bootstrapping
    if env_wrapper.step_count + 1 == env_wrapper.max_episode_steps:
        done = False
    if buffer:
        # go ahead and push to the buffer
        buffer.push(state, action, reward, next_state, done)

    # prep for next iteration
    env_wrapper.state = next_state
    env_wrapper.step_count += 1
    env_wrapper.done = done
    if env_wrapper.step_count >= env_wrapper.max_episode_steps:
        env_wrapper.done = True

    if buffer:
        return None
    else:
        return state, action, reward, next_state, done


def evaluate_agent(
    agent, env, eval_episodes, max_episode_steps, render=False, verbosity=0
):
    agent.eval()
    returns = run_env(
        agent, env, eval_episodes, max_episode_steps, render, verbosity=verbosity
    )
    agent.train()
    mean_return = returns.mean()
    return mean_return


def warmup_buffer(buffer, env, warmup_steps, max_episode_steps):
    # use warmp up steps to add random transitions to the buffer
    state = env.reset()
    done = False
    steps_this_ep = 0
    for _ in range(warmup_steps):
        if done:
            state = env.reset()
            steps_this_ep = 0
            done = False
        rand_action = env.action_space.sample()
        if not isinstance(rand_action, np.ndarray):
            rand_action = np.array(float(rand_action))
        next_state, reward, done, info = env.step(rand_action)
        buffer.push(state, rand_action, reward, next_state, done)
        state = next_state
        steps_this_ep += 1
        if steps_this_ep >= max_episode_steps:
            done = True


def run_env(agent, env, episodes, max_steps, render=False, verbosity=1, discount=1.0):
    episode_return_history = []
    if render:
        env.render("rgb_array")
    for episode in range(episodes):
        episode_return = 0.0
        state = env.reset()
        done, info = False, {}
        for step_num in range(max_steps):
            if done:
                break
            action = agent.forward(state)
            state, reward_hist, done, info = env.step(action)
            if render:
                env.render("rgb_array")
            episode_return += reward_hist.sum() * (discount ** step_num)
        if verbosity:
            print(f"Episode {episode}:: {episode_return}")
        episode_return_history.append(episode_return)
    return torch.tensor(episode_return_history)
