import argparse
import pickle
import random
import traceback
from copy import copy, deepcopy
from datetime import datetime

from relnet.evaluation.eval_utils import generate_search_space, construct_search_spaces, score_predictions, \
    print_pred_scores
from relnet.evaluation.experiment_conditions import get_conditions_for_experiment, get_default_file_paths
from relnet.io.file_paths import FilePaths
from relnet.io.storage import EvaluationStorage
from relnet.state.graph_dataset import GraphDataset
from relnet.state.state_generators import create_generator_instance
from relnet.utils.config_utils import get_logger_instance
from relnet.utils.general_utils import chunks
from tasks import OptimizeHyperparamsTask, EvaluateTask

def build_dataset(experiment_conditions, parent_dir, experiment_id):
    print(f"building graph dataset.")
    fp = FilePaths(parent_dir, experiment_id)

    ec_copy = deepcopy(experiment_conditions)
    gen_params = ec_copy.gen_params
    for i, model_seed in enumerate(ec_copy.experiment_params['model_seeds']):
        if i > 0 and not ec_copy.separate_graphs_per_model_seed:
            break

        ec_copy.set_generator_seeds(model_seed)
        graph_seeds = ec_copy.get_seeds_as_tuple()
        var_types, var_count = ec_copy.topology_variations['var_types'], \
                               ec_copy.topology_variations['var_count']

        for network_generator in ec_copy.network_generators:
            for objective_function in ec_copy.objective_functions:
                network_generator_instance = create_generator_instance(network_generator, fp)
                graph_instances = network_generator_instance.generate_many(ec_copy.graph_name, gen_params, graph_seeds,
                                                                  objective_function(), ec_copy.use_ecmp,
                                                                  var_types=var_types,
                                                                  var_count=var_count)

                gds = GraphDataset(fp, ec_copy.graph_name, objective_function.name,
                                   network_generator.name, graphs=graph_instances)

def setup_hyperopt_part(experiment_conditions, parent_dir, existing_experiment_id, hyps_chunk_size, seeds_chunk_size):

    experiment_started_datetime = datetime.now()
    started_str = experiment_started_datetime.strftime(FilePaths.DATE_FORMAT)
    started_millis = experiment_started_datetime.timestamp()

    experiment_id = existing_experiment_id
    file_paths = FilePaths(parent_dir, experiment_id)
    storage = EvaluationStorage(file_paths)

    parameter_search_spaces = construct_search_spaces(experiment_conditions)

    storage.insert_experiment_details(
        experiment_conditions,
        started_str,
        started_millis,
        parameter_search_spaces)

    logger = get_logger_instance(str(file_paths.construct_log_filepath()))
    setup_hyperparameter_optimisations(storage,
                                       file_paths,
                                       experiment_conditions,
                                       experiment_id,
                                       hyps_chunk_size,
                                       seeds_chunk_size)
    logger.info(
        f"{datetime.now().strftime(FilePaths.DATE_FORMAT)} Completed setting up hyperparameter optimization tasks.")


def setup_eval_part(experiment_conditions, parent_dir, existing_experiment_id, parallel_eval):
    experiment_id = existing_experiment_id
    file_paths = FilePaths(parent_dir, experiment_id, setup_directories=False)
    storage = EvaluationStorage(file_paths)

    logger = get_logger_instance(str(file_paths.construct_log_filepath()))
    eval_tasks = construct_eval_tasks(experiment_id,
                                            file_paths,
                                            experiment_conditions,
                                            storage,
                                            parallel_eval)

    logger.info(f"have just setup {len(eval_tasks)} evaluation tasks.")
    storage.store_tasks(eval_tasks, "eval")


