import argparse
import json
import pickle
import time
from copy import deepcopy

import traceback

from relnet.io.file_paths import FilePaths
from relnet.state.graph_dataset import GraphDataset
from relnet.state.state_generators import create_generator_instance, TmGenStateGenerator
from relnet.utils.config_utils import get_logger_instance, date_format

class OptimizeHyperparamsTask(object):
    def __init__(self,
                 task_id,
                 agent,
                 objective_function,
                 network_generator,
                 experiment_conditions,
                 storage,
                 parameter_keys,
                 search_space_chunk,
                 model_seeds_chunk,
                 train_kwargs=None,
                 eval_make_action_kwargs=None,
                 additional_opts=None
                 ):
        self.task_id = task_id
        self.agent = agent
        self.objective_function = objective_function
        self.network_generator = network_generator
        self.experiment_conditions = experiment_conditions
        self.storage = storage
        self.parameter_keys = parameter_keys
        self.search_space_chunk = search_space_chunk
        self.model_seeds_chunk = model_seeds_chunk
        self.train_kwargs = train_kwargs
        self.eval_make_action_kwargs = eval_make_action_kwargs
        self.additional_opts = additional_opts

    def run(self):
        log_filename = str(self.storage.file_paths.construct_log_filepath())
        logger = get_logger_instance(log_filename)

        for hyperparams_id, combination in self.search_space_chunk:
            hyperparams = {}

            for idx, param_value in enumerate(tuple(combination)):
                param_key = self.parameter_keys[idx]
                hyperparams[param_key] = param_value

            logger.info(f"executing with hyps {hyperparams}")

            for model_seed in self.model_seeds_chunk:
                logger.info(f"executing for seed {model_seed}")
                exp_copy = deepcopy(self.experiment_conditions)
                exp_copy.set_generator_seeds(model_seed)
                setting = (self.network_generator.name, self.objective_function.name, self.agent.algorithm_name)

                if setting in exp_copy.model_seeds_to_skip:
                    if model_seed in exp_copy.model_seeds_to_skip[setting]:
                        print(f"skipping seed {model_seed} for setting {setting} as configured.")
                        continue
                model_identifier_prefix = self.storage.file_paths.construct_model_identifier_prefix(self.agent.algorithm_name,
                                                                                       self.objective_function.name,
                                                                                       self.network_generator.name,
                                                                                       model_seed, hyperparams_id)


                graph_seeds = exp_copy.get_seeds_as_tuple()
                var_types, var_count = exp_copy.topology_variations['var_types'], \
                                       exp_copy.topology_variations['var_count']

                gds = GraphDataset(self.storage.file_paths, self.experiment_conditions.graph_name, self.objective_function.name,
                                   self.network_generator.name)

                all_gs = gds.get_all_graph_hashes()
                train_graphs, validation_graphs, test_graphs = TmGenStateGenerator.split_from_seeds(all_gs, graph_seeds, var_types=var_types, var_count=var_count)

                agent_instance = self.agent()
                run_options = {}
                run_options['use_pyg_cache_dir'] = exp_copy.use_pyg_cache_dir
                run_options['log_memory_usage'] = exp_copy.log_memory_usage
                run_options["graph_ds"] = gds
                run_options["random_seed"] = model_seed
                run_options["file_paths"] = self.storage.file_paths
                run_options["log_progress"] = True

                log_filename = str(self.storage.file_paths.construct_log_filepath())
                run_options["log_filename"] = log_filename
                run_options["model_identifier_prefix"] = model_identifier_prefix
                run_options["log_tf_summaries"] = False

                run_options["restore_model"] = False

                run_options.update((self.additional_opts or {}))

                try:
                    agent_instance.setup(run_options, hyperparams)
                    if self.agent.is_trainable:
                        max_steps = exp_copy.agent_budgets[self.objective_function.name][self.agent.algorithm_name]

                        agent_train_kwargs =  (self.train_kwargs or {})
                        agent_instance.train(train_graphs, validation_graphs, max_steps, **agent_train_kwargs)

                    agent_predict_kwargs = (self.eval_make_action_kwargs or {})
                    perf = agent_instance.predict_and_score(validation_graphs, agent_predict_kwargs)
                    self.storage.write_hyperopt_results(model_identifier_prefix, perf)
                    agent_instance.finalize()
                except BaseException:
                    logger.warn("got exception while training & evaluating agent")
                    logger.warn(traceback.format_exc())


