import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

class IntSoftMax(nn.Module):
    def forward(self, inputs):
        # Extract number of classes
        Nc = inputs.size(-1) // 2

        # Extract center and the radius
        center = inputs[:, :Nc]
        radius = inputs[:, Nc:]

        # Ensure the non-negativity of radius
        radius_nonneg = torch.nn.functional.softplus(radius)

        # Compute upper and lower probabilities
        lo = torch.exp(center - radius_nonneg) / (
            torch.sum(torch.exp(center), dim=-1, keepdim=True) - torch.exp(center) + torch.exp(center - radius_nonneg)
        )
        hi = torch.exp(center + radius_nonneg) / (
            torch.sum(torch.exp(center), dim=-1, keepdim=True) - torch.exp(center) + torch.exp(center + radius_nonneg)
        )

        # Generate output
        output = torch.cat([lo, hi], dim=-1)

        return output


class CreNetRES50(nn.Module):
    def __init__(self, classes, weights):
        super(CreNetRES50, self).__init__()

        # Load ResNet50 base model
        model = torchvision.models.resnet50(progress=False, weights='IMAGENET1K_V2')
        model.fc = torch.nn.Identity()

        self.base = model

        # Adjust input size for the upsampling layer
        self.upsample = nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False)

        # Additional layers
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 2 * classes)
        self.batch_norm = nn.BatchNorm1d(2 * classes)
        self.activation = IntSoftMax()

    def forward(self, x):
        x = self.upsample(x)
        x = self.base(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.batch_norm(x)
        x = self.activation(x)

        return x