"""Build model and coordinate training."""

import os
from typing import Sequence

import numpy as np
import tensorflow as tf
import pickle 

from src import sa
from src import em
from src import data, models 
from src import util as my_utils 


_MAX_PROMPT_LEN = 5 
_INPUT_SIZE = 20
_NUM_COMPONENTS = 5 
_NOISE_LEVEL = 1.0 

def run_EM_alg(
    n, 
    noise_level, 
    num_components, 
    components,
    train_prompt_len, 
    input_size, 
    num_MCMC, 
    num_replicates,
    eval_prompt_len
):
  """
  Returns results of EM alg for given parameters
  Args: 
    num_components: number of cluster components
    train_prompt_len: batch size (prompt len) for training
    input_size: dimension 
    num_MCMC: number of MCMC trials to use when comuting prompt errors
    num_replicates: number of replicates (over cluster draw)
  """
  num_prompt_lengths = eval_prompt_len
  best_possible_errors = np.zeros(num_replicates)
  prompt_errors = np.zeros((num_replicates, num_MCMC, num_prompt_lengths))
  weights_from_reps = []
  for t in range(num_replicates):
    while True:
      try: 
        X, Y, component_parameters = em.construct_dataset(num_components, 
                                                               components, 
                                                               input_size, 
                                                               train_prompt_len, 
                                                               n, 
                                                               noise_level)
        batches = (X, Y)
        # get batches for training
        est_weights = em.main_algo_for_us(
            batches, num_components, noise_level
        )
        weights_from_reps.append(est_weights)
        best_possible_error = em.get_list_error(est_weights, component_parameters)
        best_possible_errors[t] = best_possible_error
        prompt_errors[t, :, :] = em.get_prompt_errors(
            num_trials=num_MCMC,
            input_size=input_size,
            estimated_weights=est_weights,
            true_weights=component_parameters,
            noise_level=noise_level,
        )
      except:
        continue
      else:
        break
  return best_possible_errors, prompt_errors, weights_from_reps

def run_SA(
    n, 
    noise_level,
    num_components, 
    components,
    train_prompt_len, 
    input_size, 
    num_MCMC, 
    num_replicates,
    eval_prompt_len
):
  """
  Returns results of SA alg for given parameters
  Args: 
    num_components: number of cluster components
    train_prompt_len: batch size (prompt len) for training
    input_size: dimension 
    num_MCMC: number of MCMC trials to use when comuting prompt errors
    num_replicates: number of replicates (over cluster draw)
  """
  num_prompt_lengths = eval_prompt_len
  best_possible_errors = np.zeros(num_replicates)
  prompt_errors = np.zeros((num_replicates, num_MCMC, num_prompt_lengths))
  weights_from_reps = []
  for t in range(num_replicates):
    half_sample_size = int(np.ceil(n / 2))
    component_parameters = components
    # get batches for training
    batches_one = sa.generate_batches(
        component_parameters, n=train_prompt_len, m=half_sample_size, noise_level=noise_level
    )
    batches_two = sa.generate_batches(
        component_parameters, n=train_prompt_len, m=half_sample_size, noise_level=noise_level
    )
    est_weights = sa.main_algo_for_us(
        batches_one, batches_two, num_components, noise_level
    )
    weights_from_reps.append(est_weights)
    best_possible_error = sa.get_list_error(est_weights, component_parameters)
    best_possible_errors[t] = best_possible_error
    prompt_errors[t, :, :] = em.get_prompt_errors(
            num_trials=num_MCMC,
            input_size=input_size,
            estimated_weights=est_weights,
            true_weights=component_parameters,
            noise_level=noise_level,
    )
  return best_possible_errors, prompt_errors, weights_from_reps


def get_PMA_error(X, Y, Xt, Yt, clusters, test_batch_size, input_size, noise_level):
  yp = get_preds(X, Y, Xt, clusters, test_batch_size, noise_level)[0]
  return (np.linalg.norm(np.array(Yt) - yp) ** 2) / (test_batch_size * input_size)


def get_Argmin_error(X, Y, Xt, Yt, clusters, test_batch_size, input_size):
  idx_min = np.argmin(np.linalg.norm((X @ clusters.T) - Y[:, :, np.newaxis], 
                                     axis=1),
                      axis=1)
  all_preds = Xt @ clusters.T
  yp = np.array([all_preds[i, idx_min[i]] for i in range(test_batch_size)])
  return (np.linalg.norm(np.array(Yt) - yp) ** 2) / (test_batch_size * input_size)


def get_preds(X, Y, Xt, clusters, test_batch_size, noise_level):
  residuals = np.linalg.norm((X @ clusters.T) - Y[:, :, np.newaxis], axis=1)**2
  # compute posterior mean
  probs = np.exp(-(0.5/noise_level) * residuals)
  norm_constant = np.sum(probs, axis=1)
  probs = probs / norm_constant[:, None]
  w_hats = clusters.T @ probs.T
  yp_PMA = np.array([Xt[i, :] @ w_hats[:, i] for i in range(test_batch_size)])
  # compute argmin
  idx_min = np.argmin(residuals,axis=1)
  all_preds = Xt @ clusters.T
  yp_Argmin = np.array([all_preds[i, idx_min[i]] for i in range(test_batch_size)])
  # replace overflow by argmin
  yp_PMA[norm_constant == 0] = yp_Argmin[norm_constant == 0]
  return yp_PMA, yp_Argmin


