import numpy as np
from bilevel.ExpertsAbstract import Expert

class Adanormal_sleepingexps:
  def __init__(self, A_t: np.ndarray, experts:list[Expert]):
    self.A_t = A_t
    self.T = A_t.shape[0]
    self.N = A_t.shape[1]
    self.experts = experts
    self.p_t_arr = np.zeros((self.T, self.N))
    self.l_t = np.zeros((self.T, self.N))
    self.loss_ada_t_arr = np.zeros(self.T)
    self.cumloss_groupwise_ada = []
    self.cumloss_groupwise_metaexp = []
    self.r_t = np.zeros(self.N)
    self.R_t = np.zeros(self.N)
    self.C_t = np.zeros(self.N)

  def get_prob_over_experts(self, t):
    def w(R, C):
        dr = 3 * (C + 1)
        t1 = np.exp(np.clip(R + 1, 0.0, None)**2 / dr)
        t2 = np.exp(np.clip(R - 1, 0.0, None)**2 / dr)
        return 0.5 * (t1-t2)

    a_t = self.A_t[t]
    if np.all(a_t == 0):
      self.p_t_arr[t] = np.ones(self.N) / self.N
      return self.p_t_arr[t]
    v = w(self.R_t, self.C_t) * a_t
    if np.all(v == 0):
      self.p_t_arr[t] = a_t / np.sum(a_t)
      return self.p_t_arr[t]
    self.p_t_arr[t] = v / np.sum(v)
    return self.p_t_arr[t]

  def update_metaexps_loss(self, t):
    a_t = self.A_t[t]
    l_t_hat = 0
    for index, active in enumerate(a_t):
      if active:
        self.experts[index].get_ypred_t(t)
        self.experts[index].update_t(t)
        self.l_t[t][index] = self.experts[index].loss_tarr[-1]
    l_t_hat = np.dot(self.p_t_arr[t], self.l_t[t])
    self.loss_ada_t_arr[t] = l_t_hat
    self.r_t = (l_t_hat - self.l_t[t]) * a_t
    self.R_t += self.r_t
    self.C_t += abs(self.r_t)

  def build_cumloss_curve(self):
    self.cumloss_ada_allgroups = np.cumsum(self.loss_ada_t_arr.reshape(-1, 1) * self.A_t, axis = 0)
    self.cumloss_meta_exps = np.cumsum(self.l_t, axis = 0)
    for ind in range(self.N):
      self.cumloss_groupwise_ada.append(self.cumloss_ada_allgroups[:, ind][self.A_t[:, ind].astype(bool)])
      self.cumloss_groupwise_metaexp.append(self.cumloss_meta_exps[:, ind][self.A_t[:, ind].astype(bool)])
  
  def cleanup(self):
    self.r_t = None
    self.R_t = None
    self.C_t = None
    for gnum in range(self.N):
      self.cumloss_groupwise_ada[gnum] = np.array(self.cumloss_groupwise_ada[gnum])
      self.experts[gnum].cleanup()