import torch
from torchvision.models import resnet50
from torch.nn import functional as F

from src.image_patch.img_to_patch import img_to_patch
from src.layers.residual import ResidualBlock
from src.models.patch_importance import PatchImportance
import torch.nn as nn


class AttentionPool2d(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.attn = nn.Conv2d(in_channels, 1, kernel_size=1)  # 1x1 conv to produce attention map

    def forward(self, x):  # x: (B, C, H, W)
        weights = self.attn(x)  # shape: (B, 1, H, W)
        weights = torch.softmax(weights.view(x.size(0), -1), dim=-1).view_as(weights)  # normalize
        pooled = (x * weights).sum(dim=[2, 3])  # weighted sum over spatial dimensions → (B, C)
        return pooled


class CNN(nn.Module):
    def __init__(self, hidden_size=256, num_classes=10):
        super(CNN, self).__init__()
        self.block1 = ResidualBlock(in_channels=hidden_size, out_channels=512, stride=1)
        self.blocks = nn.Sequential(
            *[ResidualBlock(in_channels=512, out_channels=512, stride=1) for _ in range(2)]
        )

        self.global_pool = nn.AdaptiveMaxPool2d((1, 1))

        # Fully connected layer
        self.fc1 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.block1(x)
        x = self.blocks(x)

        # Global average pooling
        x = self.global_pool(x)  # Shape: (batch_size, 256, 1, 1)
        x = x.view(x.size(0), -1)  # Flatten: Shape (batch_size, 256)

        # Fully connected layer
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x


class PatchImportanceCNN(PatchImportance):
    def __init__(self, embedding_size, aggregate_operation, num_classes, lr, use_importance, patch_size,
                 importance_method, optimizer):
        super().__init__(embedding_size=embedding_size, aggregate_operation=aggregate_operation,
                         num_classes=num_classes, lr=lr, use_importance=use_importance, patch_size=patch_size,
                         importance_method=importance_method, optimizer=optimizer)
        self.name = 'patch_importance_cnn'
        self.aggregate_cnn = CNN(self.hidden_size, self.output_size)
        self.save_hyperparameters()

    def freeze_weights(self, freeze_importance=False, freeze_patch_cnn=False, reset_patch_cnn=False):
        if freeze_importance:
            for param in self.importance.parameters():
                param.requires_grad = False
            self.importance.eval()
        if freeze_patch_cnn:
            for param in self.patch_cnn.parameters():
                param.requires_grad = False
            self.patch_cnn.eval()
        if reset_patch_cnn:
            for layer in self.patch_cnn.modules():
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()
                    print(f"Resetting parameters of {layer.__class__.__name__}")

    def forward(self, x):
        B, _, _, _ = x.shape
        patches = img_to_patch(x, patch_size=self.patch_size, flatten_channels=False).reshape(
            B * self.number_of_patches, 3, self.patch_size, self.patch_size)

        patches_embeddings = self.patch_cnn(patches).reshape(B, self.number_of_patches, -1)
        if self.importance_method == 'cnn':
            importance_input = patches_embeddings.reshape(B, self.patches_in_row, self.patches_in_row, -1).permute(0, 3,
                                                                                                                   1, 2)
        elif self.importance_method == 'cnn_image':
            importance_input = x
        else:
            importance_input = patches_embeddings

        importance = self.importance(importance_input)

        if self.use_importance:
            x = patches_embeddings * importance
        else:
            x = patches_embeddings

        x = x.permute(0, 2, 1).reshape(B, self.hidden_size, self.patches_in_row, self.patches_in_row)
        x = self.aggregate_cnn(x)

        return {"logits": x, "importance": importance}