def get_model_eval_results(model, exp_dir, input_size, noise_level, centers):
  num_trials = 25
  num_prompt_lengths = 3 * input_size
  test_batch_size = 256
  results_ols = np.zeros((num_trials, num_prompt_lengths))
  results_model = np.zeros((num_trials, num_prompt_lengths))
  results_argmin = np.zeros((num_trials, num_prompt_lengths))
  results_pma = np.zeros((num_trials, num_prompt_lengths))
  
  # from 1 to num_prompt_lengths
  prompt_lengths = range(1, num_prompt_lengths + 1) 
    
  # calculating errors
  for (i, prompt_length) in enumerate(prompt_lengths):
    for t in range(num_trials):
      X, Y, Xt, Yt, Z = data.generate_batch(d=input_size, 
                                            k=prompt_length, 
                                            batch_size=test_batch_size, 
                                            input_size=input_size,
                                            noise_level=noise_level,
                                            return_reg=True,
                                            return_noiseless=False,
                                            centers=centers)
      results_ols[t, i] = my_utils.get_ols_error(X, Y, Xt, Yt, test_batch_size)
      results_model[t, i] = my_utils.get_model_error(model, Z, Yt)
      if noise_level > 0:
        results_pma[t, i] = get_PMA_error(X, Y, Xt, Yt, 
                                             centers, test_batch_size, 
                                             input_size, noise_level)
      results_argmin[t, i] = get_Argmin_error(X, Y, Xt, Yt, 
                                              centers, test_batch_size, 
                                              input_size) 

  
  # save OLS results 
  with open(os.path.join(exp_dir, 'ols_results.npy'), "wb") as file:
    np.save(file, results_ols)
  # save model results 
  with open(os.path.join(exp_dir, 'transformer_results.npy'), "wb") as file:
    np.save(file, results_model)
  if centers is not None:
    # save model results 
    with open(os.path.join(exp_dir, 'PMA_results.npy'), "wb") as file:
      np.save(file, results_pma)
    with open(os.path.join(exp_dir, 'Argmin_results.npy'), "wb") as file:
      np.save(file, results_argmin)
  

def analyze_GPT2_models(input_size, 
                      d_init, 
                      k_init, 
                      k_max, 
                      num_components,
                      noise_level,
                      num_embed=256,
                      num_heads=8, 
                      num_layers=12): 
  # make a list of all possible maximum prompt lengths for training time
  base_dir ='path-goes-here'
  expt_dir = os.path.join(base_dir, '_'.join([
       'noise_level', 
       str(noise_level), 
       'm', 
       str(num_components), 
       'd', 
       str(input_size)
  ]))
  # 1. Open components values
  with open(os.path.join(expt_dir, "mixture_centers.npy"), "rb") as file:
    mixture_centers = np.load(file)
  
  # # 2. load model 
  num_embed, num_heads, num_layers = 256, 8, 12 
  batch_size, prompt_len = 256, 3*input_size
  strategy = tf.distribute.MirroredStrategy()
  with strategy.scope():
    model = models.DecoderOnlyGPT2(input_size, num_embed, num_heads, num_layers)
    model.build(input_shape=(batch_size, prompt_len, input_size))
    model_filename = os.path.join(expt_dir, 'saved_model.h5')
    util.load_weights(model, model_filename)
  
  # 3. compute all model errors 
  get_model_eval_results(model, expt_dir, input_size, noise_level, mixture_centers)
  
  # 4. compute SA on this mixture model
  best_possible_errors, prompt_errors, weights = run_SA(n=30000, 
                                                  noise_level=noise_level,
                                                  num_components=num_components,
                                                  components=mixture_centers, 
                                                  train_prompt_len=2*input_size + 1, 
                                                  input_size=input_size, 
                                                  num_MCMC=25, 
                                                  num_replicates=15,
                                                  eval_prompt_len=3*input_size)
  with open(os.path.join(expt_dir, 'sa_best_possible.npy'), "wb") as file:
    np.save(file, best_possible_errors)
  with open(os.path.join(expt_dir, 'sa_prompt_errors.npy'), "wb") as file:
    np.save(file, prompt_errors)
  with open(os.path.join(expt_dir, 'sa_weights.pkl'), "wb") as file:
    pickle.dump(weights, file)
      
  # 5. compute EM on this mixture model
  best_possible_errors, prompt_errors, weights = run_EM_alg(n=30000, 
                                                  noise_level=noise_level,
                                                  num_components=num_components,
                                                  components=mixture_centers, 
                                                  train_prompt_len=2*input_size + 1, 
                                                  input_size=input_size, 
                                                  num_MCMC=25, 
                                                  num_replicates=15,
                                                  eval_prompt_len=3*input_size)
  with open(os.path.join(expt_dir, 'em_best_possible.npy'), "wb") as file:
    np.save(file, best_possible_errors)
  with open(os.path.join(expt_dir, 'em_prompt_errors.npy'), "wb") as file:
    np.save(file, prompt_errors)
  with open(os.path.join(expt_dir, 'em_weights.pkl'), "wb") as file:
    pickle.dump(weights, file)


def main(argv: Sequence[str]) -> None:
  """Launches GPT model training loops."""
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  analyze_GPT2_models(
    input_size=_INPUT_SIZE, 
    d_init=5, 
    k_init=11, 
    k_max=_MAX_PROMPT_LEN, 
    num_components=_NUM_COMPONENTS,
    noise_level=_NOISE_LEVEL,
    num_embed=256,
    num_heads=8, 
    num_layers=12, 
  )


if __name__ == '__main__':
  app.run(main)