import numpy as np
import torch

__all__ = ["compute_advantage", "evaluate", "rollout"]


def compute_advantage(gamma, lmbda, td_delta, dones):
    td_delta = td_delta.detach()
    advantage = torch.zeros_like(td_delta)
    gae = 0.0
    for t in reversed(range(len(td_delta))):
        if dones[t, 0] == 1.0:
            gae = 0.0
        gae = td_delta[t, 0] + gamma * lmbda * gae
        advantage[t, 0] = gae
    return advantage


def evaluate(env, env_seed, agents, max_episode_steps=None, mean_return=None):
    state_np, _ = env.reset(seed=env_seed)
    episode_return = 0
    done_np = False
    while not done_np:
        action_np = agents.take_action(state_np, eval=True)
        next_state_np, reward_np, terminated, truncated, _ = env.step(action_np)
        done_np = np.logical_or(terminated, truncated)
        state_np = next_state_np
        episode_return += reward_np

    mean_return = episode_return
    return mean_return


def rollout(env, agents, rollout_step):
    # 环境重置
    state_list = env.reset()
    transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
    episode_return = torch.zeros(env.num_envs).to(env.device)
    run_step = 0  # 初始化 run_step 计数器

    for step in range(rollout_step):
        action = agents.take_action(state_list)
        next_state_list, reward, done, _ = env.step(action)
        reward = sum(reward).view(-1)

        # 记录经验
        # 对于每个env，第一个done=True该环境的交互就已经结束了
        transition_dict['states'].append(torch.stack(state_list, dim=1))  # 转换为 [num_envs, num_agents, state_dim]
        transition_dict['actions'].append(torch.stack(action, dim=1))  # 转换为 [num_envs, num_agents, action_dim]
        transition_dict['next_states'].append(torch.stack(next_state_list, dim=1))  # 转换为 [num_envs, num_agents, state_dim]
        transition_dict['rewards'].append(reward)  # 转换为 [num_envs, num_agents]
        transition_dict['dones'].append(done.unsqueeze(-1))  # 转换为 [num_envs, 1]

        # 更新状态和总奖励
        state_list = next_state_list
        episode_return += reward

        # 计算 run_step
        run_step += torch.sum(~done).item()  # 对 done 中为 False 的元素进行计数

    return transition_dict, episode_return.sum() / env.num_envs, run_step