from pathlib import Path

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import pytorch_lightning as pl
# from pl_bolts.datamodules import CIFAR10DataModule
# from pl_bolts.transforms.dataset_normalizations import cifar10_normalization

from models.architectures.simple_conv import SimpleConvConfig, SimpleConv
from models.training.supervised import SupervisedTraining
from models.utils import get_tb_logger, get_checkpoints_callback


class TestDataset(Dataset):
    def __len__(self) -> int:
        return 64

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        x = torch.rand((3, 32, 32))
        y = torch.randint(low=0, high=1, size=(1,))
        return x, y.item()


def test_training(tmp_path: Path):
    # Define a Pytorch Lightning Module with all necessary configuration
    # See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
    class SimpleConvTraining(SupervisedTraining):
        def configure_optimizers(self):
            return torch.optim.Adam(self.model.parameters(), lr=0.02)

    # Configure and create the model
    model_config = SimpleConvConfig(
        id="test",
        input_size=32,
        conv_layer_filters=[16, 32, 64],
        fc_layer_units=[1024, 512, 10],
    )
    model = SimpleConv(model_config)
    training_task = SimpleConvTraining(model=model)

    # Some Pytorch Lightning DataModule (https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html)
    # train_transforms = transforms.Compose([
    #     transforms.RandomCrop(32, padding=4),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor(),
    #     cifar10_normalization(),
    # ])
    # test_transforms = transforms.Compose([
    #     transforms.ToTensor(),
    #     cifar10_normalization(),
    # ])
    # data = CIFAR10DataModule(
    #     data_dir=tmp_path / "data",
    #     batch_size=64,
    #     num_workers=4,
    #     train_transforms=train_transforms,
    #     test_transforms=test_transforms,
    #     val_transforms=test_transforms,
    # )
    # training_task.datamodule = data

    dataset = TestDataset()
    dataloader = DataLoader(dataset, batch_size=64, num_workers=1)

    # Create a trainer and fit the model
    trainer = pl.Trainer(
        gpus=0,
        max_epochs=1,
        limit_train_batches=1,
        logger=get_tb_logger(tmp_path / "artifacts/logs/test_training"),
        callbacks=[
            get_checkpoints_callback(
                tmp_path / "artifacts/checkpoints/test_training"
            ),
        ]
    )
    trainer.fit(training_task, dataloader)
