import ray
import numpy as np
from typing import Dict, Optional, Set


@ray.remote
class Channel:
    def __init__(self):
        self.embedding_buffer: Dict[str, np.ndarray] = {}
        self.gradient_buffer: Dict[str, np.ndarray] = {}

    def send_embedding(self, batch_id: str, embedding: np.ndarray) -> None:
        self.embedding_buffer[batch_id] = embedding

    def receive_embedding(self, batch_id: str) -> Optional[np.ndarray]:
        return self.embedding_buffer.get(batch_id)

    def send_gradient(self, batch_id: str, gradient: np.ndarray) -> None:
        self.gradient_buffer[batch_id] = gradient

    def receive_gradient(self, batch_id: str) -> Optional[np.ndarray]:
        return self.gradient_buffer.get(batch_id)

    def clear_batch(self, batch_id: str) -> None:
        self.embedding_buffer.pop(batch_id, None)
        self.gradient_buffer.pop(batch_id, None)

    def clear_all(self) -> None:
        self.embedding_buffer.clear()
        self.gradient_buffer.clear()

    def get_pending_batches(self) -> Set[str]:
        return set(self.embedding_buffer.keys()) | set(self.gradient_buffer.keys())