"""GradStudent is the orchestrator of the experiment"""

import pickle as pkl
import time
import os
from os import getcwd, listdir, path

import torch
import torch.multiprocessing as mp
from tqdm import tqdm

from codes.experiment.experiment import prepare_and_run_experiment as exp
from codes.experiment.signature_experiment import prepare_and_run_experiment as sig_exp
from codes.logbook.filesystem_logger import write_config_log, write_message_logs
from codes.logbook.logbook import LogBook
from codes.utils.config import get_config
from codes.utils.data import DataUtility
from codes.utils.util import _import_module, set_seed


class GradStudent:
    """GradStudent Class

    In practice, it is a thin class to support multiple experiments at once."""

    def __init__(self, config_id):
        self.config = bootstrap_config(config_id)
        self.logbook = LogBook(self.config)
        self.num_experiments = self.config.general.num_experiments
        torch.set_num_threads(self.num_experiments)
        self.device = self.config.general.device
        self.train_data = self.initialize_data(mode="train")
        self.valid_data = self.initialize_data(mode="valid")
        self.test_data = self.initialize_data(mode="test")
        self.model = self.bootstrap_model()

    def bootstrap_model(self):
        """Method to instantiate the models that will be common to all
        the experiments."""
        model = choose_model(self.config)
        model.to(self.device)
        return model

    def checkpoint(self):
        """Method to checkpoint the grad student"""
        model = choose_model(self.config)
        model.to(self.device)
        return model

    def initialize_data(self, mode="train"):
        """
        Load and initialize data here
        :return:
        """
        # data_path = path.join(getcwd().split('lgw')[0], 'lgw',
        # self.config.general.data_name, mode)
        data_path = path.join(
            os.path.expanduser("~/checkpoint/lgw/data"),
            self.config.general.data_name,
            mode,
        )
        world_list_file = path.join(data_path, "world_list.pkl")
        if path.exists(world_list_file) and path.isfile(world_list_file) and False:
            graphworld_list = pkl.load(open(world_list_file, "rb"))
        else:
            rule_folders = [
                folder
                for folder in listdir(data_path)
                if path.isdir(path.join(data_path, folder))
            ]
            # rule_folders = rule_folders[:10]
            num_folders = len(rule_folders)
            pb = tqdm(total=num_folders)
            # get the datasets in family
            graphworld_list = []
            for folder in rule_folders:
                graphworld_list.append(
                    DataUtility(
                        config=self.config, data_folder=path.join(data_path, folder)
                    )
                )
                pb.update(1)
            pb.close()
            # save the loaded modules in pkl
            # pkl.dump(graphworld_list, open(world_list_file,'wb'))
        # graphworld_list = [DataUtility(config=self.config,
        #                                data_folder=path.join(data_path, folder))
        #                    for folder in listdir(data_path)
        #                    if path.isdir(path.join(data_path, folder))]
        self.num_graphworlds = len(graphworld_list)
        self.config.model.num_classes = max(
            [dt.get_num_classes() for dt in graphworld_list]
        )
        return graphworld_list

    def run(self):
        """Method to run the task"""

        write_message_logs(
            "Starting Experiment at {}".format(
                time.asctime(time.localtime(time.time()))
            )
        )
        write_config_log(self.config)
        write_message_logs("torch version = {}".format(torch.__version__))

        if self.num_experiments > 1 and self.device.type != "cpu":
            write_message_logs("Multi GPU training not supported.")
            return

        if self.num_experiments > 1:

            # for model in self.models:
            self.model.share_memory()

            processes = []
            for experiment_id in range(self.num_experiments):
                config = get_config(self.config.general.id, experiment_id=experiment_id)
                prepare_and_run_experiment = sig_exp
                # if config.general.has_signature:
                #     prepare_and_run_experiment = sig_exp
                proc = mp.Process(
                    target=prepare_and_run_experiment,
                    args=(config, self.models, self.data),
                )
                proc.start()
                processes.append(proc)
            for proc in processes:
                proc.join()
        else:
            prepare_and_run_experiment = sig_exp
            # if self.config.general.has_signature:
            #     prepare_and_run_experiment = sig_exp
            prepare_and_run_experiment(
                config=self.config,
                model=self.model,
                data=[self.train_data, self.valid_data, self.test_data],
                logbook=self.logbook,
            )


def bootstrap_config(config_id):
    """Method to generate the config (using config id) and set seeds"""
    config = get_config(config_id, experiment_id=0)
    set_seed(seed=config.general.seed)
    return config


def choose_model(config):
    """
    Dynamically load model
    :param config:
    :return:
    """
    model_name = config.model.base_path + "." + config.model.name
    module = _import_module(model_name)
    return module(config)
