import torch
from torchvision.models import resnet18, resnet50
import torch.nn as nn
import torch.optim as optim

from src.adjacency.importance_resnet_image import ImportanceResnetImage
from src.image_patch.img_to_patch import compute_num_patches
from src.models.base_model import BaseModule



class ImportanceModel(BaseModule):
    def __init__(self, embedding_size, aggregate_operation, num_classes, lr, use_importance, patch_size,
                 importance_method):
        super().__init__(patch_size=patch_size, num_classes=num_classes)
        self.name = f'patch_importance_{importance_method}'
        self.lr = lr
        self.importance_method = importance_method
        self.number_of_patches, self.patches_in_row = compute_num_patches(self.patch_size, overlap=0)
        self.strategy.get_default_loss()
        self.use_importance = use_importance
        self.hidden_size = embedding_size
        self.importance = ImportanceResnetImage(3, out_dim=self.number_of_patches)

    def forward(self, x):
        return self.importance(x)


class ResNetImportanceModel(BaseModule):
    def __init__(self, resnet, num_classes, lr, path, threshold, optimizer):
        super().__init__(num_classes=num_classes)
        self.num_classes = num_classes
        self.lr = lr
        self.name = resnet
        self.threshold = threshold
        self.optimizer = optimizer
        self.importance = ImportanceModel.load_from_checkpoint(path, strict=False)
        self.importance.eval()
        self.importance.freeze()

        if resnet == 'resnet18':
            self.model = resnet18()
        elif resnet == 'resnet50':
            self.model = resnet50()
        else:
            raise ValueError("Invalid resnet model")

        self.model.fc = nn.Linear(self.model.fc.in_features, self.output_size)
        self.loss_fn = self.strategy.get_default_loss()
        self.save_hyperparameters()

    def forward(self, x):
        B, _, _, _ = x.shape
        importance_base = self.importance(x)
        if self.threshold > 0:
            importance = importance_base > self.threshold
            importance = importance.reshape(B, self.importance.patches_in_row,
                                            self.importance.patches_in_row).repeat_interleave(
                self.importance.patch_size, dim=1).repeat_interleave(self.importance.patch_size, dim=2)
            importance = importance.unsqueeze(1).repeat(1, 3, 1, 1)
            normalized_zero = torch.tensor([-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], device=x.device).view(1, 3, 1, 1)
            x = torch.where(importance, x, normalized_zero)

        return {"logits": self.model(x), "masked_input": x, "importance": importance_base}

    def configure_optimizers(self):
        if self.optimizer == 'sgd':
            optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=0.9, weight_decay=1e-4)
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
            return [optimizer], [scheduler]
        elif self.optimizer == 'adam':
            optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer

    def test_step(self, batch, batch_idx):
        step = self._shared_step(batch, "test")
        self.metrics.test_acc(step["preds"], batch["y"])
        self.log('test/acc', self.metrics.test_acc, on_step=False, on_epoch=True)