from copy import deepcopy

from relnet.agent.baselines import PredictMeanAgent, PredictMedianAgent
from relnet.agent.gnn.prediction_agent import *
from relnet.agent.mlp.mlp_agent import SumMLPAgent, RawMLPAgent
from relnet.evaluation.eval_utils import get_model_seed
from relnet.io.file_paths import FilePaths
from relnet.objective_functions.objective_functions import *
from relnet.state.state_generators import TmGenStateGenerator


class ExperimentConditions(object):
    def __init__(self, graph_name, eval_on_train, use_ecmp, dms_mult):
        self.gen_params = {}
        self.graph_name = graph_name
        self.eval_on_train = eval_on_train
        self.use_ecmp = use_ecmp
        self.dms_mult = dms_mult

        self.gen_params = get_default_gen_params()

        self.objective_functions = [
            MLU
        ]

        self.network_generators = [
            TmGenStateGenerator
        ]


        self.hyps_chunk_size = 1
        self.seeds_chunk_size = 5

        self.use_pyg_cache_dir = False
        self.log_memory_usage = False

        self.separate_graphs_per_model_seed = False

        self.model_seeds_to_skip = {
            # Can be used to skip some random seeds in case training failed.
        }


    def set_generator_seeds(self, model_seed):
        self.train_seeds, self.validation_seeds, self.test_seeds = TmGenStateGenerator.construct_network_seeds(
            self.eval_on_train,
            model_seed,
            self.experiment_params['train_graphs'],
            self.experiment_params['validation_graphs'],
            self.experiment_params['test_graphs'],
            separate_graphs_per_model_seed=self.separate_graphs_per_model_seed
        )


    def get_seeds_as_tuple(self):
        return self.train_seeds, self.validation_seeds, self.test_seeds

    def update_relevant_agents(self):
        relevant_agents = deepcopy(self.agents_models)
        self.relevant_agents = relevant_agents

    def __str__(self):
        as_dict = deepcopy(self.__dict__)
        del as_dict["agents_models"]
        del as_dict["agents_baseline"]
        del as_dict["objective_functions"]
        del as_dict["network_generators"]
        return str(as_dict)

    def __repr__(self):
        return self.__str__()


