import numpy as np
import matplotlib.pyplot as plt
plt.style.use('seaborn')
from hgaussians import create_design_matrix, compute_posterior_bias_variance

def plot(tau, P, M, n, k, source_noise, target_noise, name='tmp', trials=100):
  d = len(tau)
  print('Sampling tasks')
  krange = (k - 10) + np.logspace(1, 5, 50).astype(np.int32)
  Mrange = np.arange(2, 800, 5).astype(np.int32)
  # krange = np.logspace(1, 2, 50).astype(np.int32)
  # Mrange = np.arange(1, 20).astype(np.int32)

  # Sample all the thetas we need
  thetas = [np.random.multivariate_normal(tau, P) for _ in range(Mrange.max() + 1)]
  max_meta_data = np.max(n * Mrange)
  krange = krange[:np.argmin(krange - max_meta_data < 0) + 1]

  exp_risks_k = []
  exp_risks_M = []

  for j in range(trials):
    print("Working on trial {0}/{1}".format(j+1, trials))
    Xs = [create_design_matrix(d, n=n) for _ in range(Mrange.max())]
    Xnovel = create_design_matrix(d, n=krange.max())
    print('Created design matrices!')
    print()

    print('Computing risk as k increases...')
    exp_risk_k = []
    for _k in krange:
      X_tasks = Xs[:M]
      task_thetas = thetas[:M+1]
      X_tasks = X_tasks + [Xnovel[:_k]]
      exp_risk_k.append(compute_posterior_bias_variance(X_tasks, task_thetas, P, source_noise, target_noise))

    print('Computing risk as M increases...')
    exp_risk_M = []
    for _M in Mrange:
      X_tasks = Xs[:_M]
      task_thetas = thetas[:_M+1]
      X_tasks = X_tasks + [Xnovel[:k]]
      exp_risk_M.append(compute_posterior_bias_variance(X_tasks, task_thetas, P, source_noise, target_noise))
    exp_risks_k.append(exp_risk_k)
    exp_risks_M.append(exp_risk_M)
  
  exp_risk_k = np.mean(exp_risks_k, 0)
  exp_risk_M = np.mean(exp_risks_M, 0)
  
  plt.figure(figsize=(8,4))
  plt.semilogy(krange, exp_risk_k, label=r'$M={0}, n={1}$, $k \in [{2},{3}]$'.format(M,n, krange[0]-10, 8000))
  plt.semilogy(n * Mrange, exp_risk_M, label=r'$k={1}, n={0}$, $M \in [{2}, {3}]$'.format(n,k-10, Mrange[0], 800))
  plt.title("Expected error over varying dataset sizes", fontsize=16)
  plt.ylabel('Risk', fontsize=14)
  plt.xlabel(r'Total data samples ($Mn + k$)', fontsize=14)
  plt.xlim(0, max_meta_data)
  plt.legend(fontsize=14)
  plt.tight_layout()
  plt.savefig('plot{}.pdf'.format(name))
  plt.show()
  np.save('plot{}_k_range'.format(name), krange)
  np.save('plot{}_Mn_range'.format(name), n * Mrange)
  np.save('plot{}_k_risk_2'.format(name), exp_risk_k)
  np.save('plot{}_M_risk_2'.format(name), exp_risk_M)

