import random, os, time

import wandb

import torch
import multiprocessing as mp
from multiprocessing import shared_memory

import diversipy
import numpy as np

from collections import deque

from collections import deque

colors = ['red', 'blue', 'green', 'orange', 'purple']

class ParallelReservoirSamplingMemory:
    def __init__(self, img_size, capacity, batch_size, transform,
                 num_producers, queue_max_size):

        self.capacity = capacity
        mem = np.ones((self.capacity, img_size, img_size, 3)).astype(np.uint8)
        
        self.shm = shared_memory.SharedMemory(create=True, size=mem.nbytes)

        # Create a NumPy array backed by shared memory
        self.memory = np.ndarray(mem.shape, dtype=mem.dtype, buffer=self.shm.buf)
        
        # Copy the original data to the shared array
        self.memory[:] = mem[:]

        self.mem_labels = -1 * np.ones(self.capacity)

        self.n_items = mp.Value('i', 0)
        self.batch_size = batch_size

        self.num_producers = num_producers
        self.queue = mp.Queue(maxsize=queue_max_size)
        self.stop_signal = mp.Event()
        self.lock = mp.Lock()
        self.producers = []
        self.producers_exist = False

        self.transform = transform
        self.mem_ages = np.zeros(self.capacity)
        self.eviction_ages = []
        self.step = 0


    def log_class_balance(self, N=20):

        counts = np.zeros(N)
        unique_labels, unique_counts = np.unique(self.mem_labels[:int(self.n_items.value)], return_counts=True)
        counts[unique_labels.astype(np.int32)] = unique_counts

        # Calculate the observed distribution
        total_count = len(self.mem_labels[:int(self.n_items.value)])
        observed_distribution = counts / total_count
        
        # Create a uniform distribution
        uniform_distribution = np.ones(N) / N
       
        # Handle zero entries to avoid division by zero or log(0)
        mask = observed_distribution > 0
        kl_divergence = np.sum(observed_distribution[mask] * 
                            np.log(observed_distribution[mask] / uniform_distribution[mask]))
        
        if len(self.eviction_ages) > 0:
            average_duration = np.mean(self.eviction_ages)
            e_min = np.min(self.eviction_ages)
            e_max = np.max(self.eviction_ages)
        else:
            average_duration = 0
            e_min = 0
            e_max = 0

        wandb.log({
            'mem_step': self.step,
            'avg_mem_duration': average_duration,
            'ltm_class_balance': kl_divergence,
            'min_duration': e_min,
            'max_duration': e_max, 
        })

    def add(self, x, y):
        added = 0
        for i in range(len(x)):
            if int(self.n_items.value) < self.capacity:
                added += 1
                with self.lock:
                    self.memory[int(self.n_items.value)] = x[i]
                self.mem_labels[int(self.n_items.value)] = y[i]
            else:
                replace_index = random.randint(0, int(self.n_items.value))
                if replace_index < self.capacity:
                    with self.lock:
                        self.memory[replace_index] = x[i]
                    self.mem_labels[replace_index] = y[i]
                    self.eviction_ages.append(self.mem_ages[replace_index])
                    self.mem_ages[replace_index] = 0

            self.n_items.value += 1

        if int(self.n_items.value) >= self.capacity:
            self.mem_ages += 1
        else:
            self.mem_ages[:int(self.n_items.value) % self.capacity] += 1


        self.step += 1
        return added


    def fetch_data(self):
        """Simulate fetching data; replace with your actual data source."""

        inds = np.random.choice(min(int(self.n_items.value), self.capacity), 
                                self.batch_size, replace=False)

        with self.lock:
            mem_exes = self.memory[inds].copy()
 
        batch = [self.transform(ex) for ex in mem_exes]
        if isinstance(batch[0], (list, tuple)):
            batch = torch.stack([torch.stack(ex) for ex in batch]).swapaxes(0, 1)
        else:
            batch = torch.stack(batch)

        return batch

    def persistent_producer(self):
        """Producer function that runs in a loop until stop_signal is set."""
        seed = int(os.getpid())

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        while not self.stop_signal.is_set():
            data = self.fetch_data()
            self.queue.put(data)  # Blocks if the queue is full

    def start_producers(self):
        """Starts num_producers processes."""
        print('STARTING PRODUCERS')
        self.stop_signal.clear()
        for i in range(self.num_producers):
            p = mp.Process(target=self.persistent_producer, daemon=True)
            p.start()
            self.producers.append(p)

        self.producers_exist = True

    def sample_memory(self):
        if not self.producers_exist:
            self.start_producers()
            self.producers_exist = True
        data = self.queue.get()
        return data

    def shutdown(self):
        """Signals producers to stop and waits for them to finish."""
        self.stop_signal.set()
        for p in self.producers:
            p.terminate()
        self.producers.clear()
        self.producers_exist = False
        self.shm.close()
        self.shm.unlink()
        print('Finished Shutdown')

class ParallelSCALEMemory:
    def __init__(self, img_size, embedding_size, capacity, batch_size, transform,
                 num_producers, queue_max_size):

        self.num_examples_seen = 0
        self.capacity = capacity
        self.batch_size = batch_size
        self.memory = np.zeros((0, img_size, img_size, 3))
        self.embeddings = np.zeros((0, embedding_size))


        self.num_producers = num_producers
        self.queue = mp.Queue(maxsize=queue_max_size)
        self.stop_signal = mp.Event()
        self.lock = mp.Lock()
        self.producers = []
        self.producers_exist = False

        self.transform = transform

        self.n_items = 0

    def add(self, x, embeddings):
        self.n_items += x.shape[0]

        all_images = np.concatenate((self.memory, x))
        all_embeddings = np.concatenate((self.embeddings, embeddings.detach().cpu().numpy()))

        # Init selected indices as all indices
        select_indices = np.arange(all_embeddings.shape[0])

        if all_embeddings.shape[0] > self.capacity:  # needs subset selection
            start = time.time_ns()
            select_embeddings = diversipy.subset.psa_select(all_embeddings, self.capacity)
            end = time.time_ns()

            select_indices = np.array([np.where((select_embeddings[i] == all_embeddings).all(axis=1))[0][0] for i in range(len(select_embeddings))])
            select_indices = select_indices.astype(np.int64)

        self.memory = all_images[select_indices]
        self.embeddings = all_embeddings[select_indices]


    def get_task_counts(self):
        _, counts = np.unique(self.mem_labels, return_counts=True)
        return counts

    def fetch_data(self):
        """Simulate fetching data; replace with your actual data source."""
        inds = np.random.choice(min(self.n_items, self.capacity), 
                                self.batch_size, replace=False)
        with self.lock:
            mem_exes = self.memory[inds]
 
        batch = [self.transform(ex) for ex in mem_exes]
        if isinstance(batch[0], (list, tuple)):
            batch = torch.stack([torch.stack(ex) for ex in batch]).swapaxes(0, 1)
        else:
            batch = torch.stack(batch)

        return batch

    def persistent_producer(self):
        """Producer function that runs in a loop until stop_signal is set."""
        seed = int(os.getpid())

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        while not self.stop_signal.is_set():
            data = self.fetch_data()
            self.queue.put(data)  # Blocks if the queue is full

    def start_producers(self):
        """Starts num_producers processes."""
        self.stop_signal.clear()
        for i in range(self.num_producers):
            p = mp.Process(target=self.persistent_producer, daemon=True)
            p.start()
            self.producers.append(p)

        self.producers_exist = True

    def sample_memory(self):
        if not self.producers_exist:
            self.start_producers()
            self.producers_exist = True
        data = self.queue.get()
        return data

    def shutdown(self):
        """Signals producers to stop and waits for them to finish."""
        self.stop_signal.set()
        for p in self.producers:
            p.terminate()
        self.producers.clear()
        self.producers_exist = False
        print('Finished Shutdown')

