import os

from peagang.models.ggg import PEAWGAN_HyperParameters
from .base import ex
import attr

BASE_LR=1e-3
@ex.config
def config():
    base_dir = os.path.abspath(os.getcwd())

    hyper = attr.asdict(PEAWGAN_HyperParameters(None))
    data_dir = os.path.join(base_dir, "peagang/data/")

    epochs = 1001

    model_n = None

    save_dir = "peagang"
    overfit_pct = 0.0
    ckpt_period = 100
    detect_anomaly = False
    deep = False
    deep_disc = False
    deep_gen = False
    forward_clip = False


@ex.named_config
def egonet():
    hyper = dict(dataset="egonet",
                 device="cuda:0",
                 n_attention_layers=16,
                 cut_train_size=False,
                 edge_readout="attention_weights",
                 score_function="softmax",
                 architecture="attention",
                 dataset_kwargs=dict(dir="data"),  # ,num_graphs=20),
                 embed_dim=29,
                 finetti_dim=50,
                 batch_size=5,
                 num_heads=10,
                 kc_flag=True,
                 disc_conv_channels=[32, 64, 64, 64, 128, 128, 128, 256],
                 disc_readout_hidden=128,
                 cycle_opt="finetti_noDS",
                 finetti_trainable=True,
                 finetti_train_fix_context=False,
                 dynamic_finetti_creation=False,
                 replicated_Z=False,
                 finneti_MLP=False,
                 structured_features=True,
                 node_feature_dim=4,
                 use_laplacian=True,
                 score_penalty=0.0,
                 gen_spectral_norm="nondiff",
                 temperature=3 / 3.0,  # lower => more discrete, less smooth
                 disc_contrast="fake_fake",
                 generator_every=2,
                 disc_optim_args=dict(
                     lr=3 * BASE_LR,  # TTUR
                     betas=(0.0, 0.9999),
                     eps=1e-8,
                     weight_decay=1e-3,
                     ema=True,
                     ema_start=10,
                 ),
                 gen_optim_args=dict(
                     lr=BASE_LR,
                     betas=(0.0, 0.9999),
                     eps=1e-8,
                     weight_decay=1e-3,
                     ema=True,
                     ema_start=10,
                 ),
                 )

@ex.named_config
def report_base():
    hyper=dict(dataset="MolGAN_5k",
                device="cuda:2",
                n_attention_layers=6,
                cut_train_size=False,
                edge_readout="attention_weights",
                architecture="attention",
                dataset_kwargs=dict(dir="data"),
                label_one_hot=5,
                embed_dim=25,
                finetti_dim=25,
                kc_flag=True,
                disc_conv_channels=[32, 64, 64, 64, 128, 128, 128],
                cycle_opt="finetti_noDS",
                score_function="softmax",
                finetti_trainable=True,
                finetti_train_fix_context=False,
                dynamic_finetti_creation=False,
                replicated_Z=False,
                finneti_MLP=False,
                structured_features=True,
                node_feature_dim=4,
                use_laplacian=True,
                disc_contrast="fake_fake",
                )
@ex.named_config
def pointnetst_QM9():
    hyper = dict(dataset="MolGAN_5k",
                device="cuda:2",
                n_attention_layers=7,
                cut_train_size=False,
                edge_readout="attention_weights",
                architecture="deepset",
                dataset_kwargs=dict(dir="data"),
                label_one_hot=5,
                embed_dim=25,
                finetti_dim=25,
                kc_flag=True,
                disc_conv_channels=[32, 64, 64, 64, 128, 128, 128],
                cycle_opt="finetti_noDS",
                score_function="softmax",
                finetti_trainable=True,
                finetti_train_fix_context=False,
                dynamic_finetti_creation=False,
                replicated_Z=False,
                finneti_MLP=False,
                structured_features=True,
                node_feature_dim=4,
                use_laplacian=True,
                disc_contrast="fake_fake",
                 )

@ex.named_config
def mlprow_qm9():
    hyper = dict(dataset="MolGAN_5k",
                 device="cuda:0",
                 n_attention_layers=12,
                 disc_contrast="fake_fake",
                 cut_train_size=False,
                 edge_readout="QQ_sig",
                 architecture="mlp_row",
                 MLP_layers=[128, 256, 512],
                 dataset_kwargs=dict(DATA_DIR="/peagang/data"),
                 label_one_hot=5,
                 kc_flag=True,
                 disc_conv_channels=[32, 64, 64, 64, 128, 128, 128],
                 cycle_opt="finetti_noDS",
                 finetti_trainable=True,
                 finetti_train_fix_context=False,
                 dynamic_finetti_creation=False,
                 replicated_Z=False,
                 finneti_MLP=False,
                 structured_features=False
                 )

@ex.named_config
def mlprow_chordal9():
    hyper = dict(dataset="anu_graphs_chordal9",
                 device="cuda:0",
                 n_attention_layers=12,
                 cut_train_size=False,
                 edge_readout="QQ_sig",
                 architecture="mlp_row",
                 disc_contrast="fake_fake",
                 MLP_layers=[128, 256, 512],
                 dataset_kwargs=dict(DATA_DIR="/peagang/data"),
                 label_one_hot=5,
                 kc_flag=True,
                 disc_conv_channels=[32, 64, 64, 64, 128, 128, 128],
                 cycle_opt="finetti_noDS",
                 finetti_trainable=True,
                 finetti_train_fix_context=False,
                 dynamic_finetti_creation=False,
                 replicated_Z=False,
                 finneti_MLP=False,
                 node_feature_dim=4,
                 use_laplacian=True,
                 structured_features=True)

