import tqdm
import os
import numpy as np
import sys
from utils.storage import build_experiment_folder, save_statistics, save_to_json
import time
import torch

from copy import deepcopy
from cca import CCAHook


class ExperimentBuilder(object):
    def __init__(self, args, data, model, device):
        """
        Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system
        (model) and a device (e.g. gpu/cpu/n)
        :param args: A namedtuple containing all experiment hyperparameters
        :param data: A data provider of instance MetaLearningSystemDataLoader
        :param model: A meta learning system instance
        :param device: Device/s to use for the experiment
        """
        self.args, self.device = args, device
        num_thousand_iters = int(args.total_epochs*args.total_iter_per_epoch/1000)
        if not args.TR_MAML:
            tmp_maml_str = ''
        else:
            tmp_maml_str = 'd'

        experiment_name = "ERM_" + str(num_thousand_iters) + 'k' + str(args.batch_size)+ 'bs_'+str(args.num_classes_per_set)+'way'+str(args.num_samples_per_class)+'shot_'+str(args.number_of_training_steps_per_iter)+'gs_'+tmp_maml_str+'m_'+str(args.seed) + 'seed'

        self.model = model
        self.saved_models_filepath, self.logs_filepath, self.samples_filepath = build_experiment_folder(
                experiment_name)
        print(experiment_name)
        self.total_losses = dict()
        self.state = dict()
        self.state['best_val_acc'] = 0.
        self.state['best_val_iter'] = 0
        self.state['current_iter'] = 0
        self.state['current_iter'] = 0
        self.start_epoch = 0
        self.max_models_to_save = self.args.max_models_to_save
        self.create_summary_csv = False

        self.args.continue_from_epoch = 'from_scratch'

        if self.args.continue_from_epoch == 'from_scratch':
            self.create_summary_csv = True

        elif self.args.continue_from_epoch == 'latest':
            checkpoint = os.path.join(self.saved_models_filepath, "train_model_latest")
            print("attempting to find existing checkpoint", )
            if os.path.exists(checkpoint):
                self.state = \
                    self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                          model_idx='latest')
                self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)

            else:
                self.args.continue_from_epoch = 'from_scratch'
                self.create_summary_csv = True
        elif int(self.args.continue_from_epoch) >= 0:
            self.state = \
                self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                      model_idx=self.args.continue_from_epoch)
            self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)

        self.num_train_tasks = self.args.num_train_tasks
        self.num_test_tasks = self.args.num_test_tasks
        self.data = data(args=args, current_iter=self.state['current_iter'])

        print("train_seed {}, val_seed: {}, at start time".format(self.data.dataset.seed["train"],
                                                                  self.data.dataset.seed["val"]))
        self.total_epochs_before_pause = self.args.total_epochs_before_pause
        self.state['best_epoch'] = int(self.state['best_val_iter'] / self.args.total_iter_per_epoch)
        self.epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)
        self.augment_flag = True if 'omniglot' in self.args.dataset_name.lower() else False
        self.start_time = time.time()
        self.epochs_done_in_this_run = 0
        print("CURRENT STATE")
        print(self.state['current_iter'], int(self.args.total_iter_per_epoch * self.args.total_epochs))

    def build_summary_dict(self, total_losses, phase, summary_losses=None):
        """
        Builds/Updates a summary dict directly from the metric dict of the current iteration.
        :param total_losses: Current dict with total losses (not aggregations) from experiment
        :param phase: Current training phase
        :param summary_losses: Current summarised (aggregated/summarised) losses stats means, stdv etc.
        :return: A new summary dict with the updated summary statistics information.
        """
        if summary_losses is None:
            summary_losses = dict()

        for key in total_losses:
            if 'importance' not in key:
                summary_losses["{}_{}_mean".format(phase, key)] = np.mean(total_losses[key])
                summary_losses["{}_{}_std".format(phase, key)] = np.std(total_losses[key])

        return summary_losses

    def build_loss_summary_string(self, summary_losses):
        """
        Builds a progress bar summary string given current summary losses dictionary
        :param summary_losses: Current summary statistics
        :return: A summary string ready to be shown to humans.
        """
        output_update = ""
        for key, value in zip(list(summary_losses.keys()), list(summary_losses.values())):
            if ("loss" in key or "acc" in key) and 'importance' not in key:
                value = float(value)
                output_update += "{}: {:.4f}, ".format(key, value)

        return output_update

    def merge_two_dicts(self, first_dict, second_dict):
        """Given two dicts, merge them into a new dict as a shallow copy."""
        z = first_dict.copy()
        z.update(second_dict)
        return z

    def train_iteration(self, train_sample, sample_idx, epoch_idx, total_losses, current_iter, pbar_train):
        """
        Runs a training iteration, updates the progress bar and returns the total and current epoch train losses.
        :param train_sample: A sample from the data provider
        :param sample_idx: The index of the incoming sample, in relation to the current training run.
        :param epoch_idx: The epoch index.
        :param total_losses: The current total losses dictionary to be updated.
        :param current_iter: The current training iteration in relation to the whole experiment.
        :param pbar_train: The progress bar of the training.
        :return: Updates total_losses, train_losses, current_iter
        """
        x_support_set, x_target_set, y_support_set, y_target_set, seed, selected_task = train_sample
        data_batch = (x_support_set, x_target_set, y_support_set, y_target_set, selected_task)

        if sample_idx == 0:
            print("shape of data", x_support_set.shape, x_target_set.shape, y_support_set.shape,
                  y_target_set.shape)

        losses, _ = self.model.run_train_iter(data_batch=data_batch, epoch=epoch_idx) 

        for key, value in zip(list(losses.keys()), list(losses.values())):
            if key not in total_losses:
                total_losses[key] = [float(value)]
            else:
                total_losses[key].append(float(value))

        train_losses = self.build_summary_dict(total_losses=total_losses, phase="train")
        train_output_update = self.build_loss_summary_string(losses)

        pbar_train.update(1)
        pbar_train.set_description("training phase {} -> {}".format(self.epoch, train_output_update))

        current_iter += 1

        return train_losses, current_iter

    def evaluation_iteration(self, val_sample, total_losses,total_accs, pbar_val, phase):
        """
        Runs a validation iteration, updates the progress bar and returns the total and current epoch val losses.
        :param val_sample: A sample from the data provider
        :param total_losses: The current total losses dictionary to be updated.
        :param pbar_val: The progress bar of the val stage.
        :return: The updated val_losses, total_losses
        """
        x_support_set, x_target_set, y_support_set, y_target_set, seed, task_idxs = val_sample
        data_batch = (
            x_support_set, x_target_set, y_support_set, y_target_set, task_idxs)


