
import pytorch_lightning as pl
import timm
import torch
import torchvision.models as models
from pytorch_lightning.metrics import Accuracy

from models.vision_transformer import (LinearClassifier, ViT_linear_eval,
                                       vit_small)
from models.vits_moco import vit_small as vit_small_mocov3
from models.schduler import WarmupCosineLR

from .densenet import densenet121, densenet161, densenet169
from .googlenet import googlenet
from .inception import inception_v3
from .mobilenetv2 import mobilenet_v2
from .resnet import resnet18, resnet34, resnet50
from .vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn

all_classifiers = {
    "vgg11_bn": vgg11_bn,
    "vgg13_bn": vgg13_bn,
    "vgg16_bn": vgg16_bn,
    "vgg19_bn": vgg19_bn,
    "resnet18": resnet18,
    "resnet34": resnet34,
    "resnet50": resnet50,
    "densenet121": densenet121,
    "densenet161": densenet161,
    "densenet169": densenet169,
    "mobilenet_v2": mobilenet_v2,
    "googlenet": googlenet,
    "inception_v3": inception_v3
}

dino_vit= {
    "dino_vit_small": vit_small
}

mocov3_vit = {
    "mocov3_vit_small": vit_small_mocov3
}

def get_imagenet_model(opt, pretrained=True):
       
    # Get the torch vision model for ImageNet classification
    if opt.classifier in mocov3_vit:
        model = mocov3_vit[opt.classifier]()
        
    elif opt.classifier in dino_vit:
        print("using Dino Vit Model")
        opt.patch_size = 16
        n_last_blocks = 4
        opt.avgpool_patchtokens = False
        model = dino_vit[opt.classifier](patch_size=opt.patch_size, num_classes=0)
        embed_dim = model.embed_dim * (n_last_blocks + int(opt.avgpool_patchtokens))
        head = LinearClassifier(embed_dim)
        model = ViT_linear_eval(model, head, n_last_blocks)
        
    elif hasattr(models, opt.classifier):
        print("using Pytorch Model")
        model = getattr(models, opt.classifier)(pretrained=opt.pretrained)
    else:
        print("using Timm Model")
        model = timm.create_model(opt.classifier, pretrained=opt.pretrained)
    return model

class CIFAR10Module(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        self.criterion = torch.nn.CrossEntropyLoss()
        self.accuracy = Accuracy()
        self.model = self.get_model()

    
    def get_model(self):
        if self.hparams.classifier in all_classifiers.keys():
            model = all_classifiers[self.hparams.classifier]()
            return model
        else:
            model = torch.hub.load("chenyaofo/pytorch-cifar-models", self.hparams.classifier, pretrained=True)
            return model

    def forward(self, batch):
        images, labels = batch
        predictions = self.model(images)
        loss = self.criterion(predictions, labels)
        accuracy = self.accuracy(predictions, labels)
        return loss, accuracy * 100

    def training_step(self, batch, batch_nb):
        loss, accuracy = self.forward(batch)
        self.log("loss/train", loss)
        self.log("acc/train", accuracy)
        return loss

    def validation_step(self, batch, batch_nb):
        loss, accuracy = self.forward(batch)
        self.log("loss/val", loss)
        self.log("acc/val", accuracy)

    def test_step(self, batch, batch_nb):
        loss, accuracy = self.forward(batch)
        self.log("acc/test", accuracy)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
            momentum=0.9,
            nesterov=True,
        )
        total_steps = self.hparams.max_epochs * len(self.train_dataloader())
        scheduler = {
            "scheduler": WarmupCosineLR(
                optimizer, warmup_epochs=total_steps * 0.3, max_epochs=total_steps
            ),
            "interval": "step",
            "name": "learning_rate",
        }
        return [optimizer], [scheduler]
    

class CIFAR100ExplainModule(CIFAR10Module):
    def __init__(self, hparams):
        super().__init__(hparams)
        self.attributes = []
        self.labels = []

    def get_model(self):
        model = torch.hub.load("chenyaofo/pytorch-cifar-models", self.hparams.classifier)
        return model
        
    def forward(self, batch):
        images, labels, attribute = batch
        self.attributes += attribute
        #####
        self.labels.append(labels)
        predictions = self.model(images)
        #####
        loss = self.criterion(predictions, labels)
        accuracy = self.accuracy(predictions, labels)
        return loss, accuracy * 100
    
