import torch
from torch import nn, optim
from torch.nn import functional as F

from src.adjacency.importance_cnn import ImportanceCNN
from src.adjacency.importance_resnet_image import ImportanceResnetImage
from src.image_patch.img_to_patch import img_to_patch, compute_num_patches
from src.layers.patch_cnn import PatchCNN
from src.models.base_model import BaseModule


class PatchImportance(BaseModule):
    def __init__(self, embedding_size, aggregate_operation, num_classes, lr, use_importance, patch_size,
                 importance_method, optimizer='adam'):
        super().__init__(patch_size=patch_size, num_classes=num_classes)
        self.name = f'patch_importance_{importance_method}'
        self.lr = lr
        self.importance_method = importance_method
        self.number_of_patches, self.patches_in_row = compute_num_patches(self.patch_size, overlap=0)
        self.strategy.get_default_loss()
        self.use_importance = use_importance
        self.hidden_size = embedding_size
        self.patch_cnn = PatchCNN(embedding_size)
        if self.importance_method == 'cnn':
            self.importance = ImportanceCNN(embedding_size, out_dim=self.number_of_patches)
        elif self.importance_method == 'cnn_image':
            self.importance = ImportanceResnetImage(3, out_dim=self.number_of_patches)
        else:
            raise Exception('Unknown importance method')
        self.aggregate_operation = aggregate_operation
        self.lin1 = nn.Linear(embedding_size, embedding_size)
        self.lin2 = nn.Linear(embedding_size, self.output_size)
        self.optimizer = optimizer
        self.save_hyperparameters()

    def forward(self, x):
        B, _, _, _ = x.shape
        patches = img_to_patch(x, patch_size=self.patch_size, flatten_channels=False).reshape(
            B * self.number_of_patches, 3, self.patch_size, self.patch_size)

        patches_embeddings = self.patch_cnn(patches).reshape(B, self.number_of_patches, -1)
        if self.importance_method == 'cnn':
            importance_input = patches_embeddings.reshape(B, self.patches_in_row, self.patches_in_row, -1).permute(0, 3,
                                                                                                                   1, 2)
        elif self.importance_method == 'cnn_image':
            importance_input = x
        else:
            importance_input = patches_embeddings

        importance = self.importance(importance_input)

        if self.use_importance:
            x = patches_embeddings * importance
        else:
            x = patches_embeddings

        if self.aggregate_operation == 'mean':
            x = x.mean(axis=1)
        elif self.aggregate_operation == 'sum':
            x = x.sum(axis=1)
        else:
            raise Exception('Unknown convolution operation')

        x = F.relu(x)
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)

        return {"logits": x, "importance": importance}

    def training_step(self, batch, batch_idx):
        step = self._shared_step(batch, "train")
        self.metrics.train_acc(step["preds"], batch["y"])
        self.log('train/acc', self.metrics.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        return step["loss"]

    def test_step(self, batch, batch_idx):
        step = super().test_step(batch, batch_idx)

    def configure_optimizers(self):
        if self.optimizer == 'adam':
            return torch.optim.AdamW(self.parameters(), lr=self.lr)
        elif self.optimizer == 'adam':
            optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=0.9, weight_decay=1e-4)
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
            return [optimizer], [scheduler]
        else:
            raise ValueError(f"Unknown optimizer: {self.optimizer}")
