import numpy as np
import torch
from collections import deque


def evaluate(env, agents, max_episode_steps=None):
    state_np = env.reset()  # [agent_num, obs_dim]
    episode_return = 0
    done_np = False
    while not done_np:
        action_np = agents.take_action(state_np, eval=True)
        next_state_np, reward_np, terminated, truncated, _ = env.step(action_np)
        done_np = np.logical_or(terminated, truncated)

        episode_return += reward_np
        state_np = next_state_np

    # env.render()
    return episode_return

def evaluate_dg(env, agents, max_episode_steps=None):
    state_np = env.reset(eval=True)  # [agent_num, obs_dim]
    episode_return = 0
    done_np = False
    action_np = None
    while not done_np:
        action_np = agents.take_action(state_np, eval=True)
        next_state_np, reward_np, terminated, truncated, _ = env.step(action_np)
        done_np = np.logical_or(terminated, truncated)

        episode_return += reward_np
        state_np = next_state_np

    # env.render()
    return episode_return, action_np

def rollout(env, agents):
    transition_dict = {
        'states': [],
        'actions': [],
        'next_states': [],
        'rewards': [],
        'dones': [],
    }
    num = 0
    while num < 4000:
        state_np = env.reset()  # [agent_num, obs_dim]
        done_np = False
        while not done_np:
            num+=1
            action_np = agents.take_action(state_np)  # [agent_num, action_dim]
            next_state_np, reward_np, terminated_np, truncated_np, _ = env.step(action_np)
            done_np = np.logical_or(terminated_np, truncated_np)

            transition_dict['states'].append(state_np)  # [, agent_num, obs_dim]
            transition_dict['actions'].append(action_np)  # [, agent_num, action_dim]
            transition_dict['rewards'].append(reward_np)
            transition_dict['next_states'].append(next_state_np)
            transition_dict['dones'].append(done_np)

            state_np = next_state_np

    env_step = len(transition_dict['dones'])
    return transition_dict, env_step


def rollout_async(envs, agents, max_ep_len):
    env_num = envs.env_num
    transition_dict = {
        'states': [],
        'actions': [],
        'next_states': [],
        'rewards': [],
        'dones': [],
    }
    experience_cache = [deque() for _ in range(env_num)]
    state_np = envs.reset()  # maybe [env_num, agent_num, obs_dim]
    for step in range(max_ep_len):
        action_np = agents.take_action_async(state_np)  # maybe action_np: [env_num, agent_num, action_dim]
        next_state_np, reward_np, terminated, truncated, _ = envs.step(action_np)
        done_np = np.logical_or(terminated, truncated)
        # print("next_state_np.shape", next_state_np.shape)  # [env_num, agt_num, obs_dim]
        # print("reward_np.shape", reward_np.shape)  # [env_num]
        # print("terminated.shape", terminated.shape)  # [env_num]
        # print("truncated.shape", truncated.shape)  # [env_num]
        for i in range(env_num):
            experience_cache[i].append((state_np[i], action_np[i], reward_np[i], next_state_np[i], done_np[i]))
        state_np = next_state_np
        if done_np.any():
            done_indices = np.nonzero(done_np)[0]
            for index in done_indices:
                # print(index)
                episode_experience = list(experience_cache[index])
                transition_dict['states'].extend([exp[0] for exp in episode_experience])
                transition_dict['actions'].extend([exp[1] for exp in episode_experience])
                transition_dict['rewards'].extend([exp[2] for exp in episode_experience])
                transition_dict['next_states'].extend([exp[3] for exp in episode_experience])
                transition_dict['dones'].extend([exp[4] for exp in episode_experience])
                experience_cache[index].clear()

    env_step = len(transition_dict['dones'])
    return transition_dict, env_step