from argparse import Namespace

import attr


@attr.s
class PEAWGAN_HyperParameters:
    # hyper parameters
    dataset = attr.ib(
        default="MolGAN_5k"
    )  # MolGAN, MolGAN_5k, MolGAN_kC4, MolGAN_kC5, MolGAN_kC6, CommunitySmall_12, CommunitySmall_20 anu_graphs_chordal9
    batch_size = attr.ib(default=20)
    shuffle = attr.ib(default=True)
    # separate disc and gen learning rates for TTUR https://github.com/bioinf-jku/TTUR/blob/master/WGAN_GP/gan_64x64_FID.py : disc 3e-4 G 1e-4 beta=0,0.9
    # ema motivated by https://arxiv.org/pdf/1806.04498.pdf
    architecture=attr.ib(default="attention",validator=attr.validators.in_({
        "attention",
        "mlp",
        "deepset",
        "rnn"
    }))
    disc_optim_args = attr.ib(
        factory=lambda: dict(
            lr=1e-4,
            betas=(0.5, 0.9999),
            eps=1e-8,
            weight_decay=1e-3,
            ema=False,
            ema_start=100,
        )
    )  # recommended setting from WGAN-GP/optimistic Adam:wq paper, half the learning rate tho
    gen_optim_args = attr.ib(
        factory=lambda: dict(
            lr=1e-4,
            betas=(0.5, 0.9999),
            eps=1e-8,
            weight_decay=1e-3,
            ema=False,
            ema_start=100,
        )
    )  # recommended setting from WGAN-GP/optimistic Adam:wq paper, half the learning rate tho
    extra_adam = attr.ib(default=True)
    reduce_every = attr.ib(default=100)
    lr_gamma = attr.ib(default=0.1)
    embed_dim = attr.ib(default=50)
    finetti_dim = attr.ib(default=50)
    label_one_hot = attr.ib(default=5)  # same as in node feature dim...
    node_feature_dim = attr.ib(
        default=5
    )  # + (9 - 2)+1)  # number of cycles+number of nodes in graph
    disc_conv_channels = attr.ib(default=[32, 64, 64, 64])
    LP = attr.ib(default=True)  # leaky penalty, False,True, or "ZP" string
    penalty_lambda = attr.ib(default=5)
    penalty_onfake = attr.ib(default=False)
    penalty_onreal = attr.ib(default=False)
    generator_every = attr.ib(default=5)
    attention_mode = attr.ib(
        default="QK", validator=attr.validators.in_({"QQ", "QK"})
    )  # QQ,QK, other
    edge_readout = attr.ib(
        default="attention_weights",
        validator=attr.validators.in_(
            {
                "biased_sigmoid",
                "rescaled_softmax",
                "gaussian_kernel",
                "attention_weights",
            }
        ),
    )
    edge_bias_mode = attr.ib(
        default="scalar"
    )  # nodes/scalar for biased sigmoid, True/False for rescaled_softmax
    edge_bias_hidden = attr.ib(
        default=128
    )  # nodes/scalar for biased sigmoid, True/False for rescaled_softmax
    cycle_opt = attr.ib(
        default="finetti_noDS",
        validator=attr.validators.in_(
            {"standard", "finetti_noDS", "finetti_ds"}
        ),
    )
    disc_readout_hidden = attr.ib(default=32)
    n_attention_layers = attr.ib(default=12)
    num_workers = attr.ib(default=0)
    attention_inner_layers = attr.ib(type=list, factory=list)
    num_heads = attr.ib(default=1)
    discretization = attr.ib(default="relaxed_bernoulli")
    disc_swish = attr.ib(default=True)
    deep_gen_inner_act = attr.ib(default=None)  # swish, relu
    deep_gen_out_act = attr.ib(default=None)
    disc_spectral_norm = attr.ib(default=None)  # none, diff,nondiff
    disc_dropout = attr.ib(default=None)  # none, diff,nondiff
    gen_spectral_norm = attr.ib(default=None)  # none, diff,nondiff
    temperature = attr.ib(
        default=2 / 3.0
    )  # the lower the more discrete, but less smooth. Taking 2/3 form http://www.stats.ox.ac.uk/~cmaddis/pubs/concrete.pdf
    disc_contrast = attr.ib(
        default="real_fake", validator=attr.validators.in_({"real_fake", "fake_fake"})
    )  # real/fake node embeddings with fake adjacency matrix
    disc_penalty_mode = attr.ib(
        default="interpolate_Adj",
        validator=attr.validators.in_(
            {
                "avg_grads",
                "interpolate_Adj",
                "interpolate_emebbeddings",
                "GNN_layerwise_penatlty",
            }
        ),
    )
    save_dir=attr.ib(default=None)

    # TODO: figure out how to move away from dynamic init
    #seed_batch_shape = None,
    #seed_batch_size = None,
    #seedN = None,
    finetti_trainable=attr.ib(default=True)
    flip_finetti = attr.ib(default=True)
    finetti_train_fix_context=attr.ib(default=True)
    dynamic_finetti_creation=attr.ib(default=False)
    replicated_Z=attr.ib(default=False)
    exp_name=attr.ib(default=None)

    structured_features=attr.ib(default=False)
    k_eigenvals=attr.ib(default=4)
    use_laplacian=attr.ib(default=False)
    large_N_approx=attr.ib(default=False)

    @classmethod
    def with_updates(cls, **kwargs):
        c = PEAWGAN_HyperParameters()
        for k, v in kwargs.items():
            setattr(c, k, v)
        return c


@attr.s
class Parameters:
    node_count_weights = attr.ib()
    model_n = attr.ib()
    base_dir = attr.ib()
    data_dir = attr.ib(default=".")
    filename = attr.ib(default=".")
    hyper = attr.ib(
        factory=PEAWGAN_HyperParameters,
        validator=attr.validators.instance_of(PEAWGAN_HyperParameters),
    )
    # Deep toggle...can be removed after further fixes
    deep = attr.ib(
        default=False
    )  # kept for compatability, deep_disc/gen overrides this if set
    deep_disc = attr.ib(default=None)
    deep_gen = attr.ib(default=None)

    def to_namespace(self):
        return Namespace(
            node_count_weights=self.node_count_weights,
            data_dir=self.data_dir,
            filename=self.filename,
            base_dir=self.base_dir,
            model_n=self.model_n,
            hyper=attr.asdict(self.hyper),
            deep=self.deep,
            deep_gen=self.deep_gen,
            deep_disc=self.deep_disc,
        )

    @classmethod
    def from_namespace(self, ns: Namespace):
        if isinstance(ns, Namespace):
            ns = vars(ns)

        if "hyper" not in ns:
            ns["hyper"] = PEAWGAN_HyperParameters(
                {k.name: ns[k.name] for k in attr.fields(PEAWGAN_HyperParameters)}
            )

        return Parameters(
            node_count_weights=ns["node_count_weights"],
            data_dir=ns["data_dir"],
            filename=ns["filename"],
            base_dir=ns["base_dir"],
            model_n=ns["model_n"],
            hyper=PEAWGAN_HyperParameters(**ns["hyper"]),
            deep=ns["deep"],
            deep_gen=ns["deep_gen"],
            deep_disc=ns["deep_disc"],
        )