def construct_eval_tasks(experiment_id,
                         file_paths,
                         original_experiment_conditions,
                         storage,
                         parallel_eval):

    experiment_conditions = deepcopy(original_experiment_conditions)
    logger = get_logger_instance(str(file_paths.construct_log_filepath()))

    tasks = []
    task_id = 1

    try:
        hyps_search_spaces = storage.get_experiment_details(experiment_id)["parameter_search_spaces"]
        optimal_hyperparams = storage.retrieve_optimal_hyperparams(experiment_id, experiment_conditions.model_seeds_to_skip, False)
    except (KeyError, ValueError):
        logger.warn("no hyperparameters retrieved as no configured agents require them.")
        logger.warn(traceback.format_exc())
        hyps_search_spaces = None
        optimal_hyperparams = {}


    for network_generator in experiment_conditions.network_generators:
        for objective_function in experiment_conditions.objective_functions:
            relevant_agents = deepcopy(experiment_conditions.relevant_agents)
            relevant_agents.extend(experiment_conditions.agents_baseline[objective_function.name])
            for agent in relevant_agents:

                hyp_space = hyps_search_spaces[objective_function.name][agent.algorithm_name]
                additional_opts = {}
                eval_make_action_kwargs = {}

                # if issubclass(agent, MonteCarloTreeSearchAgent):
                #     additional_opts.update(get_base_opts())

                # is_baseline = issubclass(agent, BaselineAgent)
                # hyperparams_needed = (
                #     not is_baseline)  # or (agent.algorithm_name == SimulatedAnnealingAgent.algorithm_name)
                hyperparams_needed = True

                setting = (network_generator.name, objective_function.name, agent.algorithm_name)

                if not hyperparams_needed:
                    best_hyperparams, best_hyperparams_id =  ({}, 0)
                else:
                    if setting in optimal_hyperparams:
                        best_hyperparams, best_hyperparams_id = optimal_hyperparams[setting]
                    else:
                        best_hyperparams, best_hyperparams_id = ({}, 0)

                model_seeds = experiment_conditions.experiment_params['model_seeds']

                # if agent.is_deterministic and model_seed > 0:
                #     # deterministic agents only need to be evaluated once as they involve no randomness.
                #     break

                tasks.append(EvaluateTask(task_id,
                                        agent,
                                        objective_function,
                                        network_generator,
                                        hyp_space,
                                        best_hyperparams,
                                        best_hyperparams_id,
                                        experiment_conditions,
                                        storage,
                                        model_seeds,
                                        eval_make_action_kwargs=eval_make_action_kwargs,
                                        additional_opts=additional_opts
                                        ))
                task_id += 1

    return tasks


def setup_hyperparameter_optimisations(storage,
                                       file_paths,
                                       experiment_conditions,
                                       experiment_id,
                                       hyps_chunk_size,
                                       seeds_chunk_size):
    relevant_agents = experiment_conditions.relevant_agents
    experiment_params = experiment_conditions.experiment_params
    model_seeds = experiment_params['model_seeds']

    hyperopt_tasks = []

    start_task_id = 1
    for network_generator in experiment_conditions.network_generators:
        for obj_fun in experiment_conditions.objective_functions:
            for agent in relevant_agents:
                if agent.algorithm_name in experiment_conditions.hyperparam_grids[obj_fun.name]:
                    agent_param_grid = experiment_conditions.hyperparam_grids[obj_fun.name][agent.algorithm_name]

                    local_tasks = construct_parameter_search_tasks(
                            start_task_id,
                            agent,
                            obj_fun,
                            network_generator,
                            experiment_conditions,
                            storage,
                            file_paths,
                            agent_param_grid,
                            model_seeds,
                            experiment_id,
                            hyps_chunk_size,
                            seeds_chunk_size)
                    hyperopt_tasks.extend(local_tasks)
                    start_task_id += len(local_tasks)


    logger = get_logger_instance(str(file_paths.construct_log_filepath()))
    logger.info(f"created {len(hyperopt_tasks)} hyperparameter optimisation tasks.")
    storage.store_tasks(hyperopt_tasks, "hyperopt")


