import numpy as np

def create_design_matrix(d, x=None, n=None):
  if x is None:
    if n is None:
      raise ValueError('Both `x` and `n` cannot be `None`')
    x = np.random.uniform(-1,1,n)
  x_pows = np.stack([x**i for i in range(d)], 1)
  return x_pows


def sample_task_data(tau, P, task_var, n, theta=None):
  if theta is None:
    print('UHOH')
    import pdb; pdb.set_trace()
    theta = np.random.multivariate_normal(tau, P)
  x = create_design_matrix(tau.shape[0], n=n)
  y = x.dot(theta) + np.random.normal(0, task_var, n)
  return x, y, theta


def compute_posterior_bias_variance(Xs, thetas, P, source_noise, target_noise):
  post_precision_tau = 0.0
  post_Cs = []
  num_source_tasks = len(Xs) - 1
  for i in range(num_source_tasks):
    post_C_i = np.linalg.inv(Xs[i].dot(P.dot(Xs[i].T)) + source_noise * np.eye(Xs[i].shape[0]))
    post_Cs.append(post_C_i)
    post_precision_tau += Xs[i].T.dot(post_C_i).dot(Xs[i])
  post_cov_tau = np.linalg.inv(post_precision_tau)
  post_cov_theta_tau = P + post_cov_tau
  post_precision_theta_tau = np.linalg.inv(post_cov_theta_tau)
  post_precision_theta = Xs[-1].T.dot(Xs[-1]) / target_noise + post_precision_theta_tau
  post_cov_theta = np.linalg.inv(post_precision_theta)

  left_cov = post_cov_theta.dot(Xs[-1].T)
  left_cov = left_cov.dot(left_cov.T) / target_noise

  post_mu_tau_cov = [post_cov_tau.dot(Xs[i].T.dot(post_Cs[i])) for i in range(num_source_tasks)]
    
  post_mu_tau_cov = [a.T.dot(a) for a in post_mu_tau_cov]
  post_mu_tau_cov = source_noise * np.sum(post_mu_tau_cov)

  right_cov = post_cov_theta.dot(post_cov_theta_tau)
  
  right_cov = right_cov.dot(post_mu_tau_cov).dot(right_cov)
  estimator_var = np.trace(left_cov + right_cov)

  bias_sum = np.sum([Xs[i].T.dot(post_Cs[i].dot(Xs[i].dot(thetas[i] - thetas[-1]))) \
    for i in range(num_source_tasks)], 0)
  bias = post_cov_theta.dot(post_cov_theta_tau.dot(post_cov_tau.dot(bias_sum)))
  estimator_bias = bias.dot(bias)
  return estimator_var + estimator_bias

def sample_datasets(M, n, k, thetas=None):
  Xs = []
  Ys = []
  if thetas is None:
    thetas = []
  local_thetas = []
  for i in range(M):
    theta = thetas[i] if i < len(thetas) else None
    x,y,theta = sample_task_data(tau, P, source_noise, n, theta=theta)
    Xs.append(x)
    Ys.append(y)
    local_thetas.append(theta)
  x,y,theta = sample_task_data(tau, P, target_noise, k, thetas[M])
  Xs.append(x); Ys.append(y); local_thetas.append(theta)
  return Xs, Ys, local_thetas