import numpy as np
from bilevel.ExpertsAbstract import *
import pandas as pd

def sherman_inv(previnv, v):
    '''
        Sherman morrison used to compute (A + uv^T)^{-1}
        same u = v = x
    '''
    nr = previnv @ v @ v.T @ previnv
    return previnv - nr / (1 + v.T @ previnv @ v)

class Manual_inv_LinearExpert(Expert):
  def __init__(self, X_dat_np: np.array, y_dat_np: np.array, l2_pen = 1.0):
    self.name = "Manual inversion"
    self.X_dat_np = X_dat_np
    self.y_dat_np = y_dat_np
    self.l2_pen = l2_pen
    self.dim = X_dat_np.shape[1]
    self.theta_pred = np.zeros(self.dim)
    self.y_predarr = []
    self.loss_tarr = [] 
    self.previnv = (1.0/ self.l2_pen) * np.identity(self.dim)
    self.xt_labelprodsum = np.zeros(self.dim)
  
  def update_theta_pred(self, t) -> None:
    self.previnv = sherman_inv(self.previnv, self.X_dat_np[t].reshape(-1,1))
    self.theta_pred = self.previnv @ self.xt_labelprodsum
    
  def get_ypred_t(self, t) -> None:
    x_t = self.X_dat_np[t]
    yhatt = np.dot(x_t, self.theta_pred)
    self.y_predarr.append(yhatt)

  def update_t(self, t) -> None:
    y_t = self.y_dat_np[t]
    x_t = self.X_dat_np[t]
    self.loss_tarr.append((self.y_predarr[-1] - y_t)**2)
    self.xt_labelprodsum += (y_t * x_t)
    self.update_theta_pred(t)
  
  def cleanup(self) -> None:
    self.X_dat_np = None
    self.y_dat_np = None
    self.y_predarr = np.array(self.y_predarr)
    self.loss_tarr = np.array(self.loss_tarr)


