import numpy as np
import torch
from minigrid.minigrid_env import MiniGridEnv
from tensordict import TensorDict
from tqdm import trange


def generate_on_policy_data(
    policy, envs: MiniGridEnv, num_trajectories: int, max_steps: int, seed: int, device: torch.device
) -> TensorDict:
    """Generate on-policy data for a given policy and environment.

    Args:
        policy: Policy to generate data with.
        env: Environment to generate data in.
        num_trajectories: Number of trajectories to generate.
        max_steps: Maximum number of steps per trajectory.
        seed: Seed to use for environment reset.
        device: Device to store data on.

    Returns:
        TensorDict with trajectories as keys and observations as values.
    """
    data = {}
    for trajectory in trange(num_trajectories):
        # NOTE: Always reset to the same seed to generate the same gridworld
        next_obs, _ = envs.reset(seed=seed)
        next_obs = torch.Tensor(next_obs).to(device)
        observations = [next_obs]

        for step in range(max_steps):
            # Step the environment
            with torch.no_grad():
                action, _ = policy.get_eval_action(next_obs)
            next_obs, _, terminations, truncations, _ = envs.step(action.cpu().numpy())

            next_obs = torch.Tensor(next_obs).to(device)
            observations.append(next_obs)

            # Stop if the episode is done or the maximum number of steps is reached
            next_done = np.logical_or(terminations, truncations)
            if next_done or max_steps - 1 < step:
                break

        data[f"trajectory_{trajectory}"] = torch.cat(observations)

    # Cast as TensorDict
    return TensorDict(**data, device=device)
