from Base import *
import numpy as np

n_layers = 5
input_dim = 4
output_dim = 1
n_units = 4

# generate net
np.random.seed(1)

total_results = []
def run():
  ta = task(input_dim = input_dim, output_dim = output_dim, n_layers = n_layers, n_units = n_units)
  ta.init_param(sigma = 0.5)
  ta.set_true()

  # so = [deepcopy(ta) for i in range(n_layers)]

  so = []
  so_copy = []
  # diverse source tasks
  so = task(input_dim = input_dim, output_dim = n_units,
                  n_layers = n_layers, n_units = n_units, no_output = False)
  so.init_param()
  so.set_representation(ta, n_rep = n_layers)
  so.set_true()
  so.init_param()
  so_copy = deepcopy(so)
  ta.init_param() # set random before training
  ta_copy = deepcopy(ta)
  results = np.zeros((4, 3))

  for i_ta in range(1, 4):
    n_ta = 10**(i_ta)
    dat_ta = ta.gen_data(n_ta, sigma = 0.5)
    for i_so in range(2, 5):
      # inititilization
      ta = deepcopy(ta_copy)
      so = deepcopy(so_copy)

      n_so = 10**(i_so)
      dat_so = so.gen_data(n_so, sigma = 0.5)

      #ta = deepcopy(ta_copy)
      #so[i] = deepcopy(so_copy[i])
      so.training(dat_so, epoch = 10)
      # so[i].net = so[i].net_true
      ta.set_representation(so, n_rep = n_layers)
      if i_ta == 1:
        results[i_so-2, i_ta-1] = ta.training(dat_ta, with_rep = True, batch_size = 10)
      else:
        results[i_so-2, i_ta-1] = ta.training(dat_ta, with_rep = True, epoch = 10)

    ta = deepcopy(ta_copy)
    if i_ta == 1:
      results[-1, i_ta-1] = ta.training(dat_ta, with_rep = False, batch_size = 10)
    else:
      results[-1, i_ta-1] = ta.training(dat_ta, with_rep = False)
  #print(results)
  return(results)

#dat = np.array(total_results)
#dat_t = dat[1:, :, :]
#m = np.mean(dat_t, 0)
#s = np.std(dat_t, 0)

#print(np.mean(total_results, 0))
#print(np.std(total_results, 0))

from joblib import Parallel, delayed
total_results = Parallel(n_jobs=5)(delayed(run)() for _ in range(100))  # n_jobs = number of processes
#print(np.mean(total_results, 0))
#print(np.std(total_results, 0))

dat = np.array(total_results)
#dat_t = dat[1:, :, :]
m = np.mean(dat, 0)
s = np.std(dat, 0)
