"""Functions to run the different parts of VQ-T2G"""

import logging
import os
import pickle as pkl

import time

import torch
import torch_geometric as pyg
import transformers
import yaml

from collections import Counter
from pathlib import Path

from easydict import EasyDict as edict
from rich import print
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

from vqt2g.encoder_decoder import GVQEncoder, GVQDecoder
from vqt2g.generate import generate_from_text, plot_real_and_generated_graph
from vqt2g.gvqvae import GVQVAE
from vqt2g.load_dataset import load_vqt2g_dataset
from vqt2g.stats.compute_stats import compute_stats, egos_topic_stats
from vqt2g.train_gvqvae import train_autoencoder
from vqt2g.train_transformer import train_transformer
from vqt2g.tokenizer import VQT2GTokenizerBPE
from vqt2g.utils.graph_load_batch import graph_load_batch
from vqt2g.utils.gvqvae_utils import load_gvqvae
from vqt2g.vq_layer import VectorQuantizer, VectorQuantizerEMA, NoVQBottleneck


_LOG = logging.getLogger("vqt2g_logger")


class VQT2GConfig:
    """Class to hold all the config parameters"""

    def __init__(self, config_file):
        """Open config file, load with yaml, convert to easydict"""

        # Load config as yaml then easydict
        with open(config_file, "r") as f:
            config = yaml.safe_load(f)
        config = edict(config)

        self.config = config
        self.dataset = config.dataset
        self.gvqvae = config.gvqvae
        self.transformer = config.transformer
        self.evaluation = config.evaluation

        self.device = config.device

    def as_dict(self):
        """Convert the config from an EasyDict object to a regular python dictionary"""
        return self._as_dict()

    def _as_dict(self, obj=None):
        """Convert the config from an EasyDict object to a regular python dictionary"""
        config_dict = {}
        if obj is None:
            obj = self.config

        for key, val in obj.items():
            if isinstance(val, edict):
                config_dict[key] = self._as_dict(obj=val)
            elif isinstance(val, Path):
                config_dict[key] = str(val)
            else:
                config_dict[key] = val

        return config_dict

    def save_config(self, location=".", fname="config.yaml"):
        """Save the config file to a yaml file

        Args:
          location: Path (relative to run dir base) to save config (Default value = ".")
          fname: config file name (Default value = "config.yaml")

        """
        fpath = Path(self.config.this_run_dir, location, fname)
        with open(fpath, "w") as f:
            yaml.safe_dump(self.as_dict(), f, default_flow_style=False)
        _LOG.info(f"Saved config to: {fpath}")


###################################################################################################


def setup_new_run(config: VQT2GConfig) -> None:
    """Set up folder for a new experiment run"""

    # Include timestamp in run name to stop accidental overwrites
    run_name_short = config.config.run_name
    run_name = f"{run_name_short}_{time.strftime('%b_%d_%H-%M')}"
    config.config["run_name_short"] = run_name_short
    config.config["run_name"] = run_name
    this_run_dir = Path(config.config.run_dir_root, run_name)
    config.config["this_run_dir"] = str(this_run_dir)

    os.makedirs(this_run_dir, exist_ok=True)

    _LOG.info(f"Setting up new run - {this_run_dir}")
    train_test_split_dataset(config)

    # Save the config inside this directory
    config.save_config()


###################################################################################################


