import torch
import torch.nn as nn

import numpy as np
from PIL import Image
import PIL
PIL.Image.ANTIALIAS = PIL.Image.LANCZOS
from torchvision import transforms
import copy

benchmark=['./img/s1.jpg','./img/s2.jpg','./img/s3.jpg','./img/s4.jpg','./img/s5.jpg','./img/s6.jpg','./img/s7.jpg','./img/s8.jpg','./img/s9.jpg','./img/s10.jpg','./img/s11.jpg','./img/s12.jpg','./img/s501.jpg','./img/s502.jpg','./img/s503.jpg','./img/s504.jpg','./img/s505.jpg','./img/s506.jpg','./img/s507.jpg','./img/s508.jpg','./img/s509.jpg','./img/s510.jpg','./img/s511.jpg','./img/s512.jpg']
how_many=10

target_model_file = './out_model/B1.pth'
AE_model_file = './AE_model/B1.pth'

img_file='./img_candidate/s760.jpg'
humnan_annotated_img='./img_candidate/s760b.jpg'

title='(d) Group B subject 87'
save_file="b2.png"



import matplotlib.pyplot as plt
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")



def mse_loss(y_true, y_pred):

    diff = y_true - y_pred

    squared_diff = np.square(diff)

    mse = np.mean(squared_diff)
    return mse

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, 2, 1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3, 2, 1)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, 64, 3, 2, 1)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(64, 128, 3, 2, 1)
        self.relu4 = nn.ReLU()

        self.fc1 = nn.Linear(128 * 8 * 8, 64)
        self.fc2 = nn.Linear(64, 128 * 8 * 8)

        self.urelu1 = nn.ReLU()
        self.upconv1 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.urelu2 = nn.ReLU()
        self.upconv2 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
        self.urelu3 = nn.ReLU()
        self.upconv3 = nn.ConvTranspose2d(32, 16, 4, 2, 1)
        self.urelu4 = nn.ReLU()
        self.upconv4 = nn.ConvTranspose2d(16, 1, 4, 2, 1)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = torch.flatten(x, 1)
        self.emb = self.fc1(x)
        x = self.fc2(self.emb)
        x = self.urelu1(x)
        x = x.view(-1, 128, 8, 8)
        x = self.upconv1(x)
        x = self.urelu2(x)
        x = self.upconv2(x)
        x = self.urelu3(x)
        x = self.upconv3(x)
        x = self.urelu4(x)
        x = self.upconv4(x)
        x = self.sig(x)

        return x

    def emb(self, x):
        x1 = self.conv1(x)
        x = self.relu1(x1)
        x2 = self.conv2(x)
        x = self.relu2(x2)
        x = self.conv3(x)
        x = self.relu3(x)
        x4 = self.conv4(x)
        x = self.relu4(x4)
        x = torch.flatten(x, 1)
        x = self.fc1(x)

        return x


def get_most_effective_unit(model, benchmark):

    model.eval()
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # Resize image to match model input size
        transforms.Grayscale(num_output_channels=1),

        transforms.ToTensor(),
        transforms.Lambda(lambda x: x / 255.0)
        # ImageNet normalization
    ])
    image = Image.open(img_file)

    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension

    entire_emb=[]
    for a in benchmark:
        image = Image.open(a)
        image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
        entire_emb.append(model.emb(image_tensor).detach().numpy()[0])


    var_emb=np.var(entire_emb, axis=0)
    var_emb=var_emb.tolist()

    return var_emb.index(max(var_emb))