#        model2 = deepcopy(self.model)
 #       classi1 = self.model.get_classifier()
        
  #      classi2 = deepcopy(classi1)#self.model.get_classifier()
   #layers = classi1.get_layer_keys()
    #    hooks1 = [CCAHook(classi1, name, svd_device=self.device) for name in layers]
     #   hooks2 = [CCAHook(classi2, name, svd_device=self.device) for name in layers]

        losses, full_accs, _, dists = self.model.run_validation_iter(data_batch=data_batch,hook=False)
     #   print(losses)
      #  print(full_accs)
        for key, value in zip(list(losses.keys()), list(losses.values())):
            if key not in total_losses:
                total_losses[key] = [float(value)]
            else:
                total_losses[key].append(float(value))

        full_accs=torch.stack(full_accs)
        task_idxs = task_idxs.numpy()
      #  print(task_idxs)
        task_sep=int(self.args.num_classes_per_set*self.args.num_target_samples)
        for idx, task in enumerate(task_idxs):
            if task[0] not in total_accs:
                total_accs[task[0]] = [torch.mean(full_accs[task_sep*idx:task_sep*(idx+1)])]
            else:
                total_accs[task[0]].append(torch.mean(full_accs[task_sep*idx:task_sep*(idx+1)]))

        val_losses = self.build_summary_dict(total_losses=total_losses, phase=phase)
        val_output_update = self.build_loss_summary_string(losses)

        pbar_val.update(1)
        pbar_val.set_description(
            "val_phase {} -> {}".format(self.epoch, val_output_update))
 
       # losses, full_accs, per_task_preds = model2.run_validation_iter(data_batch=data_batch, fixed=True)
 
 #       history = [h1.distance(h2) for h1, h2 in zip(hooks1, hooks2)]
   #     print("DISTS")
    #    print(dists)
        return val_losses, total_losses, total_accs

    def test_evaluation_iteration(self, val_sample, total_losses, total_accs, model_idx, sample_idx, per_model_per_batch_preds, pbar_test):
        """
        Runs a validation iteration, updates the progress bar and returns the total and current epoch val losses.
        :param val_sample: A sample from the data provider
        :param total_losses: The current total losses dictionary to be updated.
        :param pbar_test: The progress bar of the val stage.
        :return: The updated val_losses, total_losses
        """
        x_support_set, x_target_set, y_support_set, y_target_set, seed, task_idxs = val_sample
        data_batch = (
            x_support_set, x_target_set, y_support_set, y_target_set,task_idxs)


      #  model2 = deepcopy(self.model)
     #   classi2 = model2.get_classifier()
    #    classi1 = self.model.get_classifier()
   #     layers = []
  #      hooks1 = [CCAHook(classi1, name, svd_device=self.device) for name in layers]
 #       hooks2 = [CCAHook(classi2, name, svd_device=self.device) for name in layers]
 
        losses, full_accs, per_task_preds, dists = self.model.run_validation_iter(data_batch=data_batch, hook=False)
        for key, value in zip(list(losses.keys()),list(losses.values())):
            if key not in total_losses:
                total_losses[key] = [float(value)]
            else:
                total_losses[key].append(float(value))
                
        full_accs = torch.stack(full_accs)
        task_idxs = task_idxs.numpy()
        task_sep = int(self.args.num_classes_per_set*self.args.num_target_samples)
        for idx, task in enumerate(task_idxs):
            if task[0] not in total_accs:
                total_accs[task[0]] = [torch.mean(full_accs[task_sep*idx:task_sep*(idx+1)])]
            else:
                total_accs[task[0]].append(torch.mean(full_accs[task_sep*idx:task_sep*(idx+1)]))

        test_losses = self.build_summary_dict(total_losses=total_losses, phase=0)

        per_model_per_batch_preds[model_idx].extend(list(per_task_preds))

        test_output_update = self.build_loss_summary_string(losses)

        pbar_test.update(1)
        pbar_test.set_description(
            "test_phase {} -> {}".format(self.epoch, test_output_update))

  #      losses, full_accs, per_task_preds = model2.run_validation_iter(data_batch=data_batch, fixed=True)
        #inp = hooks1[0].data(train_loader.dataset, batch_size=self.batch_size).to(device)
        #def distance():
         #   self.model.eval()
          #  model2.eval()
           # with torch.no_grad():
            #    self.model(inp)
             #   model2(inp)
           # return [h1.distance(h2) for h1, h2 in zip(hooks1, hooks2)]

    # 0 and 99
