import os
import sys
sys.path.append('/relnet')


from relnet.agent.baselines import PredictMeanAgent, PredictMedianAgent
import networkx as nx

from relnet.state.graph_dataset import GraphDataset
from relnet.evaluation.eval_utils import find_max_nodes, find_max_edges, find_max_diameter
from relnet.agent.gnn.prediction_agent import *
from relnet.agent.mlp.mlp_agent import MLPAgent, SumMLPAgent, RawMLPAgent
from relnet.objective_functions.objective_functions import DummyObjective, MLU
from relnet.evaluation.experiment_conditions import get_default_gen_params, get_default_file_paths, get_default_options
from relnet.state.state_generators import TmGenStateGenerator

if __name__ == '__main__':
    # num_train_graphs = 5
    # num_validation_graphs = 5
    # num_test_graphs = 5
    # eval_on_train = False
    # os.environ["RN_GNN_PRINT"] = "0"

    num_train_graphs = 1000
    num_validation_graphs = 1000
    num_test_graphs = 1000
    eval_on_train = False
    os.environ["RN_GNN_PRINT"] = "0"

    #graph_name = 'Chiesa'
    #graph_name = 'Abilene'
    #graph_name = 'Aconet'
    # graph_name = 'SNetDifferentCaps'
    #graph_name = 'Internetmci'
    #graph_name = 'Geant2011'
    #graph_name = 'Evolink'
    # graph_name = 'Iij'
    # graph_name = 'Agis'
    graph_name = "Uninett2011"

    # graphNames=("Cernet" "Cesnet201006" "Internode" "SwitchL3")
    # graphNames=("Sinet" "Ulaknet" "Uninett2011")

    gen_params = get_default_gen_params()
    #file_paths = get_default_file_paths(experiment_id=f'{graph_name}_ssp_newtest')
    file_paths = get_default_file_paths(experiment_id=f'{graph_name}_ssp_mlptest')

    #agent_class = PredictMedianAgent
    # agent_class = PredictMeanAgent

    # agent_class = SumMLPAgent
    # agent_class = RawMLPAgent

    agent_class = UniformRawDemandsRGATAgent
    # agent_class = LabelOrderSummedDemandsRGATAgent
    # agent_class = UniqueColorEdgeRawDemandsRGATAgent

    hyperparams = agent_class.get_default_hyperparameters()
    hyperparams['learning_rate'] = 0.005
    hyperparams['lf_dim'] = 32 #64#128#64
    #hyperparams['num_layers'] = 5#3#4
    # hyperparams['first_hidden_size'] = 256

    hyperparams['layers_lt_diam'] = 0
    #hyperparams['batch_size'] = 32
    hyperparams['batch_size'] = 16

    hyperparams['activation_fn'] = "relu"
    hyperparams['subgraph_agg'] = "sum"
    hyperparams['use_node_id'] = False

    # hyperparams['virtual_node_demands'] = 'existing_neg'
    # #hyperparams['virtual_node_demands'] = 'zeros'
    # #hyperparams['virtual_edge_caps'] = 'minus_ones'
    # hyperparams['virtual_edge_caps'] = 'existing'

    gen = TmGenStateGenerator(file_paths)
    model_seed = 0
    hyps_id = 0
    #
    # var_types = ["NR"]#, "NA", "EA", "ER"]
    # var_count = 50
    # var_percentage= 20

    var_types = []
    var_count = None
    var_percentage = None

    objective_function = MLU()
    #use_ecmp = True
    use_ecmp = False

    options = get_default_options(file_paths)
    options['random_seed'] = model_seed
    options['model_identifier_prefix'] = file_paths.construct_model_identifier_prefix(agent_class.algorithm_name,
                                                                                       objective_function.name,
                                                                                       gen.name,
                                                                                       model_seed,
                                                                                       hyps_id)
    options['log_tf_summaries'] = False
    options['log_memory_usage'] = False

    graph_seeds = TmGenStateGenerator.construct_network_seeds(eval_on_train, model_seed,
                                                              num_train_graphs, num_validation_graphs, num_test_graphs,
                                                              separate_graphs_per_model_seed=True)
    # g_seed = 0
    # graph_seeds = [g_seed] * num_train_graphs, [g_seed] * num_validation_graphs, [g_seed] * num_test_graphs
    print(graph_seeds)

    all_gs = gen.generate_many(graph_name, gen_params, graph_seeds, objective_function, use_ecmp, var_types=var_types, var_count=var_count)

    gds = GraphDataset(file_paths, graph_name, objective_function.name,
                       TmGenStateGenerator.name, graphs=all_gs)
    del all_gs
    options['graph_ds'] = gds
    options['use_pyg_cache'] = True

    all_hashes = gds.get_all_graph_hashes()
    train_graphs, validation_graphs, test_graphs = TmGenStateGenerator.split_from_seeds(all_hashes, graph_seeds, var_types=var_types, var_count=var_count, disjoint_topologies=False)
    agent = agent_class()
    num_training_steps = 1000
    #num_training_steps = 3000

    agent.setup(options, hyperparams)

    agent.train(train_graphs, validation_graphs, num_training_steps)
    print(f"number of params: {agent.count_parameters()}")

    vg_perf = agent.predict_and_score(validation_graphs, {})
    print(f"final performance on val set was {vg_perf}")

    avg_perf = agent.predict_and_score(test_graphs, {})
    print(f"final eval performance was {avg_perf}")


