import abc
from datetime import datetime
import glob
import os

import GPUtil
import pl_bolts
import pytorch_lightning as pl
import torch
import torchmetrics
from torch.nn import functional as F
import torchvision
from torchvision import models, transforms
import timm

# import ffcv
import yaml
import pdbr

import rgd
import utils

local_data_dir = glob.glob("/export/io*/data/sslocum/data")[0]
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


def make_adam(params, lr, beta1, beta2, **kwargs):
    return torch.optim.Adam(params, lr, betas=(beta1, beta2), **kwargs)


optimizers = {
    "cm": torch.optim.SGD,
    "adam": make_adam,
    "rgd": rgd.RGD,
    "pd": rgd.PowerKinetic,
}


class StopOnLambda(pl.callbacks.Callback):
    """
    A utility class that stops the pytorch lightning run when some condition
    lambda becomes true.
    """

    def __init__(self, monitor, lambd, on):
        super().__init__()
        self.monitor = monitor
        self.lambd = lambd
        self.on = on

    def on_train_epoch_end(self, trainer, module):
        if self.on == "train_epoch_end":
            self._run_stopping_check(trainer)

    def on_validation_epoch_end(self, trainer, module):
        if (
            self.on == "validation_epoch_end"
            and trainer.state == pl.trainer.states.TrainerState.FITTING
            and not trainer.sanity_checking
        ):
            self._run_stopping_check(trainer)

    def _run_stopping_check(self, trainer):
        metric_value = trainer.callback_metrics.get(self.monitor)
        trainer.should_stop = self.lambd(metric_value)


class TestProblem(pl.LightningModule):
    def __init__(self, opt_name, opt_hps):
        super().__init__()
        self.net = self.get_backbone()
        self.optimizer = optimizers[opt_name](self.net.parameters(), **opt_hps)

    @abc.abstractmethod
    def get_backbone(self):
        pass

    def configure_optimizers(self):
        return self.optimizer

    @classmethod
    def get_datamodule(self, num_workers):
        pass

    @classmethod
    def get_mc_callback(cls):
        pass

    @classmethod
    def run(cls, opt_name, opt_hps, results_dir):
        model = cls(opt_name, opt_hps)
        dm = cls.get_datamodule()

        stop_on_bad_loss = StopOnLambda(
            "train_loss", lambda x: not torch.isfinite(x), on="train_epoch_end"
        )

        trainer = pl.Trainer(
            max_epochs=cls.epochs,
            callbacks=[stop_on_bad_loss, cls.get_mc_callback()],
            track_grad_norm=2,
            default_root_dir=results_dir,
            logger=pl.loggers.CSVLogger(save_dir=results_dir),
            # strategy="ddp",
            # precision=16,
            gpus=1,
        )

        if os.path.exists(results_dir):
            print(f"Error: Run output already exists at {results_dir}.")
            print("Please delete it if you would like to overwrite it.")
            return
        os.makedirs(os.path.split(results_dir)[0], exist_ok=True)
        print(f"Logging run in {results_dir}")

        trainer.fit(model, dm)


class ClassificationProblem(TestProblem):
    def __init__(self, opt_name, opt_hps):
        super().__init__(opt_name, opt_hps)
        self.train_acc = torchmetrics.Accuracy()
        self.train_top5_acc = torchmetrics.Accuracy(top_k=5)

        self.val_acc = torchmetrics.Accuracy()
        self.val_top5_acc = torchmetrics.Accuracy(top_k=5)

        self.test_acc = torchmetrics.Accuracy()
        self.test_top5_acc = torchmetrics.Accuracy(top_k=5)

    def _step(self, batch, acc, top5_acc):
        x, y = batch
        y_hat = self.net(x)
        probs = F.softmax(y_hat, dim=-1)

        loss = F.cross_entropy(y_hat, y)
        acc(probs, y)
        top5_acc(probs, y)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._step(batch, self.train_acc, self.train_top5_acc)
        self.log("train_loss", loss)
        self.log("train_accuracy", self.train_acc)
        self.log("train_top5_accuracy", self.train_top5_acc)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._step(batch, self.val_acc, self.val_top5_acc)
        self.log("val_loss", loss)
        self.log("val_accuracy", self.val_acc)
        self.log("val_top5_accuracy", self.val_top5_acc)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self._step(batch, self.test_acc, self.test_top5_acc)
        self.log("test_loss", loss)
        self.log("test_accuracy", self.test_acc)
        self.log("test_top5_accuracy", self.test_top5_acc)
        return loss

    @classmethod
    def get_mc_callback(cls):
        return pl.callbacks.ModelCheckpoint(monitor="val_accuracy", mode="max")


