import logging
import time
import warnings

import hydra
from data import PyArrowTTDataset

from hydra.core.config_store import ConfigStore

from misc_utils import time_formatter
from model.vae import VAE, VAEConfig, VAEDecoder, VAEEncoder


with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    from torchmetrics.aggregation import MeanMetric

from trainer.emb_trainer import EmbeddingTrainer, EmbTrainConfig

cs = ConfigStore.instance()
cs.store(name="base_train_emb_config", node=EmbTrainConfig)


@hydra.main(version_base=None, config_path="../conf", config_name="train_emb_config")
def launch_pretraining(cfg: EmbTrainConfig) -> None:
    log = logging.getLogger(__name__)

    model_cfg = VAEConfig()
    # Create model
    vae_encoder = VAEEncoder(
        model_cfg.input_dim,
        model_cfg.hidden_dims,
        model_cfg.latent_dim,
    ).to(cfg.device)
    vae_decoder = VAEDecoder(
        model_cfg.input_dim,
        model_cfg.hidden_dims,
        model_cfg.latent_dim,
    ).to(cfg.device)
    vae_model = VAE(vae_encoder, vae_decoder).to(cfg.device)
    # model = hydra.utils.instantiate(cfg.model)

    # Dataset
    log.info(f"Creating dataset")

    loading_start_time = time.time()
    # dataset = H5TTDataset("data/random_aigs.h5")
    dataset = PyArrowTTDataset(cfg.data_path)
    train_dts, eval_dts = dataset.split_dataset(0.9)

    elapsed_time = time_formatter(time.time() - loading_start_time, show_ms=False)
    log.info(f"[Finished loading dataset" f"[Elapsed Time: {elapsed_time}]")
    log.info(
        f"[ Dataset Size: [Training: {len(train_dts)}] "
        f"[Evaluation: {len(eval_dts)}]"
    )

    # Create optimizer and lr scheduler
    optimizer = hydra.utils.instantiate(cfg.optimizer, params=vae_model.parameters())
    scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)

    # Start training
    trainer = EmbeddingTrainer(
        cfg,
        vae_model,
        optimizer,
        scheduler,
        train_dts,
        eval_dts,
    )

    trainer.train()


if __name__ == "__main__":
    launch_pretraining()