class ParallelQueueMemory:
    def __init__(self, img_size, capacity, batch_size, transform,
                 num_producers, queue_max_size):

        self.capacity = capacity
        mem = np.ones((self.capacity, img_size, img_size, 3)).astype(np.uint8)
        
        self.shm = shared_memory.SharedMemory(create=True, size=mem.nbytes)

        # Create a NumPy array backed by shared memory
        self.memory = np.ndarray(mem.shape, dtype=mem.dtype, buffer=self.shm.buf)
        
        # Copy the original data to the shared array
        self.memory[:] = mem[:]

        self.mem_labels = -1 * np.ones(self.capacity)
        self.mem_ages = np.zeros(self.capacity)
        self.eviction_ages = []

        self.n_items = mp.Value('i', 0)
        self.batch_size = batch_size

        self.num_producers = num_producers
        self.queue = mp.Queue(maxsize=queue_max_size)
        self.stop_signal = mp.Event()
        self.lock = mp.Lock()
        self.producers = []
        self.producers_exist = False

        self.transform = transform
        self.step = 0

    def log_class_balance(self, N=20):

        counts = np.zeros(N)
        unique_labels, unique_counts = np.unique(self.mem_labels[:int(self.n_items.value)], return_counts=True)
        counts[unique_labels.astype(np.int32)] = unique_counts

        # Calculate the observed distribution
        total_count = len(self.mem_labels[:int(self.n_items.value)])
        observed_distribution = counts / total_count
        
        # Create a uniform distribution
        uniform_distribution = np.ones(N) / N
       
        # Handle zero entries to avoid division by zero or log(0)
        mask = observed_distribution > 0
        kl_divergence = np.sum(observed_distribution[mask] * 
                            np.log(observed_distribution[mask] / uniform_distribution[mask]))
        
        if len(self.eviction_ages) > 0:
            average_duration = np.mean(self.eviction_ages)
            e_min = np.min(self.eviction_ages)
            e_max = np.max(self.eviction_ages)
        else:
            average_duration = 0
            e_min = 0
            e_max = 0

        wandb.log({
            'mem_step': self.step,
            'avg_mem_duration': average_duration,
            'stm_class_balance': kl_divergence,
            'min_duration': e_min,
            'max_duration': e_max, 
        })

    def add(self, x, y):
        evicted_x = []
        evicted_y = []
        for i in range(len(x)):
            # Track evicted elements
            if int(self.n_items.value) >= self.capacity:
                with self.lock:
                    evicted_x.append(self.memory[int(self.n_items.value) % self.capacity])
                evicted_y.append(self.mem_labels[int(self.n_items.value) % self.capacity])
                self.eviction_ages.append(self.mem_ages[int(self.n_items.value) % self.capacity])
                self.mem_ages[int(self.n_items.value) % self.capacity] = 0
                     
            with self.lock:
                self.memory[int(self.n_items.value) % self.capacity] = x[i]
            
            self.mem_labels[int(self.n_items.value) % self.capacity] = y[i]
            self.n_items.value += 1
        
        if int(self.n_items.value) >= self.capacity:
            self.mem_ages += 1
        else:
            self.mem_ages[:int(self.n_items.value) % self.capacity] += 1

        self.step += 1

        return evicted_x, evicted_y


    def fetch_data(self):
        """Simulate fetching data; replace with your actual data source."""

        inds = np.random.choice(min(int(self.n_items.value), self.capacity), 
                                self.batch_size, replace=False)

        #print('Got Inds', flush=True)
        with self.lock:
            mem_exes = self.memory[inds].copy()
 
        #print('Got Exes', flush=True)
        batch = [self.transform(ex) for ex in mem_exes]
        #print('got batch', flush=True)
        if isinstance(batch[0], (list, tuple)):
            batch = torch.stack([torch.stack(ex) for ex in batch]).swapaxes(0, 1)
        else:
            batch = torch.stack(batch)

        #print('ABout to return batch')

        return batch

    def persistent_producer(self):
        """Producer function that runs in a loop until stop_signal is set."""
        seed = int(os.getpid())

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        while not self.stop_signal.is_set():
            data = self.fetch_data()
            self.queue.put(data)  # Blocks if the queue is full

    def start_producers(self):
        """Starts num_producers processes."""
        print('STARTING QUEUE')
        self.stop_signal.clear()
        for i in range(self.num_producers):
            #print(i, flush=True)
            p = mp.Process(target=self.persistent_producer, daemon=True)
            p.start()
            self.producers.append(p)
        #print('Done Starting')

        self.producers_exist = True

    def sample_memory(self):
        if not self.producers_exist:
            self.start_producers()
            self.producers_exist = True
        #print('waiting')
        data = self.queue.get()
        #print('got data')
        return data

    def shutdown(self):
        """Signals producers to stop and waits for them to finish."""
        self.stop_signal.set()
        for p in self.producers:
            p.terminate()
        self.producers.clear()
        self.producers_exist = False
        self.shm.close()
        self.shm.unlink()
        print('Finished Shutdown')

