from PIL import Image
from imagenetLabels import labelToIndex, indexToLabel
import code
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np



class IndexModel(nn.Module):
    # for vgg19 classifier is as below, good index layers could be -1, 1, 4
    # (0): Linear(in_features=25088, out_features=4096, bias=True)
    # (1): ReLU(inplace=True)
    # (2): Dropout(p=0.5, inplace=False)
    # (3): Linear(in_features=4096, out_features=4096, bias=True)
    # (4): ReLU(inplace=True)
    # (5): Dropout(p=0.5, inplace=False)
    # (6): Linear(in_features=4096, out_features=1000, bias=True)

    def __init__(self, use_pretrained=True, feature_extract=True, indexLayer=1):
        """ VGG19
        """
        super().__init__()
        self.model_ft = torchvision.models.vgg19(pretrained=use_pretrained)
        IndexModel.set_parameter_requires_grad(self.model_ft, feature_extract)

        self.model_ft.subnets = {}
        self.model_ft.subnets["features"] = nn.Sequential(self.model_ft.features, self.model_ft.avgpool)
        self.features = self.model_ft.subnets["features"] 

        # self.model_ft.subnets["classifier"] = self.model_ft.classifier
        # self.features = self.model_ft.subnets["features"] 

        if indexLayer == -1:
            # make it so index is at feature layer
            pass
        else:
            self.preIndex = nn.Sequential(nn.Flatten(), self.model_ft.classifier[0:indexLayer+1])
            self.postIndex = nn.Sequential(self.model_ft.classifier[indexLayer+1:])

    def forward(self, x, indexData=None, weight=0.5):
        if indexData is None:
            return self.postIndex(self.preIndex(self.features(x)))
        
        else:
            # # simple
            # index = self.preIndex(self.features(indexData)).mean(dim=0) # take mean along batch dimmension
            # ret = self.postIndex(self.preIndex(self.features(x)) + index)
            # return ret

            # complex here we want to move the image representation towards the index not in the same direction the index is pointing 
            index = self.preIndex(self.features(indexData)).mean(dim=0) # take mean along batch dimmension
            img = self.preIndex(self.features(x))
            # diff = 0.6 *(index - img)
            diff = weight * (index - img)
            ret = self.postIndex(img + diff)
            return ret

    @staticmethod
    def set_parameter_requires_grad(model, feature_extracting):
        if feature_extracting:
            for param in model.parameters():
                param.requires_grad = False



model = IndexModel(use_pretrained=True, feature_extract=True, indexLayer=1)


# Load the pre-trained model
# model = torchvision.models.resnet18(pretrained=True)
# model = torchvision.models.vgg19(pretrained=True)
model.eval()

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the image
# imgPath = "images/cowBeach.jpg" # PERFORMANCE WORKS
# imgPath = "images/oxBeach2.jpg" # PERFORMANCE NOT WORKS
imgPath = "images/cowCity.jpg" # PERFORMANCE WORKS
image = Image.open(imgPath)
# image = Image.open("images/cowBeach.jpg")
# image = Image.open("images/ox.jpg")
# image = Image.open("images/cowFarm.jpg")
# image = Image.open("images/oxBeach.jpg")

image = transform(image)
image = image.unsqueeze(0)  # Add batch dimension

# indexImage = Image.open("images/cowBeach.jpg")
# indexImage = Image.open("images/ox.jpg")
# indexImage = Image.open("images/cowFarm.jpg")
# indexImage = Image.open("images/oxBeach.jpg")
# indexImage = transform(indexImage)
# indexImage = indexImage.unsqueeze(0)  # Add batch dimension


indexImages = []
# # farm context
for path in [ "/bazhlab/edelanois/msProj/9/images/farm1.jpg", "/bazhlab/edelanois/msProj/9/images/farm2.jpg", "/bazhlab/edelanois/msProj/9/images/farm3.jpg", "/bazhlab/edelanois/msProj/9/images/grass1.jpg", "/bazhlab/edelanois/msProj/9/images/grass2.jpg", "/bazhlab/edelanois/msProj/9/images/grass3.jpg" ]:

# # City context
# for path in ["./images/city%d.jpg" % i for i in range(1,7)]:

    indexImage = Image.open(path)
    indexImage = transform(indexImage)
    indexImage = indexImage.unsqueeze(0)  # Add batch dimension
    indexImages.append(indexImage)
indexImages = torch.cat(indexImages, dim=0)

print("______________________________________")
# Classify the image
with torch.no_grad():
    # output = model(image)
    # output = model(image, indexImage)
    # output = model(image, indexImages)
    # _, predicted = torch.max(output, 1)
    for weight in np.arange(0,1.1,0.025):
        weight = float(weight)
        output = model(image, indexData=indexImages, weight=weight)
        _, predicted = torch.max(output, 1)
        class_index = int(predicted.item())
        print()
        print("%s | weight %f | Class Index %d | Class Label %s" % (imgPath, weight, class_index, indexToLabel[class_index]))

# Get the class label
# class_index = int(predicted.item())


# print("Predicted class:", class_index)
# print(indexToLabel[class_index])