import heapq

import torch


class StreamingMedian:
    def __init__(self):
        """
        Initialize the StreamingMedian calculator.
        """
        self.lower_half = []  # Max-heap (negative values to simulate max-heap in Python)
        self.upper_half = []  # Min-heap

    def update(self, value):
        """
        Update the median with a new value from the stream.

        Args:
        - value (float or int): The new value to add to the stream.

        Returns:
        - float: The current median after adding the new value.
        """
        # Add to the appropriate heap
        if not self.lower_half or value <= -self.lower_half[0]:
            heapq.heappush(self.lower_half, -value)  # Push negative to simulate max-heap
        else:
            heapq.heappush(self.upper_half, value)

        # Balance the heaps to ensure their sizes differ by at most 1
        if len(self.lower_half) > len(self.upper_half) + 1:
            heapq.heappush(self.upper_half, -heapq.heappop(self.lower_half))
        elif len(self.upper_half) > len(self.lower_half):
            heapq.heappush(self.lower_half, -heapq.heappop(self.upper_half))

        # Return the median
        if len(self.lower_half) == len(self.upper_half):
            return (-self.lower_half[0] + self.upper_half[0]) / 2.0
        return -self.lower_half[0]  # Median is the top of the max-heap

    def update_batch(self, values):
        """
        Update the median with a batch of values.

        Args:
        - values (torch.Tensor): The new values to add to the stream.

        Returns:
        - float: The current median after adding the new values.
        """
        for value in values:
            self.update(value)
        return self.get_median()

    def get_median(self):
        """
        Get the current median.

        Returns:
        - float: The current median, or None if no values have been added.
        """
        if not self.lower_half:
            return None
        if len(self.lower_half) == len(self.upper_half):
            return (-self.lower_half[0] + self.upper_half[0]) / 2.0
        return -self.lower_half[0]


class RunningAverage:
    def __init__(self):
        """
        Initialize the RunningAverage calculator.
        """
        self.total = torch.tensor(0.0)  # Total sum of values
        self.count = torch.tensor(0)  # Count of values added

    def update(self, value):
        """
        Update the running average with a new value from the stream.

        Args:
        - value (torch.Tensor): The new scalar value (torch tensor) to update the running average.

        Returns:
        - torch.Tensor: The updated running average.
        """
        self.total += value
        self.count += 1
        return self.total / self.count

    def update_batch(self, values):
        """
        Update the running average with a batch of values.

        Args:
        - values (torch.Tensor): The new scalar values (torch tensor) to update the running average.

        Returns:
        - torch.Tensor: The updated running average.
        """
        self.total += values.sum()
        self.count += values.numel()
        return self.total / self.count

    def get_average(self):
        """
        Get the current running average.

        Returns:
        - torch.Tensor: The current running average, or None if no values have been added.
        """
        if self.count == 0:
            return None
        return self.total / self.count


class EMA:
    def __init__(self, alpha=0.1):
        """
        Initialize the EMA calculator.

        Args:
        - alpha (float): Smoothing factor, typically between 0 and 1.
        """
        self.alpha = alpha
        self.ema_value = None  # EMA starts as None until the first value is added.

    def update(self, value):
        """
        Update the EMA with a new value from the stream.

        Args:
        - value (torch.Tensor): The new scalar value (torch tensor) to update the EMA.

        Returns:
        - torch.Tensor: The updated EMA value.
        """
        if self.ema_value is None:
            # Initialize EMA with the first value
            self.ema_value = value.clone()  # Clone to avoid modifying the original tensor
        else:
            # Update EMA using the formula
            self.ema_value = self.alpha * value + (1 - self.alpha) * self.ema_value
        return self.ema_value

    def get_ema(self):
        """
        Get the current EMA value.

        Returns:
        - torch.Tensor: The current EMA value, or None if no values have been added.
        """
        return self.ema_value
