import numpy as np
import torch


def compute_advantage(gamma, lmbda, td_delta, dones):
    td_delta = td_delta.detach()
    advantage = torch.zeros_like(td_delta)
    gae = 0.0
    for t in reversed(range(len(td_delta))):
        if dones[t, 0] == 1.0:
            gae = 0.0
        gae = td_delta[t, 0] + gamma * lmbda * gae
        advantage[t, 0] = gae
    return advantage


def evaluate(env, env_seed, agents, max_episode_steps=None, mean_return=None):
    state_np, _ = env.reset(seed=env_seed)
    episode_return = 0
    done_np = False
    while not done_np:
        action_np = agents.take_action(state_np, eval=True)
        next_state_np, reward_np, terminated, truncated, _ = env.step(action_np)
        done_np = np.logical_or(terminated, truncated)
        state_np = next_state_np
        episode_return += reward_np

    mean_return = episode_return

    return mean_return


def rollout(env, agents, rollout_step):
    # 环境重置
    state_list = env.reset()
    transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
    episode_return = torch.zeros(env.num_envs).to(env.device)
    run_step = 0

    for step in range(rollout_step):
        action = agents.take_action(state_list)
        next_state_list, reward, done, _ = env.step(action)
        reward = sum(reward).view(-1)/env.num_agents

        transition_dict['states'].append(torch.stack(state_list, dim=1))  # [num_envs, num_agents, state_dim]
        transition_dict['actions'].append(torch.stack(action, dim=1))  # [num_envs, num_agents, action_dim]
        transition_dict['next_states'].append(torch.stack(next_state_list, dim=1))  # [num_envs, num_agents, state_dim]
        transition_dict['rewards'].append(reward)  # [num_envs, num_agents]
        transition_dict['dones'].append(done.unsqueeze(-1))  # [num_envs, 1]

        state_list = next_state_list
        episode_return += reward

        run_step += torch.sum(~done).item()

    return transition_dict, episode_return.sum() / env.num_envs, run_step