class ParallelLossBufferMemory:
    def __init__(self, img_size, capacity, batch_size, transform, beta, k,
                 num_producers, queue_max_size):

        self.capacity = capacity
        mem = np.ones((self.capacity, img_size, img_size, 3)).astype(np.uint8)
        
        self.shm = shared_memory.SharedMemory(create=True, size=mem.nbytes)

        # Create a NumPy array backed by shared memory
        self.memory = np.ndarray(mem.shape, dtype=mem.dtype, buffer=self.shm.buf)
        
        # Copy the original data to the shared array
        self.memory[:] = mem[:]

        self.mem_labels = -1 * np.ones(self.capacity)
        self.mem_ages = np.zeros(self.capacity)
        self.eviction_ages = []

        self.n_items = mp.Value('i', 0)
        self.batch_size = batch_size

        self.num_producers = num_producers
        self.queue = mp.Queue(maxsize=queue_max_size)
        self.stop_signal = mp.Event()
        self.lock = mp.Lock()
        self.producers = []
        self.producers_exist = False


        self.beta = beta
        self.k = k
        self.loss_history = deque(maxlen=k)
        self.step = 0

        self.transform = transform

    def log_class_balance(self, N=20):

        counts = np.zeros(N)
        unique_labels, unique_counts = np.unique(self.mem_labels[:int(self.n_items.value)], return_counts=True)
        counts[unique_labels.astype(np.int32)] = unique_counts

        # Calculate the observed distribution
        total_count = len(self.mem_labels[:int(self.n_items.value)])
        observed_distribution = counts / total_count
        
        # Create a uniform distribution
        uniform_distribution = np.ones(N) / N
       
        # Handle zero entries to avoid division by zero or log(0)
        mask = observed_distribution > 0
        kl_divergence = np.sum(observed_distribution[mask] * 
                            np.log(observed_distribution[mask] / uniform_distribution[mask]))
        
        if len(self.eviction_ages) > 0:
            average_duration = np.mean(self.eviction_ages)
            e_min = np.min(self.eviction_ages)
            e_max = np.max(self.eviction_ages)
        else:
            average_duration = 0
            e_min = 0
            e_max = 0

        wandb.log({
            'mem_step': self.step,
            'avg_mem_duration': average_duration,
            'stm_class_balance': kl_divergence,
            'min_duration': e_min,
            'max_duration': e_max, 
        })

    def add(self, x, y, loss):
        evicted_x = []
        evicted_y = []
        count = 0
        for i in range(len(x)):
            if int(self.n_items.value) >= self.capacity:
                thresh = np.percentile(self.loss_history, self.beta)
                if loss[i] > thresh:
                    count += 1
                    with self.lock:
                        evicted_x.append(self.memory[int(self.n_items.value) % self.capacity])
                        self.memory[int(self.n_items.value) % self.capacity] = x[i]

                    evicted_y.append(self.mem_labels[int(self.n_items.value) % self.capacity])                    
                    self.mem_labels[int(self.n_items.value) % self.capacity] = y[i] 
                    self.eviction_ages.append(self.mem_ages[int(self.n_items.value) % self.capacity])
                    self.mem_ages[int(self.n_items.value) % self.capacity] = 0
                    self.n_items.value += 1
                else:
                    evicted_x.append(x[i])
                    evicted_y.append(y[i])   
            else:
                thresh = 0
                count += 1  
                with self.lock:
                    self.memory[int(self.n_items.value) % self.capacity] = x[i]
                
                self.mem_labels[int(self.n_items.value) % self.capacity] = y[i]
                self.n_items.value += 1

        if int(self.n_items.value) >= self.capacity:
            self.mem_ages += 1
        else:
            self.mem_ages[:int(self.n_items.value)] += 1

        wandb.log({
            'mem_step': self.step,
            'threshold': thresh,
            'added_count': count,
            'min_loss': min(loss),
            'max_loss': max(loss)
        })

        self.step += 1


        self.loss_history.extend(loss)


        return evicted_x, evicted_y


    def fetch_data(self):
        """Simulate fetching data; replace with your actual data source."""

        inds = np.random.choice(min(int(self.n_items.value), self.capacity), 
                                self.batch_size, replace=False)

        #print('Got Inds', flush=True)
        with self.lock:
            mem_exes = self.memory[inds].copy()
 
        #print('Got Exes', flush=True)
        batch = [self.transform(ex) for ex in mem_exes]
        #print('got batch', flush=True)
        if isinstance(batch[0], (list, tuple)):
            batch = torch.stack([torch.stack(ex) for ex in batch]).swapaxes(0, 1)
        else:
            batch = torch.stack(batch)

        #print('ABout to return batch')

        return batch

    def persistent_producer(self):
        """Producer function that runs in a loop until stop_signal is set."""
        seed = int(os.getpid())

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        while not self.stop_signal.is_set():
            data = self.fetch_data()
            self.queue.put(data)  # Blocks if the queue is full

    def start_producers(self):
        """Starts num_producers processes."""
        print('STARTING QUEUE')
        self.stop_signal.clear()
        for i in range(self.num_producers):
            #print(i, flush=True)
            p = mp.Process(target=self.persistent_producer, daemon=True)
            p.start()
            self.producers.append(p)
        #print('Done Starting')

        self.producers_exist = True

    def sample_memory(self):
        if not self.producers_exist:
            self.start_producers()
            self.producers_exist = True
        #print('waiting')
        data = self.queue.get()
        #print('got data')
        return data

    def shutdown(self):
        """Signals producers to stop and waits for them to finish."""
        self.stop_signal.set()
        for p in self.producers:
            p.terminate()
        self.producers.clear()
        self.producers_exist = False
        self.shm.close()
        self.shm.unlink()
        print('Finished Shutdown')

class ParallelRandomBufferMemory:
    def __init__(self, img_size, capacity, batch_size, transform, k,
                 num_producers, queue_max_size):

        self.capacity = capacity
        self.k = k
        mem = np.ones((self.capacity, img_size, img_size, 3)).astype(np.uint8)
        
        self.shm = shared_memory.SharedMemory(create=True, size=mem.nbytes)

        # Create a NumPy array backed by shared memory
        self.memory = np.ndarray(mem.shape, dtype=mem.dtype, buffer=self.shm.buf)
        
        # Copy the original data to the shared array
        self.memory[:] = mem[:]

        self.mem_labels = -1 * np.ones(self.capacity)
        self.mem_ages = np.zeros(self.capacity)
        self.eviction_ages = []

        self.n_items = mp.Value('i', 0)
        self.batch_size = batch_size

        self.num_producers = num_producers
        self.queue = mp.Queue(maxsize=queue_max_size)
        self.stop_signal = mp.Event()
        self.lock = mp.Lock()
        self.producers = []
        self.producers_exist = False

        self.step = 0

        self.transform = transform

    def log_class_balance(self, N=20):

        counts = np.zeros(N)
        unique_labels, unique_counts = np.unique(self.mem_labels[:int(self.n_items.value)], return_counts=True)
        counts[unique_labels.astype(np.int32)] = unique_counts

        # Calculate the observed distribution
        total_count = len(self.mem_labels[:int(self.n_items.value)])
        observed_distribution = counts / total_count
        
        # Create a uniform distribution
        uniform_distribution = np.ones(N) / N
       
        # Handle zero entries to avoid division by zero or log(0)
        mask = observed_distribution > 0
        kl_divergence = np.sum(observed_distribution[mask] * 
                            np.log(observed_distribution[mask] / uniform_distribution[mask]))
        
        if len(self.eviction_ages) > 0:
            average_duration = np.mean(self.eviction_ages)
            e_min = np.min(self.eviction_ages)
            e_max = np.max(self.eviction_ages)
        else:
            average_duration = 0
            e_min = 0
            e_max = 0

        wandb.log({
            'mem_step': self.step,
            'avg_mem_duration': average_duration,
            'stm_class_balance': kl_divergence,
            'min_duration': e_min,
            'max_duration': e_max, 
        })

    def add(self, x, y):
        evicted_x = []
        evicted_y = []
        inds = np.random.choice(len(x), self.k).flatten().tolist()
        for i in inds:
            # Track evicted elements
            if int(self.n_items.value) >= self.capacity:
                with self.lock:
                    evicted_x.append(self.memory[int(self.n_items.value) % self.capacity])
                evicted_y.append(self.mem_labels[int(self.n_items.value) % self.capacity])
                self.eviction_ages.append(self.mem_ages[int(self.n_items.value) % self.capacity])
                self.mem_ages[int(self.n_items.value) % self.capacity] = 0
                     
            with self.lock:
                self.memory[int(self.n_items.value) % self.capacity] = x[i]
            
            self.mem_labels[int(self.n_items.value) % self.capacity] = y[i]
            self.n_items.value += 1


        if int(self.n_items.value) >= self.capacity:
            self.mem_ages += 1
        else:
            self.mem_ages[:int(self.n_items.value)] += 1
        
        self.step += 1

        return evicted_x, evicted_y


    def fetch_data(self):
        """Simulate fetching data; replace with your actual data source."""

        inds = np.random.choice(min(int(self.n_items.value), self.capacity), 
                                self.batch_size, replace=False)

        #print('Got Inds', flush=True)
        with self.lock:
            mem_exes = self.memory[inds].copy()
 
        #print('Got Exes', flush=True)
        batch = [self.transform(ex) for ex in mem_exes]
        #print('got batch', flush=True)
        if isinstance(batch[0], (list, tuple)):
            batch = torch.stack([torch.stack(ex) for ex in batch]).swapaxes(0, 1)
        else:
            batch = torch.stack(batch)

        #print('ABout to return batch')

        return batch

    def persistent_producer(self):
        """Producer function that runs in a loop until stop_signal is set."""
        seed = int(os.getpid())

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        while not self.stop_signal.is_set():
            data = self.fetch_data()
            self.queue.put(data)  # Blocks if the queue is full

    def start_producers(self):
        """Starts num_producers processes."""
        print('STARTING QUEUE')
        self.stop_signal.clear()
        for i in range(self.num_producers):
            #print(i, flush=True)
            p = mp.Process(target=self.persistent_producer, daemon=True)
            p.start()
            self.producers.append(p)
        #print('Done Starting')

        self.producers_exist = True

    def sample_memory(self):
        if not self.producers_exist:
            self.start_producers()
            self.producers_exist = True
        #print('waiting')
        data = self.queue.get()
        #print('got data')
        return data

    def shutdown(self):
        """Signals producers to stop and waits for them to finish."""
        self.stop_signal.set()
        for p in self.producers:
            p.terminate()
        self.producers.clear()
        self.producers_exist = False
        self.shm.close()
        self.shm.unlink()
        print('Finished Shutdown')

