import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.model_summary import ModelSummary
import wandb

from omegaconf import DictConfig, OmegaConf
import hydra
from hydra.core.config_store import ConfigStore

from utils import dataset_rats
from utils.eval import load_modality_clfs

from MyRATSWSLConfig import MyRATSWSLConfig
from MyRATSWSLConfig import LogConfig
from MyRATSWSLConfig import ModelConfig
from MyRATSWSLConfig import DRPMModelConfig
from MyRATSWSLConfig import JointModelConfig
from MyRATSWSLConfig import MixedPriorModelConfig
from MyRATSWSLConfig import DataConfig
from MyRATSWSLConfig import LFPDataConfig
from MyRATSWSLConfig import SPIKEDataConfig
from MyRATSWSLConfig import EvalConfig


from mv_vaes.lfp_mixedprior_vae import LFPMixedPriorVAE as LFPMixedPriorVAE
from mv_vaes.spike_mixedprior_vae import SPIKEMixedPriorVAE as SPIKEMixedPriorVAE

cs = ConfigStore.instance()
# Registering the Config class with the name 'config'.
cs.store(group="log", name="log", node=LogConfig)
cs.store(group="model", name="drpm", node=DRPMModelConfig)
cs.store(group="model", name="joint", node=JointModelConfig)
cs.store(group="model", name="mixedprior", node=MixedPriorModelConfig)
cs.store(group="eval", name="eval", node=EvalConfig)
cs.store(group="dataset", name="LFP", node=LFPDataConfig)
cs.store(group="dataset", name="SPIKE", node=SPIKEDataConfig)
cs.store(group="dataset", name="dataset", node=DataConfig)
cs.store(name="base_config", node=MyRATSWSLConfig)


@hydra.main(version_base=None, config_path="conf", config_name="config")
def run_experiment(cfg: MyRATSWSLConfig):
    print(cfg)
    pl.seed_everything(cfg.seed, workers=True)

    # init model
    model = None

    # get data loaders and specify model
    # LFP data
    if cfg.dataset.name == "LFP":
        train_loader, train_dst, val_loader, val_dst = dataset_rats.get_dataset(cfg)
        if cfg.model.name == "drpm":
            model = LFPDRPMVAE(cfg)
        elif cfg.model.name == "joint":
            model = LFPJointVAE(cfg)
        elif cfg.model.name == "mixedprior":
            model = LFPMixedPriorVAE(cfg)
    # SPIKE data
    elif cfg.dataset.name == "SPIKE":
        train_loader, train_dst, val_loader, val_dst = dataset_rats.get_dataset(cfg)
        if cfg.model.name == "drpm":
            model = SPIKEDRPMVAE(cfg)
        elif cfg.model.name == "joint":
            model = SPIKEJointVAE(cfg)
        elif cfg.model.name == "mixedprior":
            model = SPIKEMixedPriorVAE(cfg)

    assert model is not None

    summary = ModelSummary(model, max_depth=2)
    print(summary)

    # train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
    checkpoint_callback = ModelCheckpoint(
        dirpath=cfg.log.dir_logs,
        monitor=cfg.checkpoint_metric,
        mode="min",
        save_last=True,
    )
    wandb_logger = WandbLogger(
        name=cfg.log.wandb_run_name,
        config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
        project=cfg.log.wandb_project_name,
        group=cfg.log.wandb_group,
        offline=cfg.log.wandb_offline,
        entity=cfg.log.wandb_entity,
        save_dir=cfg.log.dir_logs,
    )
    trainer = pl.Trainer(
        max_epochs=cfg.model.epochs,
        devices=1,
        accelerator="gpu" if cfg.model.device == "cuda" else cfg.model.device,
        logger=wandb_logger,
        check_val_every_n_epoch=1,
        deterministic=True,
        callbacks=[checkpoint_callback],
    )

    trainer.logger.watch(model, log="all")
    trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

if __name__ == "__main__":
    run_experiment()
