import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet

# Channel Attention block definition
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

# Custom EfficientNet Model
class CustomEfficientNet(nn.Module):
    def __init__(self, num_classes=2):
        super(CustomEfficientNet, self).__init__()
        
        # Three parallel EfficientNet models
        self.backbone1 = EfficientNet.from_name('efficientnet-b0')
        self.backbone2 = EfficientNet.from_name('efficientnet-b0')
        self.backbone3 = EfficientNet.from_name('efficientnet-b0')

        # Load pre-trained weights
        checkpoint1 = torch.load('./pth/Effib0_hsv.pth', map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        state_dict1 = checkpoint1['model_state_dict']
        self.backbone1.load_state_dict(state_dict1)

        checkpoint2 = torch.load('./pth/Effib0_ycbcr.pth', map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        state_dict2 = checkpoint2['model_state_dict']
        self.backbone2.load_state_dict(state_dict2)

        # Modify the first conv layer of each model to accept 9-channel input (for ColorCube)
        self.backbone1._conv_stem = nn.Conv2d(9, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.backbone2._conv_stem = nn.Conv2d(9, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.backbone3._conv_stem = nn.Conv2d(9, 32, kernel_size=3, stride=2, padding=1, bias=False)
        
        # Ensure weights are not frozen
        for param in self.backbone1.parameters():
            param.requires_grad = True
        for param in self.backbone2.parameters():
            param.requires_grad = True
        for param in self.backbone3.parameters():
            param.requires_grad = True
        
        # Channel Attention block
        self.channel_attention = ChannelAttention(1280)
        
        # Residual connection layers
        self.residual_conv = nn.Conv2d(2560, 1280, kernel_size=1, stride=1, bias=False)
        self.residual_bn = nn.BatchNorm2d(1280)
        self.relu = nn.ReLU(inplace=True)

        # Fully connected layer for classification
        self.fc = nn.Linear(1280, num_classes)

        # Variables to store activations and gradients for Grad-CAM
        self.gradients = None
        self.activations = None
    
    def forward(self, x):
        # Run three backbones in parallel
        features1 = self.backbone1.extract_features(x)
        features2 = self.backbone2.extract_features(x)
        features3 = self.backbone3.extract_features(x)

        # Save the activations for Grad-CAM
        self.activations = features1  # or any other backbone's features

        # Fuse the outputs of the three models
        fused_features = features1 + features2 + features3

        # Apply channel attention
        ca_features = self.channel_attention(fused_features)

        # Concatenate features (skip connection) and apply residual connection
        concat_features = torch.cat([fused_features, fused_features * ca_features], dim=1)
        residual_features = self.residual_conv(concat_features)
        residual_features = self.residual_bn(residual_features)
        residual_features = self.relu(residual_features)

        # Global Average Pooling
        gap_features = nn.AdaptiveAvgPool2d(1)(residual_features).view(residual_features.size(0), -1)
        
        # Fully connected layer for classification
        output = self.fc(gap_features)
        return output

    def save_gradients(self, grad):
        self.gradients = grad

    def get_activations_gradients(self):
        return self.activations, self.gradients
