
import jax
import gc
from helpers import get_mup_lrs_from_state, cast_to_bf16
import threading
import pprint
import jax.numpy as jnp

class MuTask(object):

  def get_mup_state(self, state, eps_mult=None):
    # import pdb; pdb.set_trace()
    if self.mup_state is None:
      if state == {}:
        raise ValueError("State is empty, cannot get mup state from it")

      device = jax.devices()[jax.process_index()]
      print(device)
      self.mup_state = get_mup_lrs_from_state(state)
      self.mup_state = jax.tree_util.tree_map(lambda x: jax.device_put(x, device), self.mup_state)
      if eps_mult is not None:
        self.mup_eps_mult = {'eps_mult':jax.device_put(jnp.array(eps_mult), device)}

    state['mup_lrs_to_use'] = self.mup_state
    if eps_mult is not None:
      state['eps_mult'] = self.mup_eps_mult
    return state

  
  def init_mup_state(self): 
    #create and save mup state outside of jit
    key = jax.random.PRNGKey(0)
    params, state = self.init_with_state(key)
    del params
    del state
        
    # Force garbage collection in a separate thread to make it non-blocking
    gc_thread = threading.Thread(target=gc.collect)
    gc_thread.start()
    # gc_thread.join()  # Optionally wait for the GC to complete