def plot_sup(tau, P, n, k, source_noise, target_noise):
  """
  We have one other task, and we consider observing many data points from that task.
  The bias will stay large.
  """
  d = len(tau)

  nrange = (n - 10) + np.logspace(1, 4, 50).astype(np.int32)
  krange = (k - 10) + np.logspace(1, 4, 50).astype(np.int32)

  theta_1 = np.random.multivariate_normal(tau, P)
  theta_2 = np.random.multivariate_normal(tau, P)

  X_1 = create_design_matrix(d, n=nrange.max())
  X_2 = create_design_matrix(d, n=krange.max())

  thetas = [theta_1, theta_2]

  print('Computing risk as k increases...')
  exp_risk_k = []
  for _k in krange:
    X_tasks = [X_1[:n], X_2[:_k]]
    task_thetas = [theta_1, theta_2]
    exp_risk_k.append(compute_posterior_bias_variance(X_tasks, task_thetas, P, source_noise, target_noise))

  print('Computing risk as n increases...')
  exp_risk_n = []
  for _n in nrange:
    X_tasks = [X_1[:_n], X_2[:k]]
    task_thetas = [theta_1, theta_2]
    exp_risk_n.append(compute_posterior_bias_variance(X_tasks, task_thetas, P, source_noise, target_noise))
  
  plt.semilogy(krange, exp_risk_k, label=r'Increasing $k$')
  plt.semilogy(nrange, exp_risk_n, label=r'Increasing $n$')
  plt.ylabel('Risk',fontsize=14)
  plt.xlabel('Total data samples',fontsize=14)
  plt.xlim(left=0,right=max(np.max(nrange), np.max(krange)))
  plt.legend(fontsize=14)
  plt.tight_layout()
  plt.savefig('plotClose.pdf')
  plt.show()
  np.save('plotSup_k_range', krange)
  np.save('plotSup_Mn_range', nrange)
  np.save('plotSup_k_risk_2', exp_risk_k)
  np.save('plotSup_M_risk_2', exp_risk_n)

def get_threshold_risk_thing(tau, sigma_theta, sigma_source, M, n, k, noise_ratio_range, trials=5):
  d = len(tau)
  P = sigma_theta * np.eye(d)

  source_noise = sigma_source
  target_noise = source_noise / noise_ratio_range

  thetas = [np.random.multivariate_normal(tau, P) for _ in range(M + 1)]
  exp_risks = []

  for i in range(trials):
    Xs = [create_design_matrix(d, n=n) for _ in range(M)]
    Xnovel = create_design_matrix(d, n=k)

    X_tasks = Xs + [Xnovel[:k]]

    exp_risk = []
    for tnoise in target_noise:
      exp_risk.append(compute_posterior_bias_variance(X_tasks, thetas, P, source_noise, tnoise))
    exp_risks.append(exp_risk)
  return exp_risks




if __name__ == '__main__':
  tau = 0.1*np.array([0, 1.0, 2.0, 0.0, 0.0, 3.0, 1.0])
  d = len(tau)
  P = 0.1 * np.eye(d)
  M = 2
  n = 20
  k = 30
  source_noise = 0.1
  target_noise = 0.01

  if True:
    source_noise = 0.2
    target_noise = 0.001
    plot(tau, P, M, n, k, source_noise, target_noise, name='B')

  if True:
    noise_ratios = np.logspace(-3, 8, 100)
    samples1 = np.mean(get_threshold_risk_thing(tau, 0.001, 1.0, 50, 20, 5, noise_ratios), 0)
    samples2 = np.mean(get_threshold_risk_thing(tau, 0.001, 1.0, 50, 20, 1000, noise_ratios), 0)
    samples3 = np.mean(get_threshold_risk_thing(tau, 0.001, 1.0, 50, 200, 5, noise_ratios), 0)


    plt.figure(figsize=(8,4))
    # plt.loglog(noise_ratios, get_threshold_risk_thing(tau, 0.001, 1.0, 5, 20, 5, noise_ratios), label=r'$(M,n,k) = (5, 20, 5)$')
    plt.loglog(1.0 / noise_ratios, samples1, label=r'$(M,n,k) = (50, 20, 5)$')
    plt.loglog(1.0 / noise_ratios, samples2, label=r'$(M,n,k) = (50, 20, 1000)$')
    plt.loglog(1.0 / noise_ratios, samples3, label=r'$(M,n,k) = (50, 200, 5)$')

    plt.legend(framealpha=0.8, frameon=True, fontsize=14)
    plt.ylabel('Risk', fontsize=14)
    plt.xlabel('Novel task noise', fontsize=14)
    plt.title("Expected error over varying novel task difficulty", fontsize=16)
    plt.xlim(10**-8, 10**3)
    plt.tight_layout()
    plt.savefig("varying_task_difficulty.pdf")
    plt.show()
