"""GradStudent is the orchestrator of the experiment"""

import json
import pickle as pkl
import time
import os
from os import listdir, path
from typing import Optional

import torch
from tqdm import tqdm

from codes.experiment.checkpointable_multitask_experiment import MultitaskExperiment
from codes.experiment.checkpointable_meta_experiment import MetaExperiment
from codes.experiment.inference import InferenceExperiment

# from codes.experiment.signature_experiment import prepare_and_run_experiment as sig_exp
from codes.logbook.filesystem_logger import write_config_log, write_message_logs
from codes.logbook.logbook import LogBook
from codes.utils.checkpointable import Checkpointable
from codes.utils.config import get_config
from codes.utils.data import DataUtility
from codes.utils.util import _import_module, set_seed

# TODO: Unify dataloading and model loading from this class itself
# For running MAML experiments, load the TaskFamily as it has a sampler
# load and save model weights right from this class
# use the classes defined in models.py here


class CheckpointableGradStudent(Checkpointable):
    """Checkpointable GradStudent Class

    This class provides a mechanism to checkpoint the (otherwise stateless) GradStudent
    """

    def __init__(self, config_id, load_checkpoint=True, seed=-1):
        self.config = bootstrap_config(config_id, seed)
        self.logbook = LogBook(self.config)
        self.num_experiments = self.config.general.num_experiments
        torch.set_num_threads(self.num_experiments)
        self.device = self.config.general.device
        self.label2id = {}
        self.model = None

    def bootstrap_model(self):
        """Method to instantiate the models that will be common to all
        the experiments."""
        model = choose_model(self.config)
        model.to(self.device)
        return model

    def load_label2id(self):
        data_path = path.join(
            os.path.expanduser("~/checkpoint/lgw/data"),
            self.config.general.data_name,
            "train",
        )
        labels_file = path.join(data_path, "label2id.json")
        if path.exists(labels_file) and path.isfile(labels_file):
            print("Loading labels from {}".format(labels_file))
            self.label2id = json.load(open(labels_file))
            print("Found : {} labels".format(len(self.label2id)))
            return True
        else:
            return False

    def initialize_data(self, mode="train", override_mode=None):
        """
        Load and initialize data here
        :return:
        """
        # data_path = path.join(getcwd().split('lgw')[0], 'lgw',
        # self.config.general.data_name, mode)
        if override_mode:
            data_path = path.join(
                os.path.expanduser("~/checkpoint/lgw/data"),
                self.config.general.data_name,
                override_mode,
            )
        else:
            data_path = path.join(
                os.path.expanduser("~/checkpoint/lgw/data"),
                self.config.general.data_name,
                mode,
            )
        world_list_file = path.join(data_path, "world_list.pkl")
        labels_file = path.join(data_path, "label2id.json")
        # check if labels_file exists, if not run a dummy loading train mode
        load_times = 1
        graphworld_list = []
        if not self.load_label2id():
            if mode == "train":
                load_times = 2
        if path.exists(world_list_file) and path.isfile(world_list_file) and False:
            graphworld_list = pkl.load(open(world_list_file, "rb"))
        else:
            while load_times >= 1:
                if load_times == 2:
                    print("Dummy loading to populate the labels file")
                rule_folders = [
                    folder
                    for folder in listdir(data_path)
                    if path.isdir(path.join(data_path, folder))
                ]
                if self.config.general.train_mode == "supervised":
                    if mode == "train":
                        rule_folders = [
                            r
                            for r in rule_folders
                            if r.split("/")[-1] == self.config.general.train_rule
                        ]
                # rule_folders = rule_folders[:10]
                num_folders = len(rule_folders)
                # get the datasets in family
                # if load_times == 1:
                #     graphworld_list = [
                #         load_data_parallel.remote(
                #             config=self.config,
                #             data_folder=path.join(data_path, folder),
                #             label2id=self.label2id,
                #             load_graph=load_times == 1,
                #             populate_labels=load_times == 2
                #         )
                #         for folder in rule_folders
                #     ]
                #     ray.get(graphworld_list)
                # elif load_times == 2:
                # dummy load
                timer = time.time()
                pb = tqdm(total=num_folders)
                if mode == "train" and "," in self.config.general.train_rule:
                    only_use = self.config.general.train_rule.split(",")
                elif mode == "test" and "," in self.config.general.test_rule:
                    only_use = self.config.general.test_rule.split(",")
                else:
                    only_use = []
                for folder in rule_folders:
                    if not path.exists(path.join(data_path, folder, "config.json")):
                        continue
                    if len(only_use) > 0:
                        if folder not in only_use:
                            continue
                    dt = DataUtility(
                        config=self.config,
                        data_folder=path.join(data_path, folder),
                        label2id=self.label2id,
                        load_graph=load_times == 1,
                        populate_labels=load_times == 2,
                    )
                    graphworld_list.append(dt)
                    # if mode train, update with new label2id
                    if mode == "train" and load_times == 2:
                        self.label2id = dt.label2id
                        json.dump(self.label2id, open(labels_file, "w"))
                    pb.update(1)
                pb.close()
                load_times -= 1
                timer_new = time.time()
                print("Data loading time : {}".format(timer_new - timer))
            # save the loaded modules in pkl
            # import ipdb; ipdb.set_trace()
            # pkl.dump(graphworld_list, open(world_list_file,'wb'), protocol=4)
            # timer = time.time()
            # gs = pkl.load(open(world_list_file,'rb'))
            # timer_new = time.time()
            # print(timer_new - timer)
            # pkl.dump(graphworld_list, open(world_list_file,'wb'))
            # timer = time.time()
            # gs = pkl.load(open(world_list_file,'rb'))
            # timer_new = time.time()
            # print(timer_new - timer)

        # graphworld_list = [DataUtility(config=self.config,
        #                                data_folder=path.join(data_path, folder))
        #                    for folder in listdir(data_path)
        #                    if path.isdir(path.join(data_path, folder))]
        self.num_graphworlds = len(graphworld_list)
        self.config.model.num_classes = len(self.label2id)
        # self.config.model.num_classes = max(
        #     [dt.get_num_classes() for dt in graphworld_list]
        # )
        return graphworld_list

    def run(self):
        """Method to run the task"""

        write_message_logs(
            "Starting Experiment at {}".format(
                time.asctime(time.localtime(time.time()))
            )
        )
        write_config_log(self.config)
        write_message_logs("torch version = {}".format(torch.__version__))

        if not self.config.general.is_meta:
            self.train_data = self.initialize_data(mode="train")
            self.valid_data = self.initialize_data(mode="valid")
            self.test_data = self.initialize_data(mode="test")
            self.experiment = MultitaskExperiment(
                config=self.config,
                model=self.model,
                data=[self.train_data, self.valid_data, self.test_data],
                logbook=self.logbook,
            )
        else:
            self.experiment = MetaExperiment(config=self.config, logbook=self.logbook)
        self.experiment.load_model()
        self.experiment.run()

    def prepare_evaluator(
        self,
        epoch: Optional[int] = None,
        test_data=None,
        zero_init=False,
        override_mode=None,
        label2id=None,
    ):
        self.load_label2id()
        if test_data:
            assert label2id is not None
            self.test_data = test_data
            self.num_graphworlds = len(test_data)
            self.config.model.num_classes = len(label2id)
        else:
            self.test_data = self.initialize_data(
                mode="test", override_mode=override_mode
            )
        self.evaluator = InferenceExperiment(
            self.config, self.logbook, [self.test_data]
        )
        self.evaluator.reset(epoch=epoch, zero_init=zero_init)

    def evaluate(self, epoch: Optional[int] = None, test_data=None, ale_mode=False):
        self.prepare_evaluator(epoch, test_data)
        return self.evaluator.run(ale_mode=ale_mode)


def bootstrap_config(config_id, seed=-1):
    """Method    to generate the config (using config id) and set seeds"""
    config = get_config(config_id, experiment_id=0)
    if seed > 0:
        set_seed(seed=seed)
    else:
        set_seed(seed=config.general.seed)
    return config


def choose_model(config):
    """
    Dynamically load model
    :param config:
    :return:
    """
    model_name = config.model.base_path + "." + config.model.name
    module = _import_module(model_name)
    return module(config)


# @ray.remote
# def load_data_parallel(config=None, data_folder=None, label2id=None, load_graph=False, populate_labels=False):
#     print("loading {}".format(data_folder))
#     dt = DataUtility(
#         config=config,
#         data_folder=data_folder,
#         label2id=label2id,
#         load_graph=load_graph,
#         populate_labels=populate_labels,
#     )
#     print("loaded {}".format(data_folder))
#     return dt
