"""Main file for training and testing VQ-T2G"""

import os
import time

import typer

from pathlib import Path

from rich import print

from vqt2g.run_vqt2g import (
    VQT2GConfig,
    setup_new_run,
    gvqvae_runner,
    transformer_runner,
    eval_runner,
    # sampling_runner,
)

from vqt2g.utils.utils import set_seeds, start_logger


app = typer.Typer()


###################################################################################################

gvqvae_config = typer.Argument(..., help="Config file for GVQVAE training")
gvqvae_log_level = typer.Option("INFO", help="Python logging level")
gvqvae_comment = typer.Option(None, help="Comment for this GVQVAE run")


@app.command()
def gvqvae_train(
    config_file: Path = gvqvae_config,
    log_level: str = gvqvae_log_level,
    comment: str = gvqvae_comment,
):
    """First part: Train the GVQVAE"""

    config = VQT2GConfig(config_file)
    set_seeds(config.config.seed)
    setup_new_run(config)
    print(config.config)

    log_file = Path(config.config.this_run_dir, "log_gvqvae.txt")
    start_logger(log_file, log_level)

    if comment is not None:
        comment_file = Path(config.config.this_run_dir, "gvqvae_comment.txt")
        with open(comment_file, "w") as f:
            f.write(comment)

    gvqvae_runner(config)


###################################################################################################

transformer_config = typer.Argument(..., help="Config file for transformer training")
transformer_log_level = typer.Option("INFO", help="Python logging level")
transformer_comment = typer.Option(None, help="Comment for this transformer run")


@app.command()
def transformer_train(
    config_file: Path = transformer_config,
    log_level: str = transformer_log_level,
    comment: str = transformer_comment,
):
    """Second part: Train the transformer"""

    config = VQT2GConfig(config_file)
    set_seeds(config.config.seed)
    log_file = Path(config.config.this_run_dir, "log_transformer.txt")
    start_logger(log_file, log_level)

    if comment is not None:
        comment_file = Path(config.config.this_run_dir, "transformer_comment.txt")
        with open(comment_file, "w") as f:
            f.write(comment)

    transformer_runner(config)


###################################################################################################


eval_config = typer.Argument(..., help="Config file for evaluation run")
eval_log_level = typer.Option("INFO", help="Python logging level")
eval_comment = typer.Option(None, help="Comment for this evaluation")
eval_edge_threshold = typer.Option(0.8, help="Threshold for edge adding")
eval_edge_sampling = typer.Option(False, help="Sampling or top-k method for graph edges")
eval_tfmr_sampling = typer.Option(
    False, help="Do non-deterministic sampling from transformer"
)
eval_tfmr_temp = typer.Option(0.7, help="Temperature parameter for transformer sampling")
eval_tfmr_top_p = typer.Option(0.7, help="Top-p parameter for transformer sampling")
eval_extra_edge_randomness = typer.Option(False, help="Extra randomness for edge sampling")
eval_do_train_set = typer.Option(False, help="Generate from train set instead of test set")

@app.command()
def run_vqt2g_evaluation(
    config_file: Path = eval_config,
    log_level: str = eval_log_level,
    comment: str = eval_comment,
    threshold: float = eval_edge_threshold,
    edge_sampling: bool = eval_edge_sampling,
    transformer_sampling: bool = eval_tfmr_sampling,
    transformer_temperature: float = eval_tfmr_temp,
    transformer_top_p: float = eval_tfmr_top_p,
    extra_edge_randomness: bool = eval_extra_edge_randomness,
    do_train_set: bool = eval_do_train_set,
):
    """Third part: Generate graphs from a trained model and compute MMD statistics"""

    config = VQT2GConfig(config_file)

    # Make the test folder
    test_dir = Path(config.config.this_run_dir, "test")
    os.makedirs(test_dir, exist_ok=True)

    set_seeds(config.config.seed)

    # Put log into test dir instead
    eval_fname = f"log_model_test_{time.strftime('%b_%d_%H-%M')}.txt"
    log_file = Path(test_dir, eval_fname)
    start_logger(log_file, log_level)

    edge_topk = not edge_sampling
    eval_runner(
        config=config,
        comment=comment,
        edge_threshold=threshold,
        edge_topk=edge_topk,
        extra_edge_randomness=extra_edge_randomness,
        transformer_sampling=transformer_sampling,
        temperature=transformer_temperature,
        top_p=transformer_top_p,
        do_train_set=do_train_set,
    )


###################################################################################################


# Sampling runner unfinished

# sample_config = typer.Argument(..., help="Config file for trained model to generate from")
# sample_num = typer.Option(1, help="Number of graphs to generate")
# sample_text = typer.Option(None, help="Text to condition generated graph")
# sample_seed = typer.Option(None, help="Optionally set random seed for generation")
##sample_from_file = typer.Option(
##    None, help="Text file with texts to generate from. Line breaks=new text")
#
#
# @app.command()
# def sample_from_model(
#    model_config: Path=sample_config,
#    num_graphs: int=sample_num,
#    text: str=sample_text,
#    set_seed: int=sample_seed,
# ):
#    """Sample from a trained model """
#
#    config = VQT2GConfig(model_config)
#    if set_seed is not None:
#        set_seeds(set_seed)
#
##    log_file = Path(config.config.this_run_dir, "log_sampling.txt")
##    start_logger(log_file, log_level)
#
##    sampling_runner(config=config, num_graphs=num_graphs, text=text)
#    raise NotImplementedError




if __name__ == "__main__":
    app()
