import pickle
import jax.numpy as jnp

class CumData:
  def __init__(self, data=None):
    self.data = None if data is None else self.dict_of_lists(data)
    pass
  def append(self, dictt):
    if self.data is None:
      self.data = self.dict_of_lists(dictt)
    else:
      for k, L in self.data.items():
        L.append(dictt[k])

  def dict_of_lists(self, dictt):
    return dict([(k, [v]) for k,v in dictt.items()])

  def __str__(self):
    return str(self.data)

  def __getitem__(self, key):
    return self.data[key]

  def __eq__(self, other):
    return self.data==other.data

  def __repr__(self):
    return repr(self.data)

  def to_array(self):
    if not self.data is None:
      new = {}
      for k, L in self.data.items():
        new[k] = jnp.array(L)
      return new

class ExpData:
  def __init__(self, params, N=None, train_params=None):
    self.N = N
    self.train_params = train_params
    tracker = []
    for p in params:
      tracker.append((p, CumData()))
    self.params = params
    self.data = dict(tracker)

  def append(self, trial_out):
    for j, v in enumerate(trial_out):
      self.data[self.params[j]].append(v)

  def save(self, filename):
    full_filename = filename+self.file_preamble()+".pckl"
    print("Saving Data...")
    f = open(full_filename, 'wb')
    pickle.dump(self.data, f)
    f.close()
    print("Saved data in file: ", full_filename)

  def load(self, filename):
    full_filename = filename+self.file_preamble()+".pckl"
    print("Reading from file: ", full_filename, " ...")
    f = open(full_filename, 'rb')
    self.data = pickle.load(f)
    f.close()
    print("Done!")

  def file_preamble(self):
    if self.train_params is None:
      return ""
    else:
      d = self.train_params
      preamble = ""
      preamble += "_N"+str(self.N)
      preamble += "_teach-"+str(d["TEACHER_MODE"])
      preamble += "-fix_init" if d["TEACHER_FIXED"] else "-random"
      preamble += "_BS"+str(d["BATCH_SIZE"])
      preamble += "_LR"+str(d["LR"])
      preamble += "_TAU"+str(d["TAU"])
      preamble += "_BETA"+str(d["BETA"])
      preamble += "_T"+str(d["T_EPOCHS"])
      preamble += "_g"+str(d["GRANULAR"])
      preamble += "_std"+str(d["DATA_STD"])
      preamble += "_equiv"+str(d["EQUIV_INIT"])
      preamble += "_x"+str(d["N_reps"])+"reps"
      return preamble

  def __str__(self):
    return "For "+ str(self.N) +" particles: " + str(self.data)

  def __repr__(self):
    return repr(self.data)

  def __getitem__(self, key):
    return self.data[key]