class MainExperimentConditions(ExperimentConditions):
    max_steps = 3000

    def __init__(self, graph_name, eval_on_train, use_ecmp, dms_mult):
        super().__init__(graph_name, eval_on_train, use_ecmp, dms_mult)

        self.experiment_params = {'train_graphs': int(1000 * self.dms_mult), #1000
                                  'validation_graphs': int(1000 * self.dms_mult),
                                  'test_graphs': int(1000 * self.dms_mult),
                                  'num_runs': 10
                                  }

        self.experiment_params['model_seeds'] = [get_model_seed(run_num) for run_num in
                                                 range(self.experiment_params['num_runs'])]

        self.topology_variations = {
            "var_types": [],
            "var_count": None,
            "var_percentage": None
        }


        self.agents_models = [
            PredictMeanAgent,
            PredictMedianAgent,

            SumMLPAgent,
            UniformSummedDemandsRGATAgent,
            UniqueColorEdgeSummedDemandsRGATAgent,
            UniformSummedDemandsGCNAgent,
            UniformSummedDemandsSAGEAgent,

            RawMLPAgent,
            UniformRawDemandsRGATAgent,
            UniqueColorEdgeRawDemandsRGATAgent,
            UniformRawDemandsGCNAgent,
            UniformRawDemandsSAGEAgent

        ]

        self.agents_baseline = {
            MLU.name: [

            ],

        }

        self.agent_budgets = {
            MLU.name: {
                PredictMeanAgent.algorithm_name: self.max_steps,
                PredictMedianAgent.algorithm_name: self.max_steps,

                SumMLPAgent.algorithm_name: self.max_steps,
                UniformSummedDemandsRGATAgent.algorithm_name: self.max_steps,
                UniqueColorEdgeSummedDemandsRGATAgent.algorithm_name: self.max_steps,
                UniformSummedDemandsGCNAgent.algorithm_name: self.max_steps,
                UniformSummedDemandsSAGEAgent.algorithm_name: self.max_steps,

                RawMLPAgent.algorithm_name: self.max_steps,
                UniformRawDemandsRGATAgent.algorithm_name: self.max_steps,
                UniqueColorEdgeRawDemandsRGATAgent.algorithm_name: self.max_steps,
                UniformRawDemandsGCNAgent.algorithm_name: self.max_steps,
                UniformRawDemandsSAGEAgent.algorithm_name: self.max_steps,
            },
        }

        self.hyperparam_grids = self.create_hyperparam_grids()


    def create_hyperparam_grids(self):
        hyperparam_grid_base = {
            ## dummy baselines
            PredictMeanAgent.algorithm_name: {
                "dummy_param": [-1]
            },

            PredictMedianAgent.algorithm_name: {
                "dummy_param": [-1]
            },

            ## sum
            SumMLPAgent.algorithm_name: {
                "learning_rate": [0.01, 0.005, 0.001],
                "first_hidden_size": [64, 256],
                "batch_size": [16],
            },
            UniformSummedDemandsRGATAgent.algorithm_name: {
                "learning_rate": [0.01, 0.005, 0.001],
                "avg_caps_as_node_feats": [False],
                "lf_dim": [8, 32],
                "layers_lt_diam": [0],
                "input_layer_heads": [1],
                "activation_fn": ["relu"],
                "subgraph_agg": ["sum"],
                "batch_size": [16],
            },
            UniqueColorEdgeSummedDemandsRGATAgent.algorithm_name: {
                "learning_rate": [0.01, 0.005, 0.001],
                "avg_caps_as_node_feats": [False],
                "lf_dim": [4, 16],
                "layers_lt_diam": [0],
                "input_layer_heads": [1],
                "activation_fn": ["relu"],
                "subgraph_agg": ["sum"],
                "batch_size": [16],
            },
            UniformSummedDemandsGCNAgent.algorithm_name: {
                "avg_caps_as_node_feats": [True],
                "learning_rate": [0.01, 0.005, 0.001],
                "lf_dim": [8, 32],
                "layers_lt_diam": [0],
                "input_layer_heads": [1],
                "activation_fn": ["relu"],
                "subgraph_agg": ["sum"],
                "batch_size": [16],
            },
            UniformSummedDemandsSAGEAgent.algorithm_name: {
                "avg_caps_as_node_feats": [True],
                "learning_rate": [0.01, 0.005, 0.001],
                "lf_dim": [8, 32],
                "layers_lt_diam": [0],
                "input_layer_heads": [1],
                "activation_fn": ["relu"],
                "subgraph_agg": ["sum"],
                "batch_size": [16],
            },

            ### raw
            RawMLPAgent.algorithm_name: {
                "learning_rate": [0.01, 0.005, 0.001],
                "first_hidden_size": [64, 128],
                "batch_size": [16],
            },
            UniformRawDemandsRGATAgent.algorithm_name: {
                "learning_rate": [0.01, 0.005, 0.001],
                "avg_caps_as_node_feats": [False],
                "lf_dim": [8, 32],
                "layers_lt_diam": [0],
                "input_layer_heads": [1],
                "activation_fn": ["relu"],
                "subgraph_agg": ["sum"],
                "batch_size": [16],
            },
            UniqueColorEdgeRawDemandsRGATAgent.algorithm_name: {
                "learning_rate": [0.01, 0.005, 0.001],
                "avg_caps_as_node_feats": [False],
                "lf_dim": [4, 16],
                "layers_lt_diam": [0],
                "input_layer_heads": [1],
                "activation_fn": ["relu"],
                "subgraph_agg": ["sum"],
                "batch_size": [16],
            },
            UniformRawDemandsGCNAgent.algorithm_name: {
                "avg_caps_as_node_feats": [True],
                "learning_rate": [0.01, 0.005, 0.001],
                "lf_dim": [8, 32],
                "layers_lt_diam": [0],
                "input_layer_heads": [1],
                "activation_fn": ["relu"],
                "subgraph_agg": ["sum"],
                "batch_size": [16],
            },
            UniformRawDemandsSAGEAgent.algorithm_name: {
                "avg_caps_as_node_feats": [True],
                "learning_rate": [0.01, 0.005, 0.001],
                "lf_dim": [8, 32],
                "layers_lt_diam": [0],
                "input_layer_heads": [1],
                "activation_fn": ["relu"],
                "subgraph_agg": ["sum"],
                "batch_size": [16],
            },
        }
        hyperparam_grids = {}
        for f in self.objective_functions:
            hyperparam_grids[f.name] = deepcopy(hyperparam_grid_base)

        return hyperparam_grids

class TopVarExperimentConditions(MainExperimentConditions):
    def __init__(self, graph_name, eval_on_train, use_ecmp, dms_mult):
        super().__init__(graph_name, eval_on_train, use_ecmp, dms_mult)

        self.experiment_params = {'train_graphs': int(40 * self.dms_mult),
                                  'validation_graphs': int(40 * self.dms_mult),
                                  'test_graphs': int(40 * self.dms_mult),
                                  'num_runs': 10
                                  }

        self.experiment_params['model_seeds'] = [get_model_seed(run_num) for run_num in
                                                 range(self.experiment_params['num_runs'])]

        self.topology_variations = {
            "var_types": ["NR"],
            "var_count": 25,
            "var_percentage": 20
        }

def get_conditions_for_experiment(which, graph_name, eval_on_train, use_ecmp, dms_mult):
    if which == 'main':
        cond = MainExperimentConditions(graph_name, eval_on_train, use_ecmp, dms_mult)
    elif which == 'topvar':
        cond = TopVarExperimentConditions(graph_name, eval_on_train, use_ecmp, dms_mult)
    else:
        raise ValueError(f"experiment {which} not recognized!")
    return cond

def get_default_gen_params():
    gp = {}
    #gp['min_scale_factor'] = 1.25
    gp['min_scale_factor'] = 1
    gp['locality'] = 0
    return gp

def get_default_options(file_paths):
    options = {"log_progress": True,
               "log_filename": str(file_paths.construct_log_filepath()),
               "log_tf_summaries": True,
               "random_seed": 42,
               "file_paths": file_paths,
               "restore_model": False,
               "use_pyg_cache_dir": False,
               "log_memory_usage": False}
    return options

def get_default_file_paths(experiment_id = 'development'):
    parent_dir = '/experiment_data'
    file_paths = FilePaths(parent_dir, experiment_id)
    return file_paths
