import numpy as np
import matplotlib.pyplot as plt

from scipy import linalg
from scipy.stats import linregress

from datetime import datetime
import importlib
import pickle
import os

import meta_learn as meta
importlib.reload(meta)


def dump_script(
  dirname, script_file, dest=None, timestamp=None, file_list=None):
  import glob, os, shutil, sys
  from datetime import datetime

  if dest is None:
    if timestamp is None:
      timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    dest = os.path.join(
      dirname, 'script_{}'.format(timestamp))
  os.mkdir(dest)

  print('copying files to {}'.format(dest))
  if file_list is None:
    file_list = glob.glob("*.py")
  for file in file_list:
    print('copying {}'.format(file))
    shutil.copy2(file, dest)
  print('copying {}'.format(script_file))
  shutil.copy2(script_file, dest)

  with open(os.path.join(dest, "command.txt"), "w") as f:
      f.write(" ".join(sys.argv) + "\n")

def save_output(output, timestamp, dir_name=None):
  if dir_name == None:
    dir_name = '.'
  f = open('{}/output_{}.pkl'.format(
    dir_name, timestamp), 'wb')
  pickle.dump(output,f)
  f.close()


def algos_vs_var_metrics(
    d, r, t, m, sigma, algos_dict,
    homogeneity=None,
    dist_x=None, dist_y=None,
    loglog=False, xlims=None, ylims=None,
    metrics=None, plot_metrics=None,
    results_dir=None, script_file=None,
    ):
  """
  algos_dict = {
    'alt': [meta.apply_alt_min, 
      {'N_step': 10, 'U_init': None, 'init_mom': True}, 10],
  }
  """
  N_trials = 1
  for algo in algos_dict:
    algo_N_trials = algos_dict[algo][2]
    if algo_N_trials < 1:
      raise ValueError
    N_trials = max(algo_N_trials, N_trials)

  instance_list = (list, np.ndarray)
  d_instance = isinstance(d, instance_list)
  r_instance = isinstance(r, instance_list)
  t_instance = isinstance(t, instance_list)
  m_instance = isinstance(m, instance_list)
  sigma_instance = isinstance(sigma, instance_list)
  homogeneity_instance = isinstance(homogeneity, instance_list)

  nof_list_vars = (
    d_instance + r_instance + 
    t_instance + m_instance + 
    sigma_instance + homogeneity_instance)
  if nof_list_vars > 1:
    ValueError('varying {} variables'.format(nof_list_vars))

  if d_instance:
    list_var, var_list = 'd', d
  elif r_instance:
    list_var, var_list = 'r', r
  elif t_instance:
    list_var, var_list = 't', t
  elif m_instance:
    list_var, var_list = 'm', m
  elif sigma_instance:
    list_var, var_list = 'sigma', sigma
  elif homogeneity_instance:
    list_var, var_list = 'homogeneity', homogeneity
  else:
    list_var, var_list = 'k', [None]


  timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
  results_dir = 'results/neurips_submit' if results_dir is None else results_dir
  expt_name = '{}_vs_{}_{}'.format(
      '-'.join(algos_dict.keys()), list_var, timestamp)
  dir_name = os.path.join(results_dir, expt_name)
  os.makedirs(dir_name, exist_ok=True)

  dump_script(
    dir_name, script_file, timestamp=timestamp,
    file_list=['meta_learn.py', 'expt_utils.py'])


  if metrics is None:
    metrics = ['dist_U', 'dist_U_spectral', 'avg_mse_loss']

  plt_titles = {
    'dist_U': 'Distance of U', 
    'dist_U_spectral': 'Distance of U in spectral norm',
    'avg_mse_loss': 'Average MSE loss',
  }

  plt_ylabels = {
    'dist_U': 'Subspace Distance', 
    'dist_U_spectral': 'Subspace Distance',
    'avg_mse_loss': 'Average MSE loss',
  }

  if plot_metrics is None:
    plot_metrics = metrics
  metrics_list = {
    algo:{metric:[] for metric in metrics} for algo in algos_dict}
  
  for var in var_list:
    if list_var == 'd':
      d =  var
    elif list_var == 'r':
      r =  var
    elif list_var == 't':
      t =  var
    elif list_var == 'm':
      m =  var
    elif list_var == 'sigma':
      sigma =  var
    elif list_var == 'homogeneity':
      homogeneity = var
    _metrics_list = {
      algo:{metric:[] for metric in metrics} for algo in algos_dict}

    if list_var == 'k':
      _metrics_list = metrics_list

    for trial_idx in range(N_trials):
      prob = meta.MetaLearnProb(d, r, t, sigma=sigma, homogeneity=homogeneity)
      prob.generate_data(m, noise=True, dist_x=dist_x, dist_y=dist_y)

      for algo, algo_setup in algos_dict.items():
        algo_func, algo_params, algo_N_trials, _, _ = algo_setup
        if not trial_idx < algo_N_trials:
          continue

        output = algo_func(prob, **algo_params)
        
        for metric in metrics:
          if list_var != 'k':
            _metrics_list[algo][metric].append(
              output['{}_list'.format(metric)][-1])
          else:
            _metrics_list[algo][metric].append(output['{}_list'.format(metric)])

    if list_var != 'k':
      for algo in algos_dict:
        for metric in metrics:
          metrics_list[algo][metric].append(_metrics_list[algo][metric])

  if loglog is None:
    plot_func = plt.plot
  elif not loglog:
    plot_func = plt.semilogy  
  elif loglog:
    plot_func = plt.loglog
  else:
    raise ValueError

  for metric in metrics:
    for algo in algos_dict:
      algo_label = algos_dict[algo][3]
      algo_style = algos_dict[algo][4]
      if list_var != 'k':
        plot_func(var_list, np.mean(metrics_list[algo][metric], axis=1), 
                  algo_style, label=algo_label)
      else:
        k_list = np.arange(1, np.mean(metrics_list[algo][metric], axis=0).shape[0]+1)
        plot_func(k_list, np.mean(metrics_list[algo][metric], axis=0), 
                  algo_style, label=algo_label)

    if xlims is not None:
      plt.xlim(xlims)
    if ylims is not None:
      plt.ylim(ylims)
    plt.legend()

    plot_file_name = '{}_{}'.format(metric, expt_name)
    plt.savefig(os.path.join(dir_name, plot_file_name+'.png'))
    plt.savefig(os.path.join(dir_name, plot_file_name+'.pdf'))

    plt.title(plt_titles[metric])
    plt.xlabel(list_var)
    plt.ylabel(plt_ylabels[metric])
    # plt.show()
    plt.savefig(os.path.join(dir_name, plot_file_name+'_labeled.png'))
    plt.savefig(os.path.join(dir_name, plot_file_name+'_labeled.pdf'))
    plt.close()

  # Saving results
  outputs_dict = {
    'd' :d, 
    'r': r, 
    't': t, 
    'm': m, 
    'sigma': sigma, 
    'homogeneity': homogeneity,
    list_var: var_list,

    'list_var': list_var,
    'var_list': var_list,
    
    'algos_dict': algos_dict,
    
    'plt_titles': plt_titles,
    'plt_ylabels': plt_ylabels,
    
    'dist_x': dist_x, 
    
    'dist_y': dist_y,
    
    'loglog': loglog,
    'metrics': metrics,
    'plot_metrics': plot_metrics,
    
    'results_dir': results_dir,
    'script_file': script_file,

    'metrics_list': metrics_list,
  }
  save_output(outputs_dict, timestamp, dir_name=dir_name)

  return metrics_list