with torch.no_grad():

    model = Autoencoder().to(device)
    model.load_state_dict(torch.load(target_model_file, map_location=torch.device('cpu')))
    model.eval()

    model_AE = Autoencoder().to(device)
    model_AE.load_state_dict(torch.load(AE_model_file, map_location=torch.device('cpu')))
    model_AE.eval()

    model_target_unit=get_most_effective_unit(model, benchmark)
    model_AE_target_unit = get_most_effective_unit(model_AE, benchmark)

    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # Resize image to match model input size
        transforms.Grayscale(num_output_channels=1),

        transforms.ToTensor(),
        transforms.Lambda(lambda x: x / 255.0)
        # ImageNet normalization
    ])
    image = Image.open(img_file)
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
    annotated_image = Image.open(humnan_annotated_img)


    ini_matrix=np.zeros((128,128))
    for i in range(128):
        for ii in range(128):

            img1=copy.deepcopy(image_tensor)#
            img2 = copy.deepcopy(image_tensor)#

            for s in range(1):
                for ss in range(1):
                    img1[:,:,(i*1)+s,(ii*1)+ss]=0

                    img2[:, :, (i*1)+s, (ii*1)+ss] = 1
            d1=model.emb(img1).detach().numpy()[0][model_target_unit]
            d2 = model.emb(img2).detach().numpy()[0][model_target_unit]
            diffrence=mse_loss(d1,d2)

            ini_matrix[i,ii]=diffrence


    min_val = np.min(ini_matrix)
    max_val = np.max(ini_matrix)

    normalized_array = (ini_matrix - min_val) / (max_val - min_val)
    arr_flattened = normalized_array.flatten()
    tenth_largest = np.partition(arr_flattened, -how_many)[-how_many]


    new_array = np.zeros((128,128))
    for i in range(128):
        for ii in range(128):
            if normalized_array[i,ii]<tenth_largest:
                new_array[i,ii]+=0
            elif normalized_array[i,ii]>=tenth_largest:


                for t in range(5):
                    for tt in range(5):

                        new_array[i-2+t,ii-2+tt]+=(normalized_array[i,ii]+0.2)

                new_array[i, ii] += (normalized_array[i, ii] + 0.1)

            else:
                new_array[i,ii]+=normalized_array[i,ii]+0.01



    ini_matrix_AE=np.zeros((128,128))
    for i in range(128):
        for ii in range(128):

            img1=copy.deepcopy(image_tensor)#[
            img2 = copy.deepcopy(image_tensor)

            for s in range(1):
                for ss in range(1):
                    img1[:,:,(i*1)+s,(ii*1)+ss]=0

                    img2[:, :, (i*1)+s, (ii*1)+ss] = 1
            d1=model_AE.emb(img1).detach().numpy()[0][model_target_unit]
            d2 = model_AE.emb(img2).detach().numpy()[0][model_target_unit]
            diffrence_AE=mse_loss(d1,d2)

            ini_matrix_AE[i,ii]=diffrence_AE


    min_val = np.min(ini_matrix_AE)
    max_val = np.max(ini_matrix_AE)

    normalized_array_AE = (ini_matrix_AE - min_val) / (max_val - min_val)

    arr_flattened_AE = normalized_array_AE.flatten()
    tenth_largest_AE = np.partition(arr_flattened_AE, -how_many)[-how_many]



    new_array_AE = np.zeros((128,128))
    for i in range(128):
        for ii in range(128):
            if normalized_array_AE[i,ii]<tenth_largest_AE:
                new_array_AE[i,ii]+=0
            elif normalized_array_AE[i,ii]>=tenth_largest_AE:


                for t in range(5):
                    for tt in range(5):

                        new_array_AE[i-2+t,ii-2+tt]+=(normalized_array_AE[i,ii]+0.2)

                new_array_AE[i, ii] += (normalized_array_AE[i, ii] + 0.1)

            else:
                new_array_AE[i,ii]+=normalized_array_AE[i,ii]+0.01



    image = plt.imread(img_file)[::8,::8]

    heatmap_data = new_array#
    heatmap_data_AE = new_array_AE#

    ###############################################

    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 3.8),tight_layout=True)

    im0=axes[0].imshow(annotated_image, cmap='gray')
    axes[0].set_title('Physician Annotation')
    axes[0].set_xticks([])
    axes[0].set_yticks([])

    im1=axes[1].imshow(image,cmap='viridis')
    im1_heatmap=axes[1].imshow(heatmap_data, cmap='hot', alpha=0.6, interpolation='nearest')
    axes[1].set_title('Our Model')
    axes[1].set_xticks([])
    axes[1].set_yticks([])

    im2=axes[2].imshow(image, cmap='viridis')
    im2_heatmap=axes[2].imshow(heatmap_data_AE, cmap='hot', alpha=0.6, interpolation='nearest')
    axes[2].set_title('Conventional Autoencoder')
    axes[2].set_xticks([])
    axes[2].set_yticks([])

    plt.savefig(save_file, bbox_inches='tight')
    plt.savefig("annotationB2.pdf", format="pdf")
    plt.show()