#        history = [h1.distance(h2) for h1, h2 in zip(hooks1, hooks2)]
#        print("DISTS")
 #       print(dists)

        return per_model_per_batch_preds, test_losses, total_losses, total_accs

    def save_models(self, model, epoch, state):
        """
        Saves two separate instances of the current model. One to be kept for history and reloading later and another
        one marked as "latest" to be used by the system for the next epoch training. Useful when the training/val
        process is interrupted or stopped. Leads to fault tolerant training and validation systems that can continue
        from where they left off before.
        :param model: Current meta learning model of any instance within the few_shot_learning_system.py
        :param epoch: Current epoch
        :param state: Current model and experiment state dict.
        """
        model.save_model(model_save_dir=os.path.join(self.saved_models_filepath, "train_model_{}".format(int(epoch))),
                         state=state)

        model.save_model(model_save_dir=os.path.join(self.saved_models_filepath, "train_model_latest"),
                         state=state)

        print("saved models to", self.saved_models_filepath)

    def pack_and_save_metrics(self, start_time, create_summary_csv, train_losses, val_losses, state):
        """
        Given current epochs start_time, train losses, val losses and whether to create a new stats csv file, pack stats
        and save into a statistics csv file. Return a new start time for the new epoch.
        :param start_time: The start time of the current epoch
        :param create_summary_csv: A boolean variable indicating whether to create a new statistics file or
        append results to existing one
        :param train_losses: A dictionary with the current train losses
        :param val_losses: A dictionary with the currrent val loss
        :return: The current time, to be used for the next epoch.
        """
        epoch_summary_losses = self.merge_two_dicts(first_dict=train_losses, second_dict=val_losses)

        if 'per_epoch_statistics' not in state:
            state['per_epoch_statistics'] = dict()

        for key, value in epoch_summary_losses.items():

            if key not in state['per_epoch_statistics']:
                state['per_epoch_statistics'][key] = [value]
            else:
                state['per_epoch_statistics'][key].append(value)

        epoch_summary_string = self.build_loss_summary_string(epoch_summary_losses)
        epoch_summary_losses["epoch"] = self.epoch
        epoch_summary_losses['epoch_run_time'] = time.time() - start_time

        if create_summary_csv:
            self.summary_statistics_filepath = save_statistics(self.logs_filepath, list(epoch_summary_losses.keys()),
                                                               create=True)
            self.create_summary_csv = False

        start_time = time.time()
        print("epoch {} -> {}".format(epoch_summary_losses["epoch"], epoch_summary_string))

        self.summary_statistics_filepath = save_statistics(self.logs_filepath,
                                                           list(epoch_summary_losses.values()))
        return start_time, state

    def evaluated_test_set_using_the_best_models(self, top_n_models, dataset_name):
        per_epoch_statistics = self.state['per_epoch_statistics']
        val_acc = np.copy(per_epoch_statistics['val_accuracy_mean'])
        val_idx = np.array([i for i in range(len(val_acc))])
        sorted_idx = np.argsort(val_acc, axis=0).astype(dtype=np.int32)[::-1][:top_n_models]

        sorted_val_acc = val_acc[sorted_idx]
        val_idx = val_idx[sorted_idx]
        print(sorted_idx)
        print(sorted_val_acc)
        top_n_models=1

        top_n_idx = val_idx[:top_n_models]
        per_model_per_batch_preds = [[] for i in range(top_n_models)]
        per_model_per_batch_targets = [[] for i in range(top_n_models)]
        test_losses = [dict() for i in range(top_n_models)]
        total_losses = dict()
        total_accs = dict()
        # sample 8 tasks
        # get accuracy for each task
        # save
        old_state = self.state