class CIFAR100_Resnet32(ClassificationProblem):
    epochs = 350
    batch_size = 256

    def __init__(self, opt_name, opt_hps):
        opt_hps = {**opt_hps, **{"weight_decay": 1e-4}}
        super().__init__(opt_name, opt_hps)

    def get_backbone(self):
        return utils.ResNet32()

    @classmethod
    def get_datamodule(cls, num_workers=4):
        return utils.cifar100_datamodule(local_data_dir, cls.batch_size, num_workers)


class CIFAR100_Resnet34(ClassificationProblem):
    epochs = 350
    batch_size = 128

    def __init__(self, opt_name, opt_hps):
        opt_hps = {**opt_hps, **{"weight_decay": 1e-4}}
        super().__init__(opt_name, opt_hps)

    def get_backbone(self):
        return utils.ResNet34()

    @classmethod
    def get_datamodule(cls, num_workers=4):
        return utils.cifar100_datamodule(local_data_dir, cls.batch_size, num_workers)


class Imagenet_Resnet(ClassificationProblem):
    epochs = 120
    batch_size = 64

    @abc.abstractmethod
    def __init__(self, opt_name, opt_hps):
        opt_hps = {**opt_hps, **{"weight_decay": 1e-4}}
        super().__init__(opt_name, opt_hps)

    @classmethod
    def write_imagenet_to_ffcv(cls):
        def write_ds(dataset, split):
            write_path = f"{local_data_dir}/ImageNet-FFCV/{split}.beton"
            writer = ffcv.writer.DatasetWriter(
                write_path,
                {
                    "image": ffcv.fields.RGBImageField(
                        max_resolution=256,
                        jpeg_quality=100,
                    ),
                    "label": ffcv.fields.IntField(),
                },
                num_workers=4,
            )
            writer.from_indexed_dataset(dataset)

        os.makedirs(f"{local_data_dir}/ImageNet-FFCV", exist_ok=True)

        train = torchvision.datasets.ImageNet(
            f"{local_data_dir}/ImageNet-ImageFolderFormat", "train"
        )
        test = torchvision.datasets.ImageNet(
            f"{local_data_dir}/ImageNet-ImageFolderFormat", "val"
        )
        train, val = torch.utils.data.random_split(
            train,
            lengths=[len(train) - 50000, 50000],
            generator=torch.Generator().manual_seed(42),
        )

        write_ds(train, "train")
        write_ds(val, "val")
        write_ds(test, "test")

    @classmethod
    def get_datamodule(cls, num_workers=4):
        return pl_bolts.datamodules.ImagenetDataModule(
            data_dir=f"{local_data_dir}/ImageNet-ImageFolderFormat",
            image_size=224,
            batch_size=cls.batch_size,
            num_workers=num_workers,
        )


class Imagenet_Resnet50(Imagenet_Resnet):
    def get_backbone(self):
        return models.resnet50()


class CIFAR100_ViT(ClassificationProblem):
    epochs = 350
    batch_size = 512

    @classmethod
    def get_datamodule(cls, num_workers=4):
        return utils.cifar100_datamodule(
            local_data_dir, cls.batch_size, num_workers, img_size=224
        )

    def get_backbone(self):
        return timm.create_model("vit_small_patch32_224", num_classes=100)


class WMT14_EN_DE_Transformer(TestProblem):
    batch_size = 512


problems = {
    "cifar100_resnet32": CIFAR100_Resnet32,
    "cifar100_resnet34": CIFAR100_Resnet34,
    "cifar100_vit": CIFAR100_ViT,
    "imagenet_resnet50": Imagenet_Resnet50,
    "imagenet_efficientnet": Imagenet_EfficientNet,
    "wmt14-en-de_transformer": WMT14_EN_DE_Transformer,
}