class EvaluateTask(object):
    def __init__(self,
                 task_id,
                 agent,
                 objective_function,
                 network_generator,
                 hyp_space,
                 best_hyperparams,
                 best_hyperparams_id,
                 experiment_conditions,
                 storage,
                 model_seeds,
                 eval_make_action_kwargs=None,
                 additional_opts=None):
        self.task_id = task_id
        self.agent = agent
        self.objective_function = objective_function
        self.network_generator = network_generator

        self.hyp_space = hyp_space
        self.best_hyperparams = best_hyperparams
        self.best_hyperparams_id = best_hyperparams_id

        self.experiment_conditions = experiment_conditions
        self.storage = storage
        self.model_seeds = model_seeds
        self.eval_make_action_kwargs = eval_make_action_kwargs
        self.additional_opts = additional_opts

    def run(self):
        log_filename = str(self.storage.file_paths.construct_log_filepath())
        logger = get_logger_instance(log_filename)
        local_results = []

        for hyps_id, hyps in self.hyp_space.items():
            #hyps_id = int(hyps_id_str)

            for model_seed in self.model_seeds:
                exp_copy = deepcopy(self.experiment_conditions)
                exp_copy.set_generator_seeds(model_seed)

                setting = (self.agent.algorithm_name, self.objective_function.name, self.network_generator.name)
                if setting in exp_copy.model_seeds_to_skip:
                    if model_seed in exp_copy.model_seeds_to_skip[setting]:
                        continue

                gen_params = exp_copy.gen_params
                graph_seeds = exp_copy.get_seeds_as_tuple()

                var_types, var_count = exp_copy.topology_variations['var_types'], \
                                       exp_copy.topology_variations['var_count']

                gds = GraphDataset(self.storage.file_paths, self.experiment_conditions.graph_name, self.objective_function.name,
                                   self.network_generator.name)

                all_gs = gds.get_all_graph_hashes()
                train_graphs, validation_graphs, test_graphs = TmGenStateGenerator.split_from_seeds(all_gs, graph_seeds,
                                                                                                    var_types=var_types, var_count=var_count)

                obj_fun_kwargs = {}

                setting = (self.agent.algorithm_name, self.objective_function.name, self.network_generator.name)
                if setting in exp_copy.model_seeds_to_skip:
                    if model_seed in exp_copy.model_seeds_to_skip[setting]:
                        return []

                # if agent.is_deterministic and model_seed > 0:
                #     # deterministic agents only need to be evaluated once as they involve no randomness.
                #     break

                agent_instance = self.agent()
                run_options = {}
                run_options['use_pyg_cache_dir'] = exp_copy.use_pyg_cache_dir
                run_options['log_memory_usage'] = exp_copy.log_memory_usage
                run_options["graph_ds"] = gds
                run_options['random_seed'] = model_seed
                run_options["restore_model"] = True

                model_identifier_prefix = self.storage.file_paths.construct_model_identifier_prefix(self.agent.algorithm_name,
                                                                                       self.objective_function.name,
                                                                                       self.network_generator.name,
                                                                                       model_seed,
                                                                                       hyps_id)
                run_options["model_identifier_prefix"] = model_identifier_prefix
                run_options["file_paths"] = self.storage.file_paths
                run_options["log_progress"] = True
                run_options["log_filename"] = log_filename

                run_options.update((self.additional_opts or {}))

                agent_instance.setup(run_options, hyps)

                time_started_seconds = time.time()
                agent_predict_kwargs = (self.eval_make_action_kwargs or {})
                try:
                    predictions = agent_instance.predict(test_graphs, **agent_predict_kwargs)
                except ValueError:
                    # temporary solution to GAT layers < 2 throwing ValueError.
                    continue
                time_ended_seconds = time.time()

                durations_ms = (time_ended_seconds - time_started_seconds) * 1000

                for i, prediction in enumerate(predictions):
                    result_row = {}
                    g_md = gds.get_metadata_for_hash(test_graphs[i])

                    result_row['network_generator'] = self.network_generator.name
                    result_row['objective_function'] = self.objective_function.name
                    result_row['network_seed'] = g_md['generator_seed']

                    result_row['algorithm'] = self.agent.algorithm_name
                    result_row['agent_seed'] = model_seed
                    result_row['graph_name'] = exp_copy.graph_name

                    result_row['prediction'] = prediction.item()
                    result_row['ground_truth'] = g_md['obj_fun_value']

                    result_row['duration_ms'] = durations_ms / len(predictions)

                    result_row['num_nodes'] = g_md['num_nodes']
                    result_row['num_edges'] = g_md['num_edges']
                    result_row['var_type'] = g_md['var_type']
                    result_row['var_number'] = g_md['var_number']
                    result_row['hyps_id'] = hyps_id
                    result_row['is_best_hyps'] = (hyps_id == self.best_hyperparams_id)

                    local_results.append(result_row)

                agent_instance.finalize()

        self.storage.write_eval_results(local_results, self.task_id)



def main():
    parser = argparse.ArgumentParser(description="Run a given task.")
    parser.add_argument("--experiment_part", required=True, type=str,
                        help="Whether to setup hyperparameter optimisation or evaluation.",
                        choices=["hyperopt", "eval"])

    parser.add_argument("--parent_dir", type=str, help="Root path for storing experiment data.")
    parser.add_argument("--experiment_id", required=True, help="experiment id to use")

    parser.add_argument("--task_id", type=str, required=True, help="Task id to run. Must have already been generated.")
    parser.set_defaults(parent_dir="/experiment_data")

    args = parser.parse_args()

    file_paths = FilePaths(args.parent_dir, args.experiment_id, setup_directories=False)
    task_storage_dir = file_paths.hyperopt_tasks_dir if args.experiment_part == "hyperopt" else file_paths.eval_tasks_dir
    task_file = task_storage_dir / FilePaths.construct_task_filename(args.experiment_part, args.task_id)
    with open(task_file, 'rb') as fh:
        task = pickle.load(fh)

    task.run()



if __name__ == "__main__":
    main()
































