def train_test_split_dataset(config: VQT2GConfig) -> None:
    """Load dataset from single .pkl, train/test split, then save the split

    Args:
      config: Config object for the experiment

    Returns:
        None. Data splits are saved to files.

    """

    # Create folder to save dataset after splitting
    data_dir = Path(config.dataset.dataset_dir, config.dataset.dataset_name)
    config.dataset["dataset_save_path"] = str(data_dir)
    if not data_dir.is_dir():
        os.makedirs(data_dir)

    # Don't recompute splits if already done
    if config.dataset.is_already_split:
        _LOG.info("Train/test split already done, loading data from file")
        try:
            with open(Path(data_dir, "train_graphs.pkl"), "rb") as f:
                train_graphs = pkl.load(f)
        except FileNotFoundError:
            raise ValueError(
                "Pre-split dataset not found, set `is_already_split = False` in config to fix"
            )
        max_nodes = train_graphs[0].num_nodes
        num_feats = train_graphs[0].x.size(1)
        config.dataset["max_nodes"] = max_nodes
        config.dataset["num_node_features"] = num_feats
        return

    # Load raw dataset
    if config.dataset.dataset_name == "proteins":
        data = graph_load_batch(
            data_dir=config.dataset.raw_dataset,
            min_num_nodes=100,
            max_num_nodes=500,
            name="DD",
        )
    else:
        with open(config.dataset.raw_dataset, "rb") as f:
            data = pkl.load(f)

    # Do train/test split
    _LOG.info("Loading raw dataset: Adding node features and doing train/test split")
    train_graphs, test_graphs, train_texts, test_texts, indices = load_vqt2g_dataset(
        data=data,
        proportion_or_count=config.dataset.prop_or_count,
        test_prop=config.dataset.test_prop,
        test_num=config.dataset.test_num,
        add_node_attrs=config.dataset.add_node_features,
        attr_type=config.dataset.node_feature_type,
        shuffle=config.dataset.shuffle,
        seed=config.config.seed,
        max_dataset_size=config.dataset.max_dataset_size,
        one_hot_degree=config.dataset.degree_feature_one_hot,
    )

    max_nodes = train_graphs[0].num_nodes
    num_feats = train_graphs[0].x.size(1)
    config.dataset["max_nodes"] = max_nodes
    config.dataset["num_node_features"] = num_feats
    _LOG.info(f"Loaded dataset '{config.dataset.dataset_name}' and done train/test split")
    _LOG.info(f"Dataset max nodes = {max_nodes}; num node features = {num_feats}")

    if config.gvqvae.model.codes_per_graph == 0:
        config.gvqvae.model.codes_per_graph = max_nodes

    # Save each split separately
    with open(Path(data_dir, "train_graphs.pkl"), "wb") as f:
        pkl.dump(train_graphs, f)
    with open(Path(data_dir, "test_graphs.pkl"), "wb") as f:
        pkl.dump(test_graphs, f)
    with open(Path(data_dir, "train_texts.pkl"), "wb") as f:
        pkl.dump(train_texts, f)
    with open(Path(data_dir, "test_texts.pkl"), "wb") as f:
        pkl.dump(test_texts, f)
    with open(Path(data_dir, "train_test_indices.pkl"), "wb") as f:
        pkl.dump(indices, f)
    _LOG.info(f"Saved dataset train/test splits to: {data_dir}")



###################################################################################################