def construct_parameter_search_tasks(start_task_id,
                                     agent,
                                     objective_function,
                                     network_generator,
                                     experiment_conditions,
                                     storage,
                                     file_paths,
                                     parameter_grid,
                                     model_seeds,
                                     experiment_id,
                                     hyps_chunk_size,
                                     seeds_chunk_size):
    parameter_keys = list(parameter_grid.keys())
    local_tasks = []
    search_space = list(generate_search_space(parameter_grid).items())

    search_space_chunks = list(chunks(search_space, hyps_chunk_size))
    model_seed_chunks = list(chunks(model_seeds, seeds_chunk_size))

    print(search_space_chunks)
    print(model_seed_chunks)

    eval_make_action_kwargs = {}
    additional_opts = {}
    # if issubclass(agent, MonteCarloTreeSearchAgent):
    #     additional_opts.update(get_base_opts())

    task_id = start_task_id
    for ss_chunk in search_space_chunks:
        for ms_chunk in model_seed_chunks:
            local_tasks.append(OptimizeHyperparamsTask(task_id,
                                                           agent,
                                                           objective_function,
                                                           network_generator,
                                                           experiment_conditions,
                                                           storage,
                                                           parameter_keys,
                                                           ss_chunk,
                                                           ms_chunk,
                                                           additional_opts=additional_opts,
                                                           eval_make_action_kwargs=eval_make_action_kwargs))
            task_id += 1

    return local_tasks


def get_base_opts():
    return {}

def main():
    parser = argparse.ArgumentParser(description="Setup tasks for experiments.")
    parser.add_argument("--experiment_part", required=True, type=str,
                        help="Whether to setup hyperparameter optimisation or evaluation.",
                        choices=["build_dataset", "hyperopt", "eval"])

    parser.add_argument("--which", required=True, type=str,
                        help="Which experiment to run",
                        choices=["main", "topvar"])

    parser.add_argument("--parent_dir", type=str, help="Root path for storing experiment data.")
    parser.add_argument("--experiment_id", required=True, help="experiment id to use")

    parser.add_argument("--graph_name", type=str, required=True, help="Underlying graph topology name.")

    parser.add_argument("--dms_mult", type=float, required=False, help="Multiplier for number of DMs to generate (applied to 250 main / 20 topvar). Default 1.")

    parser.add_argument('--eval_on_train', dest='eval_on_train', action='store_true', help="Whether to train/validate/test on the same sets of DMs (if true) or disjoint.")

    parser.add_argument('--hyps_chunk_size', required=False, type=int, help="Number of hyperparameter configurations to pack in a single HPC task.")
    parser.add_argument('--seeds_chunk_size', required=False, type=int,
                        help="Number of random seeds to pack in a single HPC task.")


    parser.add_argument('--parallel_eval', dest='parallel_eval', action='store_true',
                        help="Whether to parallelize evaluation over different seeds. Care needed when number of graphs / seeds is large as may overwhelm queue.")

    parser.add_argument('--use_ecmp', dest='use_ecmp', action='store_true', help="Whether to use ECMP (if given) or vanilla shortest paths otherwise.")

    parser.set_defaults(dms_mult=1.0)
    parser.set_defaults(parallel_eval=False)
    parser.set_defaults(eval_on_train=False)
    parser.set_defaults(parent_dir="/experiment_data")
    parser.set_defaults(use_ecmp=False)

    args = parser.parse_args()

    experiment_conditions = get_conditions_for_experiment(args.which, args.graph_name, args.eval_on_train, args.use_ecmp, args.dms_mult)
    experiment_conditions.update_relevant_agents()

    if args.experiment_part == "build_dataset":
        build_dataset(experiment_conditions, args.parent_dir, args.experiment_id)
    elif args.experiment_part == "hyperopt":
        setup_hyperopt_part(experiment_conditions, args.parent_dir, args.experiment_id, experiment_conditions.hyps_chunk_size, experiment_conditions.seeds_chunk_size)
    elif args.experiment_part == "eval":
        setup_eval_part(experiment_conditions, args.parent_dir, args.experiment_id, args.parallel_eval)


if __name__ == "__main__":
    main()
