import torch
import math

class RunningStats:
    def __init__(self):
        self.n = 0
        self.mu = 0
        self.sum_sq_dev = 0

    def update(self, R):
        count = R.shape[0]
        if count > 0:
            sample_mean = torch.mean(R).item()

            prev_mu = self.mu
            self.mu = (self.n * self.mu + count * sample_mean) / (self.n + count)

            sum_sq_dev_new_data = torch.sum((R - prev_mu) * (R - prev_mu)).item() # Fixed this line
            self.sum_sq_dev += sum_sq_dev_new_data + count * (prev_mu - self.mu) * (prev_mu - self.mu)
            self.n += count

    def running_mean(self):
        return self.mu

    def running_std_dev(self):
        if self.n > 1:
            return math.sqrt(self.sum_sq_dev / (self.n - 1))
        else:
            return 0

    def reset(self):
        self.n = 0
        self.mu = 0
        self.sum_sq_dev = 0