import torch
import numpy as np
from tqdm import tqdm
from algos.misodice_continuous import TanhActor



class RolloutWorkerContinuous:

    def __init__(self, model: TanhActor, n_agents, device="cuda"):
        self.model = model
        self.n_agents = n_agents
        self.device = device
        self.deterministic = False

        self.obs_scale = 1.0
        self.obs_shift = 0.0
    
    def sample(self, obs):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = (obs + self.obs_shift) * self.obs_scale

            all_actions, _ = self.model.forward(obs)
            if self.deterministic:
                actions = all_actions[0]
            else:
                actions = all_actions[1]
            actions = actions.cpu().numpy()
        return actions

    def rollout(self, env, num_episodes=32, verbose=False):
        self.model.eval()
        T_rewards = []
        for _ in tqdm(range(num_episodes), desc="Rollout", leave=False, disable=not verbose, ncols=80):
            reward_sum = 0
            obs, _, _ = env.reset()
            while True:
                actions = self.sample(obs)
                obs, _, rewards, dones, _, _ = env.step(actions)
                reward_sum += np.mean(rewards)
                if np.all(dones):
                    break
            reward_sum = round(reward_sum, 3)
            T_rewards.append(reward_sum)
        results = {"returns": T_rewards}
        self.model.train()
        return results