#        model2 = deepcopy(self.model)
 #       history = []
  #      layers = []
   #     hooks1 = [CCAHook(self.model, name, svd_device=self.device) for name in layers]
    #    hooks2 = [CCAHook(model2, name, svd_device=self.device) for name in layers]

        for idx, model_idx in enumerate(top_n_idx):
            total_losses = dict()
            total_accs = dict()
            self.state = \
                self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                      model_idx=model_idx + 1)
            with tqdm.tqdm(total=int(self.args.num_evaluation_tasks / self.args.batch_size)) as pbar_test:
                if dataset_name == 'train':
                    batches = self.data.get_test_train_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size),
                                                   augment_images=False)
                else:
                    batches = self.data.get_test_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size),
                                                   augment_images=False)
                for sample_idx, test_sample in enumerate(batches):
                      #  self.data.get_test_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size),
                       #                            augment_images=False)):
#                    print(test_sample[2])
 #                   print(test_sample[3])
  #                  print(test_sample[4])
                    per_model_per_batch_targets[idx].extend(np.array(test_sample[3]))
                    per_model_per_batch_preds, tst_losses, total_losses, total_accs = self.test_evaluation_iteration(val_sample=test_sample,
                                                                               total_losses=total_losses,
                                                                                                         total_accs=total_accs,
                                                                               sample_idx=sample_idx,
                                                                               model_idx=idx,
                                                                               per_model_per_batch_preds=per_model_per_batch_preds,
                                                                               pbar_test=pbar_test)
                   # per_model_per_batch_preds, tst_losses, total_losses, total_accs = self.test_evaluation_iteration(val_sample=test_sample,
                                                                             #  total_losses=total_losses,
                                                                            #                             total_accs=total_accs,
                                                                           #    sample_idx=sample_idx,
                                                                          #     model_idx=idx,
                                                                         #      per_model_per_batch_preds=per_model_per_batch_preds,
                                                                        #       pbar_test=pbar_test,fixed=True)
                    
                    if idx == 0:
                        per_mean = np.asarray(per_model_per_batch_preds[0])
                    #else:
                     #   per_mean = per_mean + np.asarray(per_model_per_batch_preds[idx])/5
       # for i in range(top_n_models):
            #print("test assertion", 0)
            #print(per_model_per_batch_targets[0], per_model_per_batch_targets[i])
            #assert np.equal(np.array(per_model_per_batch_targets[0]), np.array(per_model_per_batch_targets[i]))
                if dataset_name == 'train':
                    nn = self.num_train_tasks #25
                else:
                    nn = self.num_test_tasks #20
                accs = -1*np.ones(nn)
     #   print("TOTAL ACCS II")
      #  print(total_accs)
                for ii in range(nn):
                    if ii in total_accs:
               # print(total_accs[ii])
               # print(np.asarray(total_accs[ii]).shape)
               # print(np.mean(np.asarray(total_accs[ii])))
                        accs[ii] = np.mean(np.asarray(total_accs[ii]))
                
                print("ACCURACIES")
                print(accs)
                print(np.mean(accs))
                print(np.min(accs))
                print(np.std(accs))
                print(np.max(accs))
                sorted_accs = np.argsort(accs)
                #print(sorted_accs[5])
                print(sorted_accs[:3])
        
        per_batch_preds = per_mean
        #print(per_batch_preds.shape)
        per_batch_max = np.argmax(per_batch_preds, axis=2)
        per_batch_targets = np.array(per_model_per_batch_targets[0]).reshape(per_batch_max.shape)
        #print(per_batch_max)
        accuracy = np.mean(np.equal(per_batch_targets, per_batch_max))
        accuracy_std = np.std(np.equal(per_batch_targets, per_batch_max))

        tst_losses["test_accuracy_mean"] = accuracy
        tst_losses["test_accuracy_std"] = accuracy_std
        test_losses = {"test_accuracy_mean": accuracy, "test_accuracy_std": accuracy_std}
        
        #inp = hooks1[0].data(train_loader.dataset, batch_size=self.batch_size).to(device)
       # def distance():
      #      self.model.eval()
     #       model2.eval()
    #         with torch.no_grad():
   #             self.model(inp)
  #              model2(inp)
 #           return [h1.distance(h2) for h1, h2 in zip(hooks1, hooks2)]

    # 0 and 99
       # history = distance()
      #  print("HISTORY")
      #  print(history)
        
        self.state = old_state
        _ = save_statistics(self.logs_filepath,
                            list(tst_losses.keys()),
                            create=True, filename="test_summary.csv")

        summary_statistics_filepath = save_statistics(self.logs_filepath,
                                                      list(tst_losses.values()),
                                                      create=False, filename="test_summary.csv")
        #print(test_losses)
        print("saved test performance at", summary_statistics_filepath)
        return accs

    def run_experiment(self):
        """
        Runs a full training experiment with evaluations of the model on the val set at every epoch. Furthermore,
        will return the test set evaluation results on the best performing validation model.
        """
        torch.cuda.empty_cache()
        
        count = 0
        with tqdm.tqdm(initial=self.state['current_iter'],
                       total=int(self.args.total_iter_per_epoch * self.args.total_epochs)) as pbar_train:

            while (self.state['current_iter'] < (self.args.total_epochs * self.args.total_iter_per_epoch)) and (self.args.evaluate_on_test_set_only == False):

                # train sample is data for one training iteration
                for train_sample_idx, train_sample in enumerate(
                        self.data.get_train_batches(total_batches=int(self.args.total_iter_per_epoch *
                                                                      self.args.total_epochs) - self.state[
                                                                      'current_iter'],
                                                    augment_images=self.augment_flag)):
                    # print(self.state['current_iter'], (self.args.total_epochs * self.args.total_iter_per_epoch))
                    train_losses, self.state['current_iter'] = self.train_iteration(
                        train_sample=train_sample,
                        total_losses=self.total_losses,
                        epoch_idx=(self.state['current_iter'] /
                                   self.args.total_iter_per_epoch),
                        pbar_train=pbar_train,
                        current_iter=self.state['current_iter'],
                        sample_idx=self.state['current_iter'])
                    
               
                 

                    if self.state['current_iter'] % self.args.total_iter_per_epoch == 0:

                        total_losses = dict()
                        val_losses = dict()
                        total_accs = dict()
                        with tqdm.tqdm(total=int(self.args.num_evaluation_tasks / self.args.batch_size)) as pbar_val:
                            for _, val_sample in enumerate(
                                    self.data.get_test_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size),
                                                              augment_images=False)):
           #                     x_support_set, x_target_set, y_support_set, y_target_set, seed, task_idxs = val_sample
            #                    print(task_idxs)
                                val_losses, total_losses, total_accs = self.evaluation_iteration(val_sample=val_sample,
                                                                                     total_losses=total_losses,
                                                                                     total_accs=total_accs,
                                                                                     pbar_val=pbar_val, phase='val')

                            if val_losses["val_accuracy_mean"] > self.state['best_val_acc']:
                                print("Best validation accuracy", val_losses["val_accuracy_mean"])
                                self.state['best_val_acc'] = val_losses["val_accuracy_mean"]
                                self.state['best_val_iter'] = self.state['current_iter']
                                self.state['best_epoch'] = int(
                                    self.state['best_val_iter'] / self.args.total_iter_per_epoch)

                            nn = self.num_test_tasks
                            accs = -1*np.ones(nn)
     #   print("TOTAL ACCS II")
      #  print(total_accs)
                            for ii in range(nn):
                                if ii in total_accs:
               # print(total_accs[ii])
               # print(np.asarray(total_accs[ii]).shape)
               # print(np.mean(np.asarray(total_accs[ii])))
                                    accs[ii] = np.mean(np.asarray(total_accs[ii]))

                            print("TEST ACCURACIES")
                            print(accs)
                            print("Mean: {}".format(np.mean(accs)))
                            print("Min: {}".format(np.min(accs)))
                            print("Std Dev: {}".format(np.std(accs)))
                            print("Max: {}".format(np.max(accs)))
                            sorted_accs = np.argsort(accs)
                        #    print(sorted_accs[5])
                            print(sorted_accs[:3])

                            total_losses = dict()
                            val_losses = dict()
                            total_accs = dict()

                            for _, val_sample in enumerate(
                                    self.data.get_test_train_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size),
                                                              augment_images=False)):
                                val_losses, total_losses, total_accs = self.evaluation_iteration(val_sample=val_sample,
                                                                                     total_losses=total_losses,
                                                                                     total_accs=total_accs,
                                                                                     pbar_val=pbar_val, phase='val')

                            nn = self.num_train_tasks
                            accs = -1*np.ones(nn)
     #   print("TOTAL ACCS II")
      #  print(total_accs)
                            for ii in range(nn):
                                if ii in total_accs:
               # print(total_accs[ii])
               # print(np.asarray(total_accs[ii]).shape)
               # print(np.mean(np.asarray(total_accs[ii])))
                                    accs[ii] = np.mean(np.asarray(total_accs[ii]))

                            print("ACCURACIES ON NEW TASKS FROM TRAINING CLASSES")
                            print(accs)
                            print("Mean: {}".format(np.mean(accs)))
                            print("Min: {}".format(np.min(accs)))
                            print("Std Dev: {}".format(np.std(accs)))
                            print("Max: {}".format(np.max(accs)))
                            sorted_accs = np.argsort(accs)
                        #    print(sorted_accs[5])
                            print(sorted_accs[:3])

                        self.epoch += 1
                        self.state = self.merge_two_dicts(first_dict=self.merge_two_dicts(first_dict=self.state,
                                                                                          second_dict=train_losses),
                                                          second_dict=val_losses)

                        self.save_models(model=self.model, epoch=self.epoch, state=self.state)

                        self.start_time, self.state = self.pack_and_save_metrics(start_time=self.start_time,
                                                                                 create_summary_csv=self.create_summary_csv,
                                                                                 train_losses=train_losses,
                                                                                 val_losses=val_losses,
                                                                                 state=self.state)

                        self.total_losses = dict()
                        self.epochs_done_in_this_run += 1

                        save_to_json(filename=os.path.join(self.logs_filepath, "summary_statistics.json"),
                                     dict_to_store=self.state['per_epoch_statistics'])

                        if self.epochs_done_in_this_run >= self.total_epochs_before_pause:
                            print("train_seed {}, val_seed: {}, at pause time".format(self.data.dataset.seed["train"],
                                                                                      self.data.dataset.seed["val"]))
                            sys.exit()

            accs_train = self.evaluated_test_set_using_the_best_models(top_n_models=1,dataset_name='train')
            accs_test1 = self.evaluated_test_set_using_the_best_models(top_n_models=1,dataset_name='test')
#            accs_test2 = self.evaluated_test_set_using_the_best_models(top_n_models=5,dataset_name='test',model=2)          