def setup_gvqvae_model(config: VQT2GConfig, device):
    """Make encoder/decoder/bottleneck"""

    ### Move this to gvqvae.py

    cfg_gvq = config.gvqvae
    cfg_model = config.gvqvae.model

    encoder = GVQEncoder(
        in_channels=config.dataset.num_node_features,
        hid_channels_1=cfg_model.encoder_channels_1,
        hid_channels_2=cfg_model.encoder_channels_2,
        output_dim=cfg_model.codebook_dim,
        max_nodes=config.dataset.max_nodes,
        conv_type=cfg_model.gnn_conv_type,
        conv_aggr=cfg_model.gnn_conv_aggr,
        num_random_feature=cfg_model.num_random_feature,
        random_feature_sd=cfg_model.random_feature_sd,
        random_feature_only=cfg_model.random_feature_only,
        pre_vq_batchnorm=cfg_model.pre_vq_batchnorm,
        use_linear_layers=cfg_model.encoder_linear_layers,
        codes_per_graph=cfg_model.codes_per_graph,
        linear_layer_dropout=cfg_model.encoder_dropout,
    )

    _LOG.debug("Created encoder")

    # Decoder
    decoder = GVQDecoder(
        in_latent_dim=cfg_model.codebook_dim,
        codes_per_graph=cfg_model.codes_per_graph,
        hidden_size_1=cfg_model.decoder_size_1,
        hidden_size_2=cfg_model.decoder_size_2,
        output_node_dim=cfg_model.output_node_dim,
        max_nodes=config.dataset.max_nodes,
    )

    _LOG.debug("Created decoder")

    # VQ bottleneck (codebook)
    use_vq = cfg_model.use_vq_bottleneck
    vq_ema = cfg_model.ema
    if not use_vq:
        vq_layer = NoVQBottleneck
    elif vq_ema:
        vq_layer = VectorQuantizerEMA
    else:
        vq_layer = VectorQuantizer
    _LOG.debug(f"Created VQ layer: {vq_layer.__name__}")

    # Full GVQVAE model
    model = GVQVAE(
        encoder=encoder,
        decoder=decoder,
        vq_bottleneck=vq_layer,
        embedding_dim=cfg_model.codebook_dim,
        codebook_size=cfg_model.codebook_size,
        commitment_cost=cfg_gvq.train.commitment_cost,
        codebook_init_sd=cfg_model.codebook_init_sd,
    ).to(device)

    return model


###################################################################################################


def gvqvae_runner(config: VQT2GConfig) -> None:
    """Run GVQVAE training"""

    cfg_base = config.config
    cfg_gvq = config.gvqvae

    device = torch.device(cfg_base.device)

    gvqvae_dir = Path(cfg_base.this_run_dir, "gvqvae")
    checkpoint_dir = Path(gvqvae_dir, "checkpoints")
    plots_dir = Path(gvqvae_dir, "plots")
    os.makedirs(gvqvae_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)
    if cfg_gvq.train.do_embedding_plots:
        os.makedirs(plots_dir, exist_ok=True)
    cfg_gvq["gvqvae_dir"] = str(gvqvae_dir)
    cfg_gvq["checkpoint_dir"] = str(checkpoint_dir)
    cfg_gvq["plots_dir"] = str(plots_dir)
    config.save_config()

    # Load graph datasets
    data_dir = config.dataset.dataset_save_path
    with open(Path(data_dir, "train_graphs.pkl"), "rb") as f:
        train_graphs = pkl.load(f)
    with open(Path(data_dir, "test_graphs.pkl"), "rb") as f:
        test_graphs = pkl.load(f)

    _LOG.info("Train/test graphs loaded")
    _LOG.info(f"Num train, test graphs = {len(train_graphs)}, {len(test_graphs)}")

    # Create model
    model = setup_gvqvae_model(config, device)
    print(model)
    _LOG.info("GVQVAE initialised")

    # Create optimizer
    use_vq = cfg_gvq.model.use_vq_bottleneck
    lr = cfg_gvq.train.learning_rate
    codebook_lr = lr * cfg_gvq.train.codebook_lr_factor
    weight_decay = cfg_gvq.train.weight_decay
    # If codebook LR different, split param groups
    if use_vq and cfg_gvq.train.codebook_lr_factor != 1.0:
        optimizer = torch.optim.Adam(
            [
                {"params": model.encoder.parameters()},
                {"params": model.decoder.parameters()},
                {"params": model.vq.parameters(), "lr": codebook_lr},
            ],
            lr=lr,
            weight_decay=weight_decay,
        )
        _LOG.info(f"Encoder+decoder learning rate = {lr}, codebook learning rate = {codebook_lr}")
    # Normal optimizer, all params have same LR
    elif use_vq:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        _LOG.info(f"Encoder+decoder+codebook learning rate = {lr}")
    else:  # If no-VQ run set LR of 'codebook' to zero
        optimizer = torch.optim.Adam(
            [
                {"params": model.encoder.parameters()},
                {"params": model.decoder.parameters()},
                {"params": model.vq.parameters(), "lr": 0.0},
            ],
            lr=lr,
            weight_decay=weight_decay,
        )
        _LOG.info(f"Encoder+decoder learning rate = {lr}, no codebook")

    # Create learning rate scheduler
    scheduler = ReduceLROnPlateau(
        optimizer,
        factor=cfg_gvq.train.lr_decay_factor,
        patience=cfg_gvq.train.lr_decay_patience,
        threshold=1e-5,  # Default 1e-4
        cooldown=0,
        verbose=True,
        min_lr=cfg_gvq.train.min_lr,
    )

    _LOG.info("Optimizer and scheduler initialised")

    # If resuming then
    if cfg_gvq.train.resume_training:
        load_gvqvae(config, model, device, optimizer, scheduler)
        _LOG.info(f"Loaded checkpoint from {config.gvqvae.test.last_checkpoint}")

    train_autoencoder(
        config=config,
        model=model,
        train_graphs=train_graphs,
        test_graphs=test_graphs,
        optimizer=optimizer,
        scheduler=scheduler,
    )

    config.save_config("gvqvae")
    config.save_config()  # Rewrite base config too
    _LOG.info("Training finished and config saved")


