import lightning
import wandb
from datamodules import CIFAR10, CIFAR100, TinyImageNet
from get_arch import get_architecture, get_cnn_architecture
from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from pc_s import PCS


def train_model(logger, run_config):
    # Make sure to always generate the *exact* same datasets & batches
    lightning.seed_everything(run_config["seed"], workers=True)

    # 1: load dataset as Lightning DataModule
    batch_size = run_config["batch_size"]
    if run_config["dataset"] == "CIFAR10":
        datamodule = CIFAR10(batch_size, is_test=run_config["is_test"])
    elif run_config["dataset"] == "CIFAR100":
        datamodule = CIFAR100(batch_size, is_test=run_config["is_test"])
    elif run_config["dataset"] == "tiny-imagenet":
        datamodule = TinyImageNet(batch_size, is_test=run_config["is_test"])
    print("Training on", datamodule.dataset_name)

    # 2: Set up Lightning trainer
    lr_monitor = LearningRateMonitor(logging_interval='step')
    trainer = Trainer(
        accelerator="gpu",
        devices=1,
        logger=logger,
        callbacks=[lr_monitor],
        max_epochs=run_config["nm_epochs"],
        inference_mode=False,  # inference_mode would interfere with the state backward pass
        limit_predict_batches=1,  # enable 1-batch prediction
        num_sanity_val_steps=0,
    )

    # 3: Get architecture that belongs to this dataset
    architecture = get_cnn_architecture(run_config["model"], datamodule.dataset_name,run_config["act_fn"])

    # 4: Initiate model and train it
    datamodule.setup("fit")
    pc = PCS(
        architecture,
        iters=run_config["iters"],
        s_lr=run_config["s_lr"],
        w_lr=run_config["w_lr"],
        w_decay=run_config["w_decay"],
        s_momentum=run_config["s_momentum"],
        output_loss=run_config["output_loss"],
        nm_batches=len(datamodule.train_dataloader()),
        nm_epochs=run_config["nm_epochs"],
    )
    trainer.fit(pc, datamodule=datamodule)

    # 5: Test results
    trainer.test(pc, datamodule=datamodule)

    # 6: Release all CUDA memory that you can
    pc = None
    trainer = None
    lightning.pytorch.utilities.memory.garbage_collection_cuda()



if __name__ == "__main__":
    config = {
        "type": "State",
        "seed": 42,
        "batch_size": 128,
        "nm_epochs": 50,
        "iters": 8,
        "s_lr": 0.014708110891388754,
        "s_momentum": 0.0,
        "w_lr": 0.0002641020393177285,
        "w_decay": 1.2107509557065504e-05,
        "output_loss": "ce",
        "model": "VGG5",
        "act_fn": "gelu",
        "dataset": "CIFAR10",
        "is_test": True,
    }

    wandb.init(project="ErrorPC")
    logger = WandbLogger(project="ErrorPC", mode="online")
    logger_config = logger.experiment.config

    # overwrite config with logger config if it exists
    for key, value in logger_config.items(): 
        config[key] = value
    logger.experiment.config.update(config) # update wandb config with the full config

    train_model(logger, config)
    wandb.finish()