import os
import platform
from PIL import Image
import numpy as np

if platform.system() == "Linux":
    import torch
    import torchvision
    import torch.nn as nn
    import torch.nn.functional as F
    from torchvision import transforms
else:
    raise Exception(f"Error: ICNet not supported on {platform.system()}.")


class slam(nn.Module):
    def __init__(self, spatial_dim):
        super(slam,self).__init__()
        self.spatial_dim = spatial_dim
        self.linear = nn.Sequential(
             nn.Linear(spatial_dim**2,512),
             nn.ReLU(),
             nn.Linear(512,1),
             nn.Sigmoid()
        )

    def forward(self, feature):
        n,c,h,w = feature.shape
        if (h != self.spatial_dim):
            x = F.interpolate(feature,size=(self.spatial_dim,self.spatial_dim),mode= "bilinear", align_corners=True)
        else:
            x = feature

        x = x.view(n,c,-1)
        x = self.linear(x)
        x = x.unsqueeze(dim =3)
        out = x.expand_as(feature)*feature

        return out
        

class to_map(nn.Module):
    def __init__(self,channels):
        super(to_map,self).__init__()
        self.to_map = nn.Sequential(
            nn.Conv2d(in_channels=channels,out_channels=1, kernel_size=1,stride=1),
            nn.Sigmoid()
        )

    def forward(self,feature):
        return self.to_map(feature)


class conv_bn_relu(nn.Module):
    def __init__(self,in_channels, out_channels, kernel_size = 3, padding = 1, stride = 1):
        super(conv_bn_relu,self).__init__()
        self.conv = nn.Conv2d(in_channels= in_channels, out_channels= out_channels, kernel_size= kernel_size, padding= padding, stride = stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class up_conv_bn_relu(nn.Module):
    def __init__(self,up_size, in_channels, out_channels = 64, kernal_size = 1, padding =0, stride = 1):
        super(up_conv_bn_relu,self).__init__()
        self.upSample = nn.Upsample(size = (up_size,up_size),mode="bilinear",align_corners=True)
        self.conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size = kernal_size, stride = stride, padding= padding)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.act = nn.ReLU()

    def forward(self,x):
        x = self.upSample(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x


class ICNet(nn.Module):
    def __init__(self, is_pretrain = True, size1 = 512, size2 = 256):
        super(ICNet,self).__init__()
        resnet18Pretrained1 = torchvision.models.resnet18(pretrained= is_pretrain)
        resnet18Pretrained2 = torchvision.models.resnet18(pretrained= is_pretrain)

        self.size1 = size1
        self.size2 = size2

        ## detail branch
        self.b1_1 = nn.Sequential(*list(resnet18Pretrained1.children())[:5])   
        self.b1_1_slam = slam(32)

        self.b1_2 = list(resnet18Pretrained1.children())[5]   
        self.b1_2_slam = slam(32)

        ## context branch
        self.b2_1 = nn.Sequential(*list(resnet18Pretrained2.children())[:5])
        self.b2_1_slam = slam(32)

        self.b2_2 = list(resnet18Pretrained2.children())[5]
        self.b2_2_slam = slam(32)

        self.b2_3 = list(resnet18Pretrained2.children())[6]
        self.b2_3_slam = slam(16)

        self.b2_4 = list(resnet18Pretrained2.children())[7]
        self.b2_4_slam = slam(8)

        ## upsample
        self.upsize = size1 // 8
        self.up1 = up_conv_bn_relu(up_size = self.upsize, in_channels = 128, out_channels = 256)
        self.up2 = up_conv_bn_relu(up_size = self.upsize, in_channels = 512, out_channels = 256)

        ## map prediction head
        self.to_map_f = conv_bn_relu(256*2,256*2)
        self.to_map_f_slam = slam(32)
        self.to_map = to_map(256*2)

        ## score prediction head
        self.to_score_f = conv_bn_relu(256*2,256*2)
        self.to_score_f_slam = slam(32)
        self.head = nn.Sequential(
            nn.Linear(256*2,512),
            nn.ReLU(),
            nn.Linear(512,1),
            nn.Sigmoid()
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))

    def forward(self,x1):
        assert(x1.shape[2] == x1.shape[3] == self.size1)
        x2 = F.interpolate(x1, size= (self.size2,self.size2), mode = "bilinear", align_corners= True)

        x1 = self.b1_2_slam(self.b1_2(self.b1_1_slam(self.b1_1(x1))))
        x2 = self.b2_2_slam(self.b2_2(self.b2_1_slam(self.b2_1(x2))))
        x2 = self.b2_4_slam(self.b2_4(self.b2_3_slam(self.b2_3(x2))))

        x1 = self.up1(x1)
        x2 = self.up2(x2)
        x_cat = torch.cat((x1,x2),dim = 1)

        cly_map = self.to_map(self.to_map_f_slam(self.to_map_f(x_cat)))

        score_feature = self.to_score_f_slam(self.to_score_f(x_cat))
        score_feature = self.avgpool(score_feature)
        score_feature = score_feature.squeeze()
        score = self.head(score_feature)
        score = score.squeeze()

        return score,cly_map


device = "cuda" if torch.cuda.is_available() else "cpu"

inference_transform = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

current_directory = os.path.dirname(os.path.abspath(__file__))
weight_path = os.path.join(current_directory, 'saved_models', 'ck.pth')

model = ICNet()
if device == 'cpu':
    model.load_state_dict(torch.load(weight_path, map_location=torch.device('cpu')))
else:
    model.load_state_dict(torch.load(weight_path))
model.eval()
model.to(device)


# IC9600 score
def IC9600_score(input_image):
    ori_img = Image.open(input_image).convert("RGB")
    img = inference_transform(ori_img)
    img = img.to(device)
    img = img.unsqueeze(0)
    with torch.no_grad():
        ic_score, _ = model(img)
    ic_score = ic_score.item()
    return ic_score


# black pixel percentage
def black_percentage(input_img_path):
    image = Image.open(input_img_path)
    image_array = np.array(image)
    threshold = 50
    total_pixels = image_array.size
    black_pixels = np.sum(image_array < threshold)
    black_pixel_ratio = black_pixels / total_pixels
    return black_pixel_ratio


def IC_score(input_path):
    f_score = 0.5 * black_percentage(input_path) + 0.5 * IC9600_score(input_path)
    return f_score