###################################################################################################


def transformer_runner(config: VQT2GConfig) -> None:
    """Run transformer training"""

    cfg_base = config.config
    cfg_gvq = config.gvqvae
    cfg_tfmr = config.transformer

    device = torch.device(cfg_base.device)

    tfmr_dir = Path(cfg_base.this_run_dir, "transformer")
    tfmr_model_dir = Path(tfmr_dir, "transformer")
    os.makedirs(tfmr_dir, exist_ok=True)
    os.makedirs(tfmr_model_dir, exist_ok=True)

    # Load graphs + texts
    data_dir = config.dataset.dataset_save_path
    with open(Path(data_dir, "train_graphs.pkl"), "rb") as f:
        train_graphs = pkl.load(f)
    with open(Path(data_dir, "test_graphs.pkl"), "rb") as f:
        test_graphs = pkl.load(f)
    with open(Path(data_dir, "train_texts.pkl"), "rb") as f:
        train_texts = pkl.load(f)
    with open(Path(data_dir, "test_texts.pkl"), "rb") as f:
        test_texts = pkl.load(f)

    # Load trained GVQVAE model
    gvqvae_model = setup_gvqvae_model(config, device)
    load_gvqvae(config, gvqvae_model, device)
    _LOG.info(f"Loaded GVQVAE checkpoint from {config.gvqvae.test.last_checkpoint}")

    # Encode graphs as list of codebook IDs
    num_copies = cfg_tfmr.train.num_graph_encodings
    enc_train_graphs, enc_test_graphs = [], []
    for _ in range(num_copies):
        enc_train_graphs += [
            gvqvae_model.encode_graph(g, to_numpy=False, to_ids=True, device=device)
            for g in train_graphs
        ]
        enc_test_graphs += [
            gvqvae_model.encode_graph(g, to_numpy=False, to_ids=True, device=device)
            for g in test_graphs
        ]

    # Change params depending on text-conditioned or unconditioned run
    if cfg_tfmr.model.only_use_graphs:
        # Ignore texts, replace with empty strings
        text_vocab_size = 0
        train_texts = [""] * len(train_texts)
        test_texts = [""] * len(test_texts)
    else:
        text_vocab_size = cfg_tfmr.model.text_vocab_size

    # Train the tokenizer and save
    tokenizer = VQT2GTokenizerBPE(
        train_texts + test_texts,
        text_vocab_size=text_vocab_size,
        graph_vocab_size=cfg_gvq.model.codebook_size,
        max_graph_len=cfg_gvq.model.codes_per_graph,
        normalise_text=False,
    )

    _LOG.info("Trained tokenizer")

    tokenizer_path = Path(tfmr_dir, "tokenizer.pkl")
    tokenizer.save_tokeniser(tokenizer_path)

    # Concatenate graphs and texts to use in transformer
    train_tokenized = tokenizer.tg_dataset(train_texts * num_copies, enc_train_graphs)
    test_tokenized = tokenizer.tg_dataset(test_texts * num_copies, enc_test_graphs)
    _LOG.info(f"Transformer train/test set size: {len(train_tokenized)}, {len(test_tokenized)}")

    # Train transformer
    train_transformer(
        train_tokenized,
        test_tokenized,
        output_dir=tfmr_model_dir,
        vocab_size=len(tokenizer),
        model_max_length=tokenizer.total_len,
        epochs=cfg_tfmr.train.epochs,
        batch_size=cfg_tfmr.train.batch_size,
        eval_steps=cfg_tfmr.train.eval_steps,
        learning_rate=cfg_tfmr.train.learning_rate,
        checkpoint=cfg_tfmr.train.checkpoint,
        model_embedding_size=cfg_tfmr.model.embedding_dim,
        model_num_layers=cfg_tfmr.model.num_layers,
        model_num_heads=cfg_tfmr.model.num_heads,
        max_checkpoints=cfg_tfmr.train.max_checkpoints,
    )

    cfg_tfmr.test.model = str(tfmr_model_dir)
    cfg_tfmr.test.tokenizer = str(tokenizer_path)
    config.save_config("transformer")
    config.save_config()

    _LOG.info("Finished training transformer")


