
from src import data
import numpy as np

def perform_batch_em(batches, dimension, noise_level, num_components, 
                     iterations=1000, epsilon=5e-3, truth=None):
  X, Y = batches # shapes: (num_batches, batch_len, dimension); (num_batches)
  num_batches = X.shape[0]
  batch_len = X.shape[1]
  sigsq = noise_level 
  emp_covs = np.zeros((num_batches, dimension, dimension))
  rhs_vecs = np.zeros((num_batches, dimension))
  for j in range(num_batches):
    emp_covs[j, : , :] = X[j, :, :].T @ X[j, :, :]
    rhs_vecs[j, :] = X[j, :, :].T @ Y[j, :]
  ########################################################
  # Initialization 
  ########################################################
  # initialize cluster proportions (marg clust probs)
  proportions = np.random.rand(num_components) 
  proportions /= np.sum(proportions) 
  # initialize cluster weights (parametrs)
  params = generate_clusters(num_components, dimension)
  # initialize cluster assignments 
  assignments = np.zeros((num_batches, num_components))
  # norm constant 
  const = 1.0/np.power(np.sqrt(2 * np.pi * noise_level), batch_len)
  # log likelihood 
  prev_params = np.copy(params) 

  while True: #for _ in range(iterations):
    ################################################
    # E step
    ################################################
    residuals = np.square(np.linalg.norm((X @ params.T) - Y[:, :, np.newaxis],
                                     axis=1))
    residuals -=  np.min(residuals, axis=1)[:, None]
    probs = np.exp(-(0.5/sigsq) * np.square(residuals)) * proportions 
    assignments = probs / np.sum(probs, axis=1)[:, None]

    ################################################
    # M step
    ################################################
    proportions = np.mean(assignments, axis=0)
    for k in range(num_components):
      L_mat = emp_covs.T @ assignments[:, k] 
      R_vec = rhs_vecs.T @ assignments[:, k]
      params[k, :] = np.linalg.lstsq(L_mat, R_vec, rcond=None)[0]

    if np.sqrt(avg_list_error(params, prev_params)) < epsilon:
      return params, proportions
    else: 
      prev_params = np.copy(params) 

    if truth is not None:
      print("Error to ground truth: %s" % avg_list_error(truth, params))

  return params, proportions

def main_algo_for_us(batches, num_clusters, noise_level):
  dimension = batches[0].shape[2]
  params, _ = perform_batch_em(batches, 
                               dimension, 
                               noise_level,
                               num_clusters, 
                               iterations=20000,
                               epsilon=1e-3)
  return params 
  
  
def generate_clusters(num_clusters, input_size):
  """
  Given num_clusters many clusters and input_size dimension,
  returns an array of shape (num_clusters, input_size)
  where each row is a uniform random vector of length input_size
  and of norm sqrt(input_size).
  """
  centers = np.random.randn(num_clusters, input_size)
  centers = (centers / np.linalg.norm(centers, axis=1)[:, None])
  return centers * np.sqrt(input_size)

def get_list_error(list_of_weights, correct_weights):
  """
  List of weights is a a num_list x dimension matrix of weights.
  Correct weights is a num_weights x dimension matrix of weights.
  We compute:
  max_{i in [num_weights]} min_{j in [num_list]}
    norm(correct_weights[i] - list_of_weights[j])^2/dimension
  This is the excess MSE (/dim) for the list provided above.
  """
  avg_dist = 0.0
  num_list = np.shape(list_of_weights)[0]
  num_weights = np.shape(correct_weights)[0]
  dimension = np.shape(correct_weights)[1]
  for i in range(num_weights):
    min_dist = np.Inf
    for j in range(num_list):
      dist = np.linalg.norm(correct_weights[i] - list_of_weights[j])
      min_dist = min(min_dist, dist)
    avg_dist += min_dist ** 2
  return avg_dist/ (dimension * num_weights)

def get_min_error(X, Y, Xt, Yt, clusters, test_batch_size, input_size):
  idx_min = np.argmin(np.linalg.norm((X @ clusters.T) - Y[:, :, np.newaxis],
                                     axis=1),
                      axis=1)
  all_preds = Xt @ clusters.T
  yp = np.array([all_preds[i, idx_min[i]] for i in range(test_batch_size)])
  return (np.linalg.norm(np.array(Yt) - yp) ** 2) / (test_batch_size * input_size)

def get_bayes_opt_error(X, Y, Xt, Yt, clusters, test_batch_size, input_size, noise_level):
  yp = get_preds(X, Y, Xt, clusters, test_batch_size, noise_level)[0]
  return (np.linalg.norm(np.array(Yt) - yp) ** 2) / (test_batch_size * input_size)

def get_prompt_errors(num_trials, input_size, estimated_weights, true_weights, noise_level):
  num_prompt_lengths = int(np.ceil(3 * input_size))
  results = np.zeros((num_trials, num_prompt_lengths))
  prompt_lengths = range(1, num_prompt_lengths + 1)
  # calculating errors
  for (i, prompt_length) in enumerate(prompt_lengths):
    for t in range(num_trials):
      X, Y, Xt, Yt, Z = data.generate_batch(d=input_size,
                                            k=prompt_length,
                                            batch_size=num_trials,
                                            input_size=input_size,
                                            noise_level=noise_level,
                                            return_reg=True,
                                            centers=true_weights)
      if noise_level > 0:
        results[t, i] = get_bayes_opt_error(X, Y, Xt, Yt, estimated_weights, num_trials, input_size, noise_level)
      else: 
        results[t, i] = get_min_error(X, Y, Xt, Yt, estimated_weights, num_trials, input_size)
  return results

def get_preds(X, Y, Xt, clusters, test_batch_size, noise_level):
  residuals = np.linalg.norm((X @ clusters.T) - Y[:, :, np.newaxis], axis=1)**2
  # compute posterior mean
  probs = np.exp(-(0.5/noise_level) * residuals)
  norm_constant = np.sum(probs, axis=1)
  probs = probs / norm_constant[:, None]
  w_hats = clusters.T @ probs.T
  yp_PMA = np.array([Xt[i, :] @ w_hats[:, i] for i in range(test_batch_size)])
  # compute argmin
  idx_min = np.argmin(residuals,axis=1)
  all_preds = Xt @ clusters.T
  yp_Argmin = np.array([all_preds[i, idx_min[i]] for i in range(test_batch_size)])
  # replace overflow by argmin
  yp_PMA[norm_constant == 0] = yp_Argmin[norm_constant == 0]
  return yp_PMA, yp_Argmin

def construct_dataset(num_clusters, components, input_size, batch_len, num_batches, noise_level): 
  true_params = components
  X, Y, _, _, _, = data.generate_batch(d=input_size, 
                      k=batch_len, 
                      batch_size=num_batches, 
                      input_size=input_size, 
                      return_reg=True,
                      noise_level=noise_level,
                      centers=true_params)
  return X, Y, true_params

def avg_list_error(true_list, pred_list): 
  num_true = true_list.shape[0] 
  avg_error = 0.0 
  for i in range(num_true):
    avg_error += np.min(np.linalg.norm(pred_list - true_list[i, :], axis=1)) ** 2
  return avg_error / num_true 