class ParallelRepulsionMemory:
    def __init__(self, img_size, capacity, batch_size, transform, raw_transform, beta, k, prune,
                 num_producers, queue_max_size):
        

        self.init_capacity = capacity
        self.capacity = capacity * 100

        ###########
        mem = np.ones((self.capacity, img_size, img_size, 3)).astype(np.uint8)
        
        self.shm = shared_memory.SharedMemory(create=True, size=mem.nbytes)

        # Create a NumPy array backed by shared memory
        self.memory = np.ndarray(mem.shape, dtype=mem.dtype, buffer=self.shm.buf)
        
        # Copy the original data to the shared array
        self.memory[:] = mem[:]
        ############

        self.mem_zs = np.zeros((self.capacity, 512))

        self.mem_labels = -1 * np.ones(self.capacity)
        self.mem_ages = -1 * np.ones(self.capacity)

        self.stored_raws = []

        self.n_items = mp.Value('i', 0)
        self.batch_size = batch_size

        self.num_producers = num_producers
        self.queue = mp.Queue(maxsize=queue_max_size)
        self.stop_signal = mp.Event()
        self.lock = mp.Lock()
        self.producers = []
        self.producers_exist = False

        self.transform = transform

        self.transform = transform
        self.raw_transform = raw_transform

        #self.k = k
        self.beta = beta
        self.prune = prune
        self.thresh=None
        self.not_init = True
        self.step = 0

    def log_class_balance(self, N=20):

        counts = np.zeros(N)
        unique_labels, unique_counts = np.unique(self.mem_labels[:int(self.n_items.value)], return_counts=True)
        counts[unique_labels.astype(np.int32)] = unique_counts


        class_data = [[i, counts[i]] for i in range(N)]
        class_tab = wandb.Table(data=class_data, columns=['classes', 'counts'])

        task_counts = np.zeros(5)
        unique_labels_task, unique_counts_task = np.unique(self.mem_labels[:int(self.n_items.value)] // 20, return_counts=True)
        task_counts[unique_labels_task.astype(np.int32)] = unique_counts_task
        task_data = [[i, task_counts[i]] for i in range(5)]
        task_tab = wandb.Table(data=task_data, columns=['tasks', 'counts'])


        # Calculate the observed distribution
        total_count = len(self.mem_labels[:int(self.n_items.value)])
        observed_distribution = counts / total_count
        
        # Create a uniform distribution
        uniform_distribution = np.ones(N) / N
       
        # Handle zero entries to avoid division by zero or log(0)
        mask = observed_distribution > 0
        kl_divergence = np.sum(observed_distribution[mask] * 
                            np.log(observed_distribution[mask] / uniform_distribution[mask]))
        
        wandb.log({
            'mem_step': self.step,
            'ltm_class_balance': kl_divergence,
        })

        wandb.log(
            {
            f'class_counts_{self.step}': wandb.plot.bar(table=class_tab, label="classes", value="counts", title=f'Class Count Dist {self.step}'),
            f'task_counts_{self.step}': wandb.plot.bar(table=task_tab, label="tasks", value="counts", title=f'Task Count Dist {self.step}')
            }
        )



    def update_encoder(self, encoder=None):
        self.encoder = encoder

        self.encoder.eval()

        if int(self.n_items.value) > 1:
            self.mem_zs[:int(self.n_items.value)] = torch.nn.functional.normalize(torch.cat([self.encoder.embed(torch.stack(self.stored_raws[i:i+512])) for i in range(0, len(self.stored_raws), 512)])).detach().cpu().numpy()
            memz = torch.tensor(self.mem_zs[:int(self.n_items.value)], dtype=torch.float32).to('cuda')
            existing_sim = 1 - torch.matmul(memz, memz.T)
            existing_sim = existing_sim[~torch.eye(existing_sim.size(0), dtype=bool)].view(existing_sim.shape[0], existing_sim.shape[0]-1)

            closest, closest_inds = torch.min(existing_sim, dim=1)
            closest = closest.detach().cpu().numpy()
            self.thresh = np.percentile(closest, self.beta)

            ## Prune
            if int(self.n_items.value) > 512 and self.prune:
                print('Pruning', flush=True)
                avg = torch.mean(existing_sim, dim=1).detach().cpu().numpy()
                prune_thresh = np.percentile(avg, 100-self.beta)
                

                # prob of pruning based on distance
                #prune_prob = np.clip(self.k * (prune_thresh - avg) / (prune_thresh -  np.min(avg)), 0, self.k)
                #keep_inds = np.argwhere(prune_prob <= np.random.uniform(0, 1, len(prune_prob))).flatten()
                
                # prob of pruning based on time
                keep_inds = np.argwhere((avg > prune_thresh) | ((self.mem_ages[:int(self.n_items.value)] / max(self.mem_ages)) < np.random.uniform(0, 1, int(self.n_items.value)))).flatten()
                
                # log prune age
                prune_inds = set(np.arange(int(self.n_items.value))) - set(keep_inds)
                prune_age = np.mean(self.mem_ages[np.array(list(prune_inds))])
    
                
                num_pruned = int(self.n_items.value) - len(keep_inds)
                self.n_items.value = len(keep_inds)
                
                with self.lock:
                    self.memory[:int(self.n_items.value)] = self.memory[keep_inds]
                self.mem_zs[:int(self.n_items.value)] = self.mem_zs[keep_inds]
                self.mem_ages[:int(self.n_items.value)] = self.mem_ages[keep_inds]
                self.mem_labels[:int(self.n_items.value)]

                new_raws = []
                for i in range(len(self.stored_raws)):
                    if i in keep_inds:
                        new_raws.append(self.stored_raws[i])

                self.stored_raws = new_raws

                wandb.log({
                    'mem_step': self.step,
                    'num_pruned': num_pruned,
                    'prune_age': prune_age
                })

    def add(self, x, y):
        raws_temp = [self.raw_transform(x[i]) for i in range(len(x))]
        new_zs = self.encoder.embed(torch.stack(raws_temp).to('cuda'))
        new_zs = torch.nn.functional.normalize(new_zs).detach()

        if int(self.n_items.value) < self.init_capacity:
            for i in range(len(x)):
                if int(self.n_items.value) >= self.init_capacity:
                    break

                if np.random.uniform(0, 1) >= 0.75:
                    with self.lock:
                        self.memory[int(self.n_items.value)] = x[i]

                    self.mem_labels[int(self.n_items.value)] = y[i]
                    
                    self.mem_zs[int(self.n_items.value)] = new_zs[i].cpu()
                    self.mem_ages[int(self.n_items.value)] = 0
                    self.stored_raws.append(raws_temp[i])
                    self.n_items.value += 1

        else:
            similarities = torch.mm(new_zs, torch.tensor(self.mem_zs[:int(self.n_items.value)], dtype=torch.float32).to('cuda').T)  # (M x N)

            distances = 1 - similarities  

            closest = torch.amin(distances, dim=1).detach().cpu().numpy()

            possible_new = np.argwhere(closest > self.thresh)

            for i in range(len(possible_new)):
                # new image in memory
                ind = int(possible_new[i])
                with self.lock:
                    self.memory[int(self.n_items.value)] = x[ind]
            
                self.mem_labels[int(self.n_items.value)] = y[ind]
                self.mem_ages[int(self.n_items.value)] = 0
                
                self.mem_zs[int(self.n_items.value)] = new_zs[ind].cpu()
                self.stored_raws.append(raws_temp[ind])
                
                self.n_items.value += 1


            wandb.log({
                'mem_step': self.step,
                'repulsion_threshold': self.thresh,
                'mean_closest': np.mean(closest),
                'var_closest': np.var(closest),
                'novel_count': len(possible_new),
                'ltm_size': int(self.n_items.value)
                })
        
        self.mem_ages[:int(self.n_items.value)] += 1
        self.step += 1

    def fetch_data(self):
        """Simulate fetching data; replace with your actual data source."""

        inds = np.random.choice(min(int(self.n_items.value), self.capacity), 
                                self.batch_size, replace=False)

        with self.lock:
            mem_exes = self.memory[inds].copy()
 
        batch = [self.transform(ex) for ex in mem_exes]
        if isinstance(batch[0], (list, tuple)):
            batch = torch.stack([torch.stack(ex) for ex in batch]).swapaxes(0, 1)
        else:
            batch = torch.stack(batch)

        #print('ABout to return batch')

        return batch

    def persistent_producer(self):
        """Producer function that runs in a loop until stop_signal is set."""
        seed = int(os.getpid())

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        while not self.stop_signal.is_set():
            data = self.fetch_data()
            self.queue.put(data)  # Blocks if the queue is full

    def start_producers(self):
        """Starts num_producers processes."""
        print('STARTING QUEUE')
        self.stop_signal.clear()
        for i in range(self.num_producers):
            #print(i, flush=True)
            p = mp.Process(target=self.persistent_producer, daemon=True)
            p.start()
            self.producers.append(p)
        #print('Done Starting')

        self.producers_exist = True

    def sample_memory(self):
        if not self.producers_exist:
            self.start_producers()
            self.producers_exist = True
        #print('waiting')
        data = self.queue.get()
        #print('got data')
        return data

    def shutdown(self):
        """Signals producers to stop and waits for them to finish."""
        self.stop_signal.set()
        for p in self.producers:
            p.terminate()
        self.producers.clear()
        self.producers_exist = False
        self.shm.close()
        self.shm.unlink()
        print('Finished Shutdown')


class ParallelRepulsionMemory2:
    def __init__(self, img_size, capacity, batch_size, transform, raw_transform, beta, k, prune,
                 num_producers, queue_max_size, num_tasks, classes_per_task):
        

        self.init_capacity = capacity
        self.capacity = capacity * 100

        ###########
        mem = np.ones((self.capacity, img_size, img_size, 3)).astype(np.uint8)
        
        self.shm = shared_memory.SharedMemory(create=True, size=mem.nbytes)

        # Create a NumPy array backed by shared memory
        self.memory = np.ndarray(mem.shape, dtype=mem.dtype, buffer=self.shm.buf)
        
        # Copy the original data to the shared array
        self.memory[:] = mem[:]
        ############

        self.mem_zs = np.zeros((self.capacity, 512))

        self.mem_labels = -1 * np.ones(self.capacity)
        self.mem_ages = -1 * np.ones(self.capacity)

        self.stored_raws = []

        self.n_items = mp.Value('i', 0)
        self.batch_size = batch_size

        self.num_producers = num_producers
        self.queue = mp.Queue(maxsize=queue_max_size)
        self.stop_signal = mp.Event()
        self.lock = mp.Lock()
        self.producers = []
        self.producers_exist = False

        self.transform = transform

        self.transform = transform
        self.raw_transform = raw_transform

        self.k = k
        self.beta = beta
        self.prune = prune
        self.thresh = 0.0
        self.not_init = True
        self.step = 0

        ## Logging

        self.classes_per_task = classes_per_task
        self.num_tasks = num_tasks

    def log_class_balance(self, N, current_task):

        counts = np.zeros(N)
        unique_labels, unique_counts = np.unique(self.mem_labels[:int(self.n_items.value)], return_counts=True)
        counts[unique_labels.astype(np.int32)] = unique_counts


        class_data = [[i, counts[i]] for i in range(N)]
        class_tab = wandb.Table(data=class_data, columns=['classes', 'counts'])

        task_counts = np.zeros(self.num_tasks)
        unique_labels_task, unique_counts_task = np.unique(self.mem_labels[:int(self.n_items.value)] // self.classes_per_task, return_counts=True)
        task_counts[unique_labels_task.astype(np.int32)] = unique_counts_task
        task_data = [[i, task_counts[i]] for i in range(self.num_tasks)]
        task_tab = wandb.Table(data=task_data, columns=['tasks', 'counts'])


        # Calculate the observed distribution
        total_count = len(self.mem_labels[:int(self.n_items.value)])
        observed_distribution = counts / total_count
        
        # Create a uniform distribution
        uniform_distribution = np.ones(N) / N
       
        # Handle zero entries to avoid division by zero or log(0)
        mask = observed_distribution > 0
        kl_divergence = np.sum(observed_distribution[mask] * 
                            np.log(observed_distribution[mask] / uniform_distribution[mask]))
        
        wandb.log({
            'mem_step': self.step,
            'ltm_class_balance': kl_divergence,
        })

        wandb.log(
            {
            f'class_counts_{self.step}': wandb.plot.bar(table=class_tab, label="classes", value="counts", title=f'Class Count Dist {self.step}'),
            f'task_counts_{self.step}': wandb.plot.bar(table=task_tab, label="tasks", value="counts", title=f'Task Count Dist {self.step}')
            }
        )

    def update_encoder(self, encoder=None):
        self.encoder = encoder

        self.encoder.eval()

        if int(self.n_items.value) > 1:
            self.mem_zs[:int(self.n_items.value)] = torch.nn.functional.normalize(torch.cat([self.encoder.embed(torch.stack(self.stored_raws[i:i+512])) for i in range(0, len(self.stored_raws), 512)])).detach().cpu().numpy()
            memz = torch.tensor(self.mem_zs[:int(self.n_items.value)], dtype=torch.float32).to('cuda')
            existing_sim = 1 - torch.matmul(memz, memz.T)
            existing_sim = existing_sim[~torch.eye(existing_sim.size(0), dtype=bool)].view(existing_sim.shape[0], existing_sim.shape[0]-1)

            closest, closest_inds = torch.min(existing_sim, dim=1)
            closest = closest.detach().cpu().numpy()
            self.thresh = np.percentile(closest, self.beta)

            ## Prune
            if int(self.n_items.value) > 512 and self.prune:
                print('Pruning', flush=True)
                
                bins = np.digitize(self.mem_ages[:int(self.n_items.value)], bins=np.linspace(0, np.max(self.mem_ages), num=10))
                bin_counts = np.bincount(bins, minlength=10)
                bin_inds = [np.where(bins == i)[0] for i in range(1, 11)]
                
                avg_sim = torch.mean(existing_sim, dim=1).detach().cpu().numpy()
                #z_scores = (avg_sim - np.mean(avg_sim)) / np.std(avg_sim)

                all_zscores = np.zeros_like(avg_sim)

                min_zscores = []
                for i in range(len(bin_inds)):
                    if len(bin_inds[i]) > 0:
                        bin_sim = avg_sim[bin_inds[i]]
                        z_scores = (bin_sim - np.mean(bin_sim)) / np.std(bin_sim)
                        min_zscores.append(np.min(z_scores))
                        all_zscores[bin_inds[i]] = z_scores
                    else:
                        min_zscores.append(0)

                scores = {f'min_z_bin_{i}': min_zscores[i] for i in range(len(min_zscores))}
                scores.update({'mem_step': self.step})
                wandb.log(scores)

                # Get the bottom k indices of the all_zscores array
                #keep_inds = np.argsort(all_zscores)[:-10]
                keep_inds = np.argwhere(all_zscores > -self.k).flatten()
                
                # log prune age
                prune_inds = set(np.arange(int(self.n_items.value))) - set(keep_inds)                
                num_pruned = int(self.n_items.value) - len(keep_inds)
                if num_pruned > 0:
                    prune_age = np.mean(self.mem_ages[np.array(list(prune_inds))])
                else:
                    prune_age = 0
                self.n_items.value = len(keep_inds)
                
                with self.lock:
                    self.memory[:int(self.n_items.value)] = self.memory[keep_inds]
                self.mem_zs[:int(self.n_items.value)] = self.mem_zs[keep_inds]
                self.mem_ages[:int(self.n_items.value)] = self.mem_ages[keep_inds]
                self.mem_labels[:int(self.n_items.value)]

                new_raws = []
                for i in range(len(self.stored_raws)):
                    if i in keep_inds:
                        new_raws.append(self.stored_raws[i])

                self.stored_raws = new_raws

                wandb.log({
                    'mem_step': self.step,
                    'num_pruned': num_pruned,
                    'prune_age': prune_age
                })
                
                        



    def add(self, x, y):
        raws_temp = [self.raw_transform(x[i]) for i in range(len(x))]
        new_zs = self.encoder.embed(torch.stack(raws_temp).to('cuda'))
        new_zs = torch.nn.functional.normalize(new_zs).detach()

        if int(self.n_items.value) < self.init_capacity:
            for i in range(len(x)):
                if int(self.n_items.value) >= self.init_capacity:
                    break

                if np.random.uniform(0, 1) >= 0.75:
                    with self.lock:
                        self.memory[int(self.n_items.value)] = x[i]

                    self.mem_labels[int(self.n_items.value)] = y[i]
                    
                    self.mem_zs[int(self.n_items.value)] = new_zs[i].cpu()
                    self.mem_ages[int(self.n_items.value)] = 0
                    self.stored_raws.append(raws_temp[i])
                    self.n_items.value += 1

        else:
            similarities = torch.mm(new_zs, torch.tensor(self.mem_zs[:int(self.n_items.value)], dtype=torch.float32).to('cuda').T)  # (M x N)

            distances = 1 - similarities  

            closest = torch.amin(distances, dim=1).detach().cpu().numpy()

            possible_new = np.argwhere(closest > self.thresh)

            for i in range(len(possible_new)):
                # new image in memory
                ind = int(possible_new[i])
                with self.lock:
                    self.memory[int(self.n_items.value)] = x[ind]
            
                self.mem_labels[int(self.n_items.value)] = y[ind]
                self.mem_ages[int(self.n_items.value)] = 0
                
                self.mem_zs[int(self.n_items.value)] = new_zs[ind].cpu()
                self.stored_raws.append(raws_temp[ind])
                
                self.n_items.value += 1


            wandb.log({
                'mem_step': self.step,
                'repulsion_threshold': self.thresh,
                'mean_closest': np.mean(closest),
                'var_closest': np.var(closest),
                'novel_count': len(possible_new),
                'ltm_size': int(self.n_items.value)
                })
        
        self.mem_ages[:int(self.n_items.value)] += 1
        self.step += 1

    def fetch_data(self):
        """Simulate fetching data; replace with your actual data source."""

        inds = np.random.choice(min(int(self.n_items.value), self.capacity), 
                                self.batch_size, replace=False)

        with self.lock:
            mem_exes = self.memory[inds].copy()
 
        batch = [self.transform(ex) for ex in mem_exes]
        if isinstance(batch[0], (list, tuple)):
            batch = torch.stack([torch.stack(ex) for ex in batch]).swapaxes(0, 1)
        else:
            batch = torch.stack(batch)

        #print('ABout to return batch')

        return batch

    def persistent_producer(self):
        """Producer function that runs in a loop until stop_signal is set."""
        seed = int(os.getpid())

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        while not self.stop_signal.is_set():
            data = self.fetch_data()
            self.queue.put(data)  # Blocks if the queue is full

    def start_producers(self):
        """Starts num_producers processes."""
        print('STARTING QUEUE')
        self.stop_signal.clear()
        for i in range(self.num_producers):
            #print(i, flush=True)
            p = mp.Process(target=self.persistent_producer, daemon=True)
            p.start()
            self.producers.append(p)
        #print('Done Starting')

        self.producers_exist = True

    def sample_memory(self):
        if not self.producers_exist:
            self.start_producers()
            self.producers_exist = True
        #print('waiting')
        data = self.queue.get()
        #print('got data')
        return data

    def shutdown(self):
        """Signals producers to stop and waits for them to finish."""
        self.stop_signal.set()
        for p in self.producers:
            p.terminate()
        self.producers.clear()
        self.producers_exist = False
        self.shm.close()
        self.shm.unlink()
        print('Finished Shutdown')

class ParallelRandomSamplingMemory:
    def __init__(self, img_size, capacity, batch_size, transform, k,
                 num_producers, queue_max_size):

        self.init_capacity = capacity
        self.capacity = capacity * 100
        self.k = k
        mem = np.ones((self.capacity, img_size, img_size, 3)).astype(np.uint8)
        
        self.shm = shared_memory.SharedMemory(create=True, size=mem.nbytes)

        # Create a NumPy array backed by shared memory
        self.memory = np.ndarray(mem.shape, dtype=mem.dtype, buffer=self.shm.buf)
        
        # Copy the original data to the shared array
        self.memory[:] = mem[:]

        self.mem_labels = -1 * np.ones(self.capacity)
        self.mem_ages = np.zeros(self.capacity)
        self.eviction_ages = []

        self.n_items = mp.Value('i', 0)
        self.batch_size = batch_size

        self.num_producers = num_producers
        self.queue = mp.Queue(maxsize=queue_max_size)
        self.stop_signal = mp.Event()
        self.lock = mp.Lock()
        self.producers = []
        self.producers_exist = False

        self.step = 0

        self.transform = transform

    def log_class_balance(self, N=20):

        counts = np.zeros(N)
        unique_labels, unique_counts = np.unique(self.mem_labels[:int(self.n_items.value)], return_counts=True)
        counts[unique_labels.astype(np.int32)] = unique_counts


        class_data = [[i, counts[i]] for i in range(N)]
        class_tab = wandb.Table(data=class_data, columns=['classes', 'counts'])

        task_counts = np.zeros(5)
        unique_labels_task, unique_counts_task = np.unique(self.mem_labels[:int(self.n_items.value)] // 20, return_counts=True)
        task_counts[unique_labels_task.astype(np.int32)] = unique_counts_task
        task_data = [[i, task_counts[i]] for i in range(5)]
        task_tab = wandb.Table(data=task_data, columns=['tasks', 'counts'])


        # Calculate the observed distribution
        total_count = len(self.mem_labels[:int(self.n_items.value)])
        observed_distribution = counts / total_count
        
        # Create a uniform distribution
        uniform_distribution = np.ones(N) / N
       
        # Handle zero entries to avoid division by zero or log(0)
        mask = observed_distribution > 0
        kl_divergence = np.sum(observed_distribution[mask] * 
                            np.log(observed_distribution[mask] / uniform_distribution[mask]))
        
        wandb.log({
            'mem_step': self.step,
            'ltm_class_balance': kl_divergence,
        })

        wandb.log(
            {
            f'class_counts_{self.step}': wandb.plot.bar(table=class_tab, label="classes", value="counts", title=f'Class Count Dist {self.step}'),
            f'task_counts_{self.step}': wandb.plot.bar(table=task_tab, label="tasks", value="counts", title=f'Task Count Dist {self.step}')
            }
        )
        
    def add(self, x, y):
        evicted_x = []
        evicted_y = []
        p = np.random.uniform(0, 1, len(x))
        inds = np.argwhere(p < self.k).flatten()
        for i in inds:
            # Track evicted elements
            if int(self.n_items.value) >= self.init_capacity:
                with self.lock:
                    evicted_x.append(self.memory[int(self.n_items.value)])
                evicted_y.append(self.mem_labels[int(self.n_items.value)])
                self.eviction_ages.append(self.mem_ages[int(self.n_items.value)])
                self.mem_ages[int(self.n_items.value)] = 0
                     
            with self.lock:
                self.memory[int(self.n_items.value)] = x[i]
            
            self.mem_labels[int(self.n_items.value)] = y[i]
            self.eviction_ages.append(self.mem_ages[int(self.n_items.value)])
            self.mem_ages[int(self.n_items.value)] = 0
            self.n_items.value += 1


        if int(self.n_items.value) >= self.capacity:
            self.mem_ages += 1
        else:
            self.mem_ages[:int(self.n_items.value)] += 1
        


        self.step += 1

        wandb.log({
                'mem_step': self.step,
                'ltm_size': int(self.n_items.value)
                })

        return evicted_x, evicted_y


    def fetch_data(self):
        """Simulate fetching data; replace with your actual data source."""

        inds = np.random.choice(min(int(self.n_items.value), self.capacity), 
                                self.batch_size, replace=False)

        #print('Got Inds', flush=True)
        with self.lock:
            mem_exes = self.memory[inds].copy()
 
        #print('Got Exes', flush=True)
        batch = [self.transform(ex) for ex in mem_exes]
        #print('got batch', flush=True)
        if isinstance(batch[0], (list, tuple)):
            batch = torch.stack([torch.stack(ex) for ex in batch]).swapaxes(0, 1)
        else:
            batch = torch.stack(batch)

        #print('ABout to return batch')

        return batch

    def persistent_producer(self):
        """Producer function that runs in a loop until stop_signal is set."""
        seed = int(os.getpid())

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        while not self.stop_signal.is_set():
            data = self.fetch_data()
            self.queue.put(data)  # Blocks if the queue is full

    def start_producers(self):
        """Starts num_producers processes."""
        print('STARTING QUEUE')
        self.stop_signal.clear()
        for i in range(self.num_producers):
            #print(i, flush=True)
            p = mp.Process(target=self.persistent_producer, daemon=True)
            p.start()
            self.producers.append(p)
        #print('Done Starting')

        self.producers_exist = True

    def sample_memory(self):
        if not self.producers_exist:
            self.start_producers()
            self.producers_exist = True
        #print('waiting')
        data = self.queue.get()
        #print('got data')
        return data

    def shutdown(self):
        """Signals producers to stop and waits for them to finish."""
        self.stop_signal.set()
        for p in self.producers:
            p.terminate()
        self.producers.clear()
        self.producers_exist = False
        self.shm.close()
        self.shm.unlink()
        print('Finished Shutdown')

class ParallelFixedThreshMemory:
    def __init__(self, img_size, capacity, batch_size, transform, raw_transform, beta, k, thresh,
                 num_producers, queue_max_size):
        

        self.init_capacity = capacity
        self.capacity = capacity * 100

        ###########
        mem = np.ones((self.capacity, img_size, img_size, 3)).astype(np.uint8)
        
        self.shm = shared_memory.SharedMemory(create=True, size=mem.nbytes)

        # Create a NumPy array backed by shared memory
        self.memory = np.ndarray(mem.shape, dtype=mem.dtype, buffer=self.shm.buf)
        
        # Copy the original data to the shared array
        self.memory[:] = mem[:]
        ############

        self.mem_zs = np.zeros((self.capacity, 512))

        self.mem_labels = -1 * np.ones(self.capacity)
        self.mem_ages = -1 * np.ones(self.capacity)

        self.stored_raws = []

        self.n_items = mp.Value('i', 0)
        self.batch_size = batch_size

        self.num_producers = num_producers
        self.queue = mp.Queue(maxsize=queue_max_size)
        self.stop_signal = mp.Event()
        self.lock = mp.Lock()
        self.producers = []
        self.producers_exist = False

        self.transform = transform

        self.transform = transform
        self.raw_transform = raw_transform

        self.k = k
        self.beta = beta
        self.thresh= thresh
        self.not_init = True
        self.step = 0

    def log_class_balance(self, N=20):

        counts = np.zeros(N)
        unique_labels, unique_counts = np.unique(self.mem_labels[:int(self.n_items.value)], return_counts=True)
        counts[unique_labels.astype(np.int32)] = unique_counts


        class_data = [[i, counts[i]] for i in range(N)]
        class_tab = wandb.Table(data=class_data, columns=['classes', 'counts'])

        task_counts = np.zeros(5)
        unique_labels_task, unique_counts_task = np.unique(self.mem_labels[:int(self.n_items.value)] // 20, return_counts=True)
        task_counts[unique_labels_task.astype(np.int32)] = unique_counts_task
        task_data = [[i, task_counts[i]] for i in range(5)]
        task_tab = wandb.Table(data=task_data, columns=['tasks', 'counts'])


        # Calculate the observed distribution
        total_count = len(self.mem_labels[:int(self.n_items.value)])
        observed_distribution = counts / total_count
        
        # Create a uniform distribution
        uniform_distribution = np.ones(N) / N
       
        # Handle zero entries to avoid division by zero or log(0)
        mask = observed_distribution > 0
        kl_divergence = np.sum(observed_distribution[mask] * 
                            np.log(observed_distribution[mask] / uniform_distribution[mask]))
        
        wandb.log({
            'mem_step': self.step,
            'ltm_class_balance': kl_divergence,
        })

        wandb.log(
            {
            f'class_counts_{self.step}': wandb.plot.bar(table=class_tab, label="classes", value="counts", title=f'Class Count Dist {self.step}'),
            f'task_counts_{self.step}': wandb.plot.bar(table=task_tab, label="tasks", value="counts", title=f'Task Count Dist {self.step}')
            }
        )

    def update_encoder(self, encoder=None):
        self.encoder = encoder

        self.encoder.eval()

        if int(self.n_items.value) > 1:
            self.mem_zs[:int(self.n_items.value)] = torch.nn.functional.normalize(torch.cat([self.encoder.embed(torch.stack(self.stored_raws[i:i+512])) for i in range(0, len(self.stored_raws), 512)])).detach().cpu().numpy()
            memz = torch.tensor(self.mem_zs[:int(self.n_items.value)], dtype=torch.float32).to('cuda')
            existing_sim = 1 - torch.matmul(memz, memz.T)
            existing_sim = existing_sim[~torch.eye(existing_sim.size(0), dtype=bool)].view(existing_sim.shape[0], existing_sim.shape[0]-1)

            closest, _ = torch.min(existing_sim, dim=1)
            closest = closest.detach().cpu().numpy()

    def add(self, x, y):
        raws_temp = [self.raw_transform(x[i]) for i in range(len(x))]
        new_zs = self.encoder.embed(torch.stack(raws_temp).to('cuda'))
        new_zs = torch.nn.functional.normalize(new_zs).detach()

        if int(self.n_items.value) < self.init_capacity:
            for i in range(len(x)):
                if int(self.n_items.value) >= self.init_capacity:
                    break

                if np.random.uniform(0, 1) >= 0.75:
                    with self.lock:
                        self.memory[int(self.n_items.value)] = x[i]

                    self.mem_labels[int(self.n_items.value)] = y[i]
                    
                    self.mem_zs[int(self.n_items.value)] = new_zs[i].cpu()
                    self.mem_ages[int(self.n_items.value)] = 0
                    self.stored_raws.append(raws_temp[i])
                    self.n_items.value += 1

        else:
            similarities = torch.mm(new_zs, torch.tensor(self.mem_zs[:int(self.n_items.value)], dtype=torch.float32).to('cuda').T)  # (M x N)

            distances = 1 - similarities  

            closest = torch.amin(distances, dim=1).detach().cpu().numpy()

            possible_new = np.argwhere(closest > self.thresh)


            for i in range(len(possible_new)):
                # new image in memory
                ind = int(possible_new[i])
                with self.lock:
                    self.memory[int(self.n_items.value)] = x[ind]
            
                self.mem_labels[int(self.n_items.value)] = y[ind]
                self.mem_ages[int(self.n_items.value)] = 0
                
                self.mem_zs[int(self.n_items.value)] = new_zs[ind].cpu()
                self.stored_raws.append(raws_temp[ind])
                
                self.n_items.value += 1

            wandb.log({
                'mem_step': self.step,
                'repulsion_threshold': self.thresh,
                'mean_closest': np.mean(closest),
                'var_closest': np.var(closest),
                'novel_count': len(possible_new),
                'ltm_size': int(self.n_items.value)
                })
        
        self.mem_ages[:int(self.n_items.value)] += 1
        self.step += 1

    def fetch_data(self):
        """Simulate fetching data; replace with your actual data source."""

        inds = np.random.choice(min(int(self.n_items.value), self.capacity), 
                                self.batch_size, replace=False)

        with self.lock:
            mem_exes = self.memory[inds].copy()
 
        batch = [self.transform(ex) for ex in mem_exes]
        if isinstance(batch[0], (list, tuple)):
            batch = torch.stack([torch.stack(ex) for ex in batch]).swapaxes(0, 1)
        else:
            batch = torch.stack(batch)

        #print('ABout to return batch')

        return batch

    def persistent_producer(self):
        """Producer function that runs in a loop until stop_signal is set."""
        seed = int(os.getpid())

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        while not self.stop_signal.is_set():
            data = self.fetch_data()
            self.queue.put(data)  # Blocks if the queue is full

    def start_producers(self):
        """Starts num_producers processes."""
        print('STARTING QUEUE')
        self.stop_signal.clear()
        for i in range(self.num_producers):
            #print(i, flush=True)
            p = mp.Process(target=self.persistent_producer, daemon=True)
            p.start()
            self.producers.append(p)
        #print('Done Starting')

        self.producers_exist = True

    def sample_memory(self):
        if not self.producers_exist:
            self.start_producers()
            self.producers_exist = True
        #print('waiting')
        data = self.queue.get()
        #print('got data')
        return data

    def shutdown(self):
        """Signals producers to stop and waits for them to finish."""
        self.stop_signal.set()
        for p in self.producers:
            p.terminate()
        self.producers.clear()
        self.producers_exist = False
        self.shm.close()
        self.shm.unlink()
        print('Finished Shutdown')