import numpy as np
from matplotlib import pyplot as plt
import glob
import re

import torch


import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

from quantile_datamodule import QuantileDataModule
from quantile_network import QuantileNetwork

from callbacks import EarlyStopping, ModelCheckpoint
import config

# To stop lightning from complaining about precision
torch.set_float32_matmul_precision("high")

def automatically_assign_version(name_base_model):
    """
    """

    list_versions = glob.glob(
        config.CHECKPOINT_DIR + f"{name_base_model}_quant_*"
    )
    list_versions = [
        int(re.findall(r"\d+", version)[-1]) for version in list_versions
    ]
    if len(list_versions) == 0:
        version = 0
    else:
        version = max(list_versions) + 1

    return "v"+str(version)


if __name__ == "__main__":
    name_base_model = "resnet34_svhn"
    assert name_base_model in config.PRETRAINED_MODELS, f"Model {name_base_model} not implemented."
    version = automatically_assign_version(name_base_model)

    quant_model = QuantileNetwork(name_base_model)
    quant_datamodule = QuantileDataModule(name_base_model)

    logger = TensorBoardLogger(
        "tb_logs", name=f"{name_base_model}_quant_{version}"
    )
    checkpoint_callback = ModelCheckpoint(
        dirpath=config.CHECKPOINT_DIR
        + f"{name_base_model}_quant_{version}/",
        save_top_k=1,
        monitor="train_acc",
        mode="max",
        save_weights_only=True # Needed for CyclicLR
    )

    trainer = pl.Trainer(
            strategy="ddp_find_unused_parameters_true",
            logger=logger,
            accelerator=config.ACCERLATOR,
            devices=config.DEVICES,
            min_epochs=1,
            max_epochs=config.EPOCHS,
            precision=config.PRECISION,
            callbacks=[
                checkpoint_callback,
            ],
            log_every_n_steps=10,
            check_val_every_n_epoch=config.CHECK_VAL_EVERY_N_EPOCHS,
            max_time="00:24:00:00"
        )
    
    trainer.fit(quant_model, datamodule=quant_datamodule)
    trainer.validate(quant_model, datamodule=quant_datamodule)
    trainer.test(quant_model, datamodule=quant_datamodule)
    