class ImageNetExplainModule(CIFAR10Module):
    def __init__(self, hparams):
        super().__init__(hparams)
        self.attributes = []
        self.labels = []

    def forward(self, batch):
        images, labels, attribute = batch
        self.attributes += attribute
        #####
        self.labels.append(labels)
        # with torch.no_grad():
        predictions = self.model(images)
        # print(predictions.shape)
        # exit()
        #####
        loss = self.criterion(predictions, labels)
        accuracy = self.accuracy(predictions, labels)
        return loss, accuracy * 100
        
    def get_model(self):
        # Get the torch vision model for ImageNet classification
        
        return get_imagenet_model(self.hparams)
    
class CIFAR10ExplainModule(CIFAR10Module):
    def __init__(self, hparams):
        super().__init__(hparams)
        self.attributes = []
        self.labels = []
        
    def forward(self, batch):
        images, labels, attribute = batch
        self.attributes += attribute
        #####
        self.labels.append(labels)
        predictions = self.model(images)
        #####
        loss = self.criterion(predictions, labels)
        accuracy = self.accuracy(predictions, labels)
        return loss, accuracy * 100

class CIFAR10KNNModule(CIFAR10Module):
    def __init__(self, hparams):
        super().__init__(hparams)
        # Create an empty list to store the attributes
        self.labels = [] 

        
    def forward(self, batch):
        images, labels = batch  # Unpack the batch
        
        self.labels.append(labels)  # Add the attribute(s) to the list
        predictions = self.model(images)  # Make predictions with the model
        loss = self.criterion(predictions, labels)  # Calculate the loss
        accuracy = self.accuracy(predictions, labels)  # Calculate the accuracy
        return loss, accuracy * 100
    

class CIFAR10PGDExplainModule(CIFAR10Module):
    def __init__(self, hparams):
        super().__init__(hparams)
        # Create an empty list to store the attributes
        self.attributes = [] 

    def set_hook(self, hook):
        # Set the hook to be used for feature extraction
        self.hook = hook

    def enable_hook(self):
        # Enable the hook(s)
        if isinstance(self.hook, list):
            for h in self.hook:
                h.enable()
        else:
            self.hook.enable()
    def disable_hook(self):
        # Disable the hook(s)
        if isinstance(self.hook, list):
            for h in self.hook:
                h.disable()
        else:
            self.hook.disable()
        
    def forward(self, batch):
        torch.set_grad_enabled(True)  # Enable gradient calculations
        self.model.eval()  # Switch to evaluation mode
        images, labels, attribute = batch  # Unpack the batch
        self.disable_hook()  # Disable the hook(s) for attack
        images = self.pgd_attack(self.model, images, labels)  # Perform PGD attack
        
        self.attributes += attribute  # Add the attribute(s) to the list
        self.enable_hook()  # Enable the hook(s) for inference
        predictions = self.model(images)  # Make predictions with the model
        loss = self.criterion(predictions, labels)  # Calculate the loss
        accuracy = self.accuracy(predictions, labels)  # Calculate the accuracy
        return loss, accuracy * 100
    
    def pgd_attack(self, model, images, labels, alpha=0.01, iters=20):
        """
        PGD attack on a PyTorch neural network.

        Args:
            model (nn.Module): The neural network model to attack.
            images (torch.Tensor): The batch of input images to attack.
            labels (torch.Tensor): The corresponding true labels of the input images.
            alpha (float): The step size of each iteration.
            iters (int): The number of iterations to run.

        Returns:
            torch.Tensor: The perturbed images.
        """
        ori_images = images.data

        for i in range(iters):
            images.requires_grad = True
            outputs = model(images)

            loss = self.criterion(outputs, labels)
            grad = torch.autograd.grad(loss, images)[0]
            adv_images = images - alpha * grad.sign()
            eta = adv_images - ori_images
            images = (ori_images + eta).detach_()
        return images