from enum import Enum
from torchvision.models import vgg16, VGG16_Weights, vit_b_16, ViT_B_16_Weights
import timm
import detectors


from nn_compression.networks._utils import recursively_find_named_children
from ._datasets import cifar, Normalisation, imagenet
import os

IMAGENET_PATH = os.environ.get("IMAGENET_PATH")


class CvModel(str, Enum):
    RESNET18_CIFAR10 = "resnet18_cifar10"
    RESNET34_CIFAR10 = "resnet34_cifar10"
    RESNET50_CIFAR10 = "resnet50_cifar10"
    VGG16 = "vgg16"
    VIT_B_16 = "vit_b_16"

    def transforms(self):
        if self.value.startswith("resnet"):
            raise ValueError("ResNet models do not require transforms.")
        elif self.value == "vgg16":
            return VGG16_Weights.IMAGENET1K_V1.transforms()
        elif self.value == "vit_b_16":
            return ViT_B_16_Weights.IMAGENET1K_V1.transforms()
        raise ValueError(f"Unknown model: {self.value}")

    def load(self, pretrained: bool = True, train: bool = False):
        """Loads the model from the timm library. Defaults to pretrained
        and in eval mode."""
        if self.value.startswith("resnet"):
            model = timm.create_model(self.value, pretrained=pretrained)
        elif self.value == "vgg16":
            model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
        elif self.value == "vit_b_16":
            model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
            for n, l in recursively_find_named_children(model):
                # This makes the layers not count when compressing with DeepCABAC
                # They are not actually called in the forward pass in our experiments
                if "out_proj" in n:
                    l.quantisable = False  # type: ignore
        model.train(train)
        return model

    def get_dataset(self, shuffle: bool = True):
        if self.value.startswith("resnet"):
            return cifar("10", shuffle, Normalisation.CIFAR10_EDALTOCG)
        elif self.value == "vgg16":
            if IMAGENET_PATH is None:
                raise ValueError(
                    "IMAGENET_PATH not set. Use export IMAGENET_PATH=/path/to/imagenet before you run the script."
                )
            return imagenet(
                IMAGENET_PATH, shuffle, VGG16_Weights.IMAGENET1K_V1.transforms()
            )
        elif self.value == "vit_b_16":
            if IMAGENET_PATH is None:
                raise ValueError(
                    "IMAGENET_PATH not set. Use export IMAGENET_PATH=/path/to/imagenet before you run the script."
                )
            return imagenet(
                IMAGENET_PATH, shuffle, ViT_B_16_Weights.IMAGENET1K_V1.transforms()
            )
        # TODO: Expand for further datasets and models
        return cifar("10", shuffle, Normalisation.CIFAR10_EDALTOCG)

    @staticmethod
    def from_string(s: str):
        if s == "resnet18_cifar10":
            return CvModel.RESNET18_CIFAR10
        elif s == "resnet34_cifar10":
            return CvModel.RESNET34_CIFAR10
        elif s == "resnet50_cifar10":
            return CvModel.RESNET50_CIFAR10
        elif s == "vgg16":
            return CvModel.VGG16
        elif s == "vit_b_16":
            return CvModel.VIT_B_16
        else:
            raise ValueError(f"Unknown model: {s}")
