from hive.agents.agent import Agent
from hive.envs.base import BaseEnv
from torchtyping import TensorType
from typing import Tuple
import ray
import torch

class RolloutWorker:
    def __init__(self, env: BaseEnv, agent: Agent, batch_size: int):
        self.env, self.agent, self.batch_size = env, agent, batch_size
        self.horizon = self.env.horizon
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def sample(self, batch_size: int = None) -> Tuple[
        TensorType['batch_size', 'horizon', 'ndim_times_side_len', int], # states
        TensorType['batch_size', 'horizon', int], # actions
        TensorType['batch_size', 'horizon', int], # backwards actions
        TensorType['batch_size', 'horizon', bool], # dones
        TensorType['batch_size', float], # rewards
    ]:
        batch_size = batch_size or self.batch_size
        all_states, all_dones, all_rewards, all_actions = \
            self._get_init_sample_tensors(batch_size)

        dones = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
        states = self.env.reset()

        i = 0
        all_states[:, i] = states
        all_dones[:, i] = dones
        while not dones.all():
            with torch.no_grad():
                agent_states = self.agent.get_agent_state(
                    all_states,
                    all_actions,
                    all_dones,
                    all_rewards,
                    i
                )

                actions = self.agent.act(agent_states)

            actions[dones] = -1
            all_actions[:, i] = actions

            i += 1
            states, all_rewards, dones = self.env.step(actions)

            idx_to_insert = i if i != self.horizon else self.horizon - 1
            all_states[:, idx_to_insert] = states
            all_dones[:, idx_to_insert] = dones

        return (
            all_states,
            all_actions,
            self.env.get_backward_actions(all_states, all_actions),
            all_dones,
            all_rewards,
        )

    def _get_init_sample_tensors(self, batch_size: int = None) -> Tuple[
        TensorType['batch_size', 'horizon', 'ndim_times_side_len', int], # states
        TensorType['batch_size', 'horizon', int], # actions
        TensorType['batch_size', 'horizon', int], # backwards actions
        TensorType['batch_size', float], # rewards
        TensorType['batch_size', 'horizon', bool], # dones
    ]:
        all_states = torch.full(
            (batch_size, self.horizon, self.env.obs_dim),
            fill_value=-1,
            dtype=torch.float,
            device=self.device
        )

        all_dones = torch.ones(
            (batch_size, self.horizon),
            dtype=torch.int,
            device=self.device
        )

        all_rewards = torch.zeros(
            (batch_size),
            dtype=torch.float,
            device=self.device
        )

        all_actions = torch.full(
            (batch_size, self.horizon),
            fill_value=-1,
            dtype=torch.int,
            device=self.device
        )

        return all_states, all_dones, all_rewards, all_actions


    def set_agent_weights(self, weights):
        self.agent.load_state_dict(weights)

@ray.remote
class RolloutWorkerRemote(RolloutWorker):
    pass