###################################################################################################


def eval_runner(
    config: VQT2GConfig,
    comment: str = None,
    edge_threshold: float = 0.8,
    edge_topk: bool = True,
    extra_edge_randomness: bool = False,
    transformer_sampling: bool = False,
    temperature: float = 0.7,
    top_p: float = 0.9,
    do_train_set: bool = False,
) -> None:
    """Run evaluation of trained model"""

    # May spam logs if this isn't set here
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    cfg_base = config.config
    cfg_tfmr = config.transformer
    cfg_eval = config.evaluation

    cpu_device = torch.device(cfg_base.device)
    device = torch.device(cfg_base.device)

    test_dir = Path(cfg_base.this_run_dir, "test")
    os.makedirs(test_dir, exist_ok=True)

    # Load test graphs, texts
    data_dir = config.dataset.dataset_save_path
    data_split = "train" if do_train_set else "test"
    with open(Path(data_dir, f"{data_split}_graphs.pkl"), "rb") as f:
        test_graphs = pkl.load(f)
    with open(Path(data_dir, f"{data_split}_texts.pkl"), "rb") as f:
        test_texts = pkl.load(f)

    # Load gvqvae. Put on CPU to start then only move decoder to GPU
    gvqvae_model = setup_gvqvae_model(config, cpu_device)
    load_gvqvae(config, gvqvae_model, device)
    gvqvae_model.decoder = gvqvae_model.decoder.to(device)

    _LOG.info(f"Loaded GVQVAE checkpoint from {config.gvqvae.test.last_checkpoint}")

    # Load transformer + tokenizer
    with open(cfg_tfmr.test.tokenizer, "rb") as f:
        tokenizer = pkl.load(f)
    _LOG.info(f"Loaded tokenizer from {cfg_tfmr.test.tokenizer}")

    tfmr_model = transformers.GPT2LMHeadModel.from_pretrained(cfg_tfmr.test.model, pad_token_id=0)
    _LOG.info(f"Loaded transformer checkpoint from {cfg_tfmr.test.model}")

    # Make folders
    eval_run_name = f"eval_{time.strftime('%b_%d_%H-%M-%S')}"
    this_eval_dir = Path(test_dir, eval_run_name)
    os.makedirs(this_eval_dir, exist_ok=True)

    plots_dir = Path(this_eval_dir, "plots")
    os.makedirs(plots_dir, exist_ok=True)

    # Always make a comment file and list the generation params from CLI
    if comment is None:
        comment = ""
    else:
        comment += "\n\n"
    comment += (
        f"GENERATION CONFIG:\nEdge threshold = {edge_threshold}\nEdge top-k = {edge_topk}\n"
        f"Extra edge randomness = {extra_edge_randomness}\nTransformer sampling = "
        f"{transformer_sampling}\n"
    )
    if transformer_sampling:
        comment += f"Transformer temperature = {temperature}\nTransformer top-p = {top_p}"
    comment_file = Path(this_eval_dir, "evaluation_comment.txt")
    with open(comment_file, "w") as f:
        f.write(comment)

    test_set_size = len(test_graphs)
    test_indices = list(range(test_set_size))

    # Generate samples
    num_samples = test_set_size if cfg_eval.full_test_gen else cfg_eval.num_gen

    generated_graphs = []
    gen_graph_texts = []
    #tfmr_greedy = not transformer_sampling
    _LOG.info(f"Generating {num_samples} graphs")
    for sample_idx in tqdm(range(num_samples)):

        # Get the right item if `shuffle_test_set`
        idx = test_indices[sample_idx]

        real_graph = test_graphs[idx]
        text = test_texts[idx]

        # Generate the graph and make it a networkx.Graph object
        generated_graph = generate_from_text(
            text=text,
            gvqvae=gvqvae_model,
            transformer=tfmr_model,
            tokenizer=tokenizer,
            thresh=edge_threshold,
            real_graph=real_graph,
            keep_num_nodes=True,  # For consistency with baselines
            edge_topk=edge_topk,
            extra_edge_randomness=extra_edge_randomness,
            transformer_sampling=transformer_sampling,
            temperature=temperature,
            top_p=top_p,
        )
        generated_graphs.append(generated_graph)
        gen_graph_texts.append(gen_graph_texts)

        plot_fname = f"sample_{sample_idx:04d}.png"
        plot_real_and_generated_graph(
            real_graph=real_graph,
            generated_graph=generated_graph,
            text=text,
            save_folder=plots_dir,
            file_name=plot_fname,
        )

    # Save generated graphs as pkl
    samples_fname = "generated_graphs.pkl"
    samples_path = Path(this_eval_dir, samples_fname)

    with open(samples_path, "wb") as f:
        pkl.dump([generated_graphs, gen_graph_texts], f)

    # If skipping stats
    if not cfg_eval.compute_mmd:
        return

    _LOG.info(f"Finished generating {len(generated_graphs)} graphs. Starting MMD stats")
    gvqvae_model = gvqvae_model.to("cpu")
    _LOG.debug("Moved GVQVAE model back to CPU")

    # Compute MMD stats
    sigma = cfg_eval.sigma
    kernel = cfg_eval.kernel
    test_graphs_nx = [pyg.utils.to_networkx(graph, to_undirected=True) for graph in test_graphs]

    if cfg_eval.wiki_ego_run:
        topic_map_file = cfg_eval.wiki_ego_topic_map
        with open(topic_map_file, "r") as f:
            topic_map = yaml.safe_load(f)
        _LOG.info(
            f"Starting topic-wise stats: num test = {len(test_graphs_nx)}, num gen = "
            f"{len(generated_graphs)}"
        )

        stats = egos_topic_stats(
            graphs_test=test_graphs_nx,
            graphs_gen=generated_graphs,
            texts=test_texts,
            topic_map=topic_map,
            kernel=kernel,
            sigma=sigma,
        )

    else:
        _LOG.info(
            f"Starting stats: num test = {len(test_graphs_nx)}, num gen = {len(generated_graphs)}"
        )
        stats = compute_stats(
            graphs1=test_graphs_nx,
            graphs2=generated_graphs,
            kernel=kernel,
            sigma=sigma,
            remove_isolated=True,
        )

    stats_path = Path(this_eval_dir, "stats.pkl")
    with open(stats_path, "wb") as f:
        pkl.dump(stats, f)


###################################################################################################


def sampling_runner(config, num_graphs, text):
    """Sample graph(s) from a trained model using `text` as conditioning text"""

    ### finish later
    raise NotImplementedError