@ex.named_config
def mlprow_commsmall20():
    hyper = dict(dataset="CommunitySmall_20",
                 device="cuda:0",
                 n_attention_layers=12,
                 cut_train_size=False,
                 edge_readout="QQ_sig",
                 disc_contrast="fake_fake",
                 architecture="mlp_row",
                 MLP_layers=[128, 256, 512],
                 dataset_kwargs=dict(DATA_DIR="/peagang/data"),
                 label_one_hot=5,
                 kc_flag=True,
                 disc_conv_channels=[32, 64, 64, 64, 128, 128, 128],
                 cycle_opt="finetti_noDS",
                 finetti_trainable=True,
                 finetti_train_fix_context=False,
                 dynamic_finetti_creation=False,
                 replicated_Z=False,
                 finneti_MLP=False,
                 node_feature_dim=4,
                 use_laplacian=True,
                 structured_features=True)

@ex.named_config
def attention_qm9():
    hyper = dict(dataset="MolGAN_5k",
                 device="cuda:2",
                 n_attention_layers=6,
                 cut_train_size=False,
                 disc_contrast="fake_fake",
                 edge_readout="attention_weights",
                 architecture="attention",
                 dataset_kwargs=dict(DATA_DIR="/peagang/data"),
                 label_one_hot=None,
                 embed_dim=25,
                 finetti_dim=25,
                 kc_flag=True,
                 disc_conv_channels=[32, 64, 64, 64, 128, 128, 128],
                 cycle_opt="finetti_noDS",
                 score_function="softmax",
                 finetti_trainable=True,
                 finetti_train_fix_context=False,
                 dynamic_finetti_creation=False,
                 replicated_Z=False,
                 finneti_MLP=False,
                 structured_features=False)

@ex.named_config
def attention_chordal9():
    hyper = dict(dataset="anu_graphs_chordal9",
                 device="cuda:2",
                 n_attention_layers=6,
                 cut_train_size=False,
                 disc_contrast="fake_fake",
                 edge_readout="attention_weights",
                 architecture="attention",
                 dataset_kwargs=dict(DATA_DIR="/peagang/data"),
                 label_one_hot=None,
                 embed_dim=50,
                 finetti_dim=50,
                 kc_flag=True,
                 disc_conv_channels=[32, 64, 64, 64, 128, 128, 128],
                 cycle_opt="finetti_noDS",
                 score_function="sigmoid",
                 finetti_trainable=True,
                 finetti_train_fix_context=False,
                 dynamic_finetti_creation=False,
                 replicated_Z=False,
                 finneti_MLP=False,
                 node_feature_dim=4,
                 use_laplacian=True,
                 structured_features=True)

@ex.named_config
def attention_community():
    hyper = dict(dataset="CommunitySmall_20",
                 device="cuda:3",
                 n_attention_layers=3,
                 disc_contrast="fake_fake",
                 cut_train_size=False,
                 edge_readout="attention_weights",
                 architecture="attention",
                 dataset_kwargs=dict(DATA_DIR="/peagang/data"),
                 label_one_hot=None,
                 embed_dim=25,
                 finetti_dim=25,
                 kc_flag=True,
                 disc_conv_channels=[32, 64, 64, 64, 128, 128, 128],
                 cycle_opt="finetti_noDS",
                 score_function="sigmoid",
                 finetti_trainable=True,
                 finetti_train_fix_context=False,
                 dynamic_finetti_creation=False,
                 replicated_Z=False,
                 finneti_MLP=False,
                 node_feature_dim=2,
                 use_laplacian=True,
                 structured_features=True)

@ex.named_config
def attention_community_eigen():
    hyper = dict(dataset="CommunitySmall_20",
                 device="cuda:3",
                 n_attention_layers=3,
                 disc_contrast="fake_fake",
                 cut_train_size=False,
                 edge_readout="attention_weights",
                 architecture="attention",
                 dataset_kwargs=dict(DATA_DIR="/peagang/data"),
                 label_one_hot=None,
                 embed_dim=25,
                 finetti_dim=25,
                 kc_flag=True,
                 disc_conv_channels=[32, 64, 64, 64, 128, 128, 128],
                 cycle_opt="finetti_noDS",
                 score_function="sigmoid",
                 finetti_trainable=True,
                 finetti_train_fix_context=False,
                 dynamic_finetti_creation=False,
                 replicated_Z=False,
                 finneti_MLP=False,
                 node_feature_dim=4,
                 disc_eigenfeat=True,
                 use_laplacian=True,
                 structured_features=True)

@ex.named_config
def condgen_dblp():
    hyper = dict(dataset="condgen_dblp",
                 node_feature_dim=10,
                 dataset_kwargs=dict(DATA_DIR="/home/AUTHOR/graphs/data_dblp"),
                 label_one_hot=None)


@ex.named_config
def condgen_tcga():
    hyper = dict(dataset="condgen_tcga",
                 label_one_hot=None,
                 node_feature_dim=10,  # TODO: check, this should not need 10 for the node_feature+1 setup?
                 dataset_kwargs=dict(DATA_DIR="/home/AUTHOR/graphs/data_tcga"))



