import torch
import numpy as np
from tqdm import tqdm
from torch.distributions import Categorical
from algos.misodice_discrete import DiscreteActor


class RolloutWorkerDiscrete:

    def __init__(self, model: DiscreteActor, n_agents, device="cuda"):
        self.model = model
        self.n_agents = n_agents
        self.device = device
    
    def sample(self, obs, avails, deterministic=False):
        with torch.no_grad():
            obs = torch.tensor(np.array(obs), dtype=torch.float32, device=self.device).squeeze(0)
            avails = torch.tensor(np.array(avails), dtype=torch.float32, device=self.device).squeeze(0)
            logits = self.model.forward(obs) + avails.log()
            logits = logits - logits.logsumexp(-1, True)
            actions = logits.argmax(-1) if deterministic else Categorical(logits=logits).sample()
            actions = actions.cpu().numpy()
        return actions

    def rollout(self, env, num_episodes=32, verbose=False):
        self.model.eval()
        T_rewards, T_wins = [], []
        for _ in tqdm(range(num_episodes), desc="Rollout", leave=False, disable=not verbose, ncols=80):
            reward_sum = 0
            infos = []
            obs, _, avails = env.reset()
            while True:
                actions = self.sample(obs, avails, deterministic=True)
                obs, _, rewards, dones, infos, avails = env.step(actions)
                reward_sum += np.mean(rewards)
                if np.all(dones):
                    break
            reward_sum = round(reward_sum, 3)
            T_rewards.append(reward_sum)
            if infos[0]["won"]:
                T_wins.append(1)
            else:
                T_wins.append(0)
        results = {
            "returns": T_rewards,
            "winrates": T_wins,
        }
        self.model.train()
        return results