from PIL import Image
from imagenetLabels import labelToIndex, indexToLabel
import code
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.decomposition import *
from sklearn.manifold import *
import matplotlib.patches as mpatches 
from matplotlib.collections import LineCollection

def stylize_plot(
    ax=None,
    *,
    base_fontsize=12,
    title_fontsize=None,
    tick_fontsize=None,
    rotate_xticks=0,
    grid=False,
    grid_axis="y",
    grid_style="--",
    grid_alpha=0.35,
    grid_linewidth=0.6,
    despine=True,
    spine_width=1.0,
    tick_length=4,
    tick_width=0.8,
    tick_direction="out",
    tighten_layout=True,
    extra_bottom_for_xticks=True,
    legend=True,
    legend_loc="best",
    legend_frame=False,
    line_width= None,       # set an absolute linewidth for all lines
    line_width_scale= None, # multiply existing widths (e.g., 1.5)
):
    """
    Apply clean, publication-friendly styling to the current plot.

    Call this AFTER you've drawn your plot elements (lines, bars, etc).

    Returns (fig, ax).
    """
    ax = ax or plt.gca()
    fig = ax.figure

    # Font sizes
    if title_fontsize is None:
        title_fontsize = base_fontsize + 2
    if tick_fontsize is None:
        tick_fontsize = base_fontsize - 1

    # Axis labels & title sizing (only change size—keep existing text)
    if ax.get_xlabel():
        ax.set_xlabel(ax.get_xlabel(), fontsize=base_fontsize)
    if ax.get_ylabel():
        ax.set_ylabel(ax.get_ylabel(), fontsize=base_fontsize)
    if ax.get_title():
        # ax.set_title(ax.get_title(), fontsize=title_fontsize, fontweight="bold")
        ax.set_title(ax.get_title(), fontsize=title_fontsize)

    # Ticks
    ax.tick_params(axis="both",
                   which="major",
                   labelsize=tick_fontsize,
                   length=tick_length,
                   width=tick_width,
                   direction=tick_direction)
    ax.tick_params(axis="both",
                   which="minor",
                   length=max(2, tick_length-2),
                   width=max(0.6, tick_width-0.2),
                   direction=tick_direction)

    # Optional x-tick rotation for dense labels
    if rotate_xticks:
        for lbl in ax.get_xticklabels():
            lbl.set_rotation(rotate_xticks)
            lbl.set_ha("right")

    # Grid (light, on one axis by default)
    if grid:
        ax.grid(True, axis=grid_axis, linestyle=grid_style,
                linewidth=grid_linewidth, alpha=grid_alpha)

    # Despine + strengthen visible spines
    if despine:
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
    for side in ["left", "bottom"]:
        if ax.spines.get(side) is not None:
            ax.spines[side].set_linewidth(spine_width)

    # Legend only if there are labeled artists
    if legend:
        handles, labels = ax.get_legend_handles_labels()
        if labels:
            ax.legend(loc=legend_loc, frameon=legend_frame, fontsize=base_fontsize-1)

    # Small margins to avoid clipping markers/bars at edges
    ax.margins(x=0.02, y=0.05)


    # Handle standard Line2D (from plt.plot, ax.plot, etc.)
    for ln in ax.lines:
        if line_width is not None:
            ln.set_linewidth(line_width)
        if line_width_scale is not None:
            ln.set_linewidth(ln.get_linewidth() * line_width_scale)

    # Handle LineCollection-based artists (some funcs create these, e.g., errorbar caps, streamplots)
    for coll in ax.collections:
        if isinstance(coll, LineCollection):
            if line_width is not None:
                coll.set_linewidths(line_width)
            if line_width_scale is not None:
                cur = coll.get_linewidths()
                coll.set_linewidths([w * line_width_scale for w in (cur if len(cur) else [1.0])])


    # Layout tuning for rotated x-ticks
    if tighten_layout:
        fig.tight_layout()
        if extra_bottom_for_xticks and rotate_xticks:
            fig.subplots_adjust(bottom=0.18)

    return fig, ax



class IndexModel(nn.Module):
    # for vgg19 classifier is as below, good index layers could be -1, 1, 4
    # (0): Linear(in_features=25088, out_features=4096, bias=True)
    # (1): ReLU(inplace=True)
    # (2): Dropout(p=0.5, inplace=False)
    # (3): Linear(in_features=4096, out_features=4096, bias=True)
    # (4): ReLU(inplace=True)
    # (5): Dropout(p=0.5, inplace=False)
    # (6): Linear(in_features=4096, out_features=1000, bias=True)

    def __init__(self, use_pretrained=True, feature_extract=True, indexLayer=1):
        """ VGG19
        """
        super().__init__()
        self.model_ft = torchvision.models.vgg19(pretrained=use_pretrained)
        IndexModel.set_parameter_requires_grad(self.model_ft, feature_extract)

        self.model_ft.subnets = {}
        self.model_ft.subnets["features"] = nn.Sequential(self.model_ft.features, self.model_ft.avgpool)
        self.features = self.model_ft.subnets["features"] 
        self.relu = nn.ReLU()

        # self.model_ft.subnets["classifier"] = self.model_ft.classifier
        # self.features = self.model_ft.subnets["features"] 

        if indexLayer == -1:
            # make it so index is at feature layer
            pass
        else:
            self.preIndex = nn.Sequential(nn.Flatten(), self.model_ft.classifier[0:indexLayer+1])
            self.postIndex = nn.Sequential(self.model_ft.classifier[indexLayer+1:])

    def forward(self, x, indexData=None, weight=0.5, subIndexData=None, subWeight=0.5, returnFeatures=False):
        if indexData is None and subIndexData is None:
            newFeats = self.preIndex(self.features(x))
            output = self.postIndex(newFeats)
            return (output, newFeats) if returnFeatures else output
        else:
            img = self.preIndex(self.features(x))

            diff = torch.tensor(0.).to(indexData.device)
            if indexData is not None:
                index = self.preIndex(self.features(indexData)).mean(dim=0) # take mean along batch dimmension
                diff = weight * (index - img)

            subDiff = torch.tensor(0.).to(subIndexData.device)
            if subIndexData is not None:
                subIndex = self.preIndex(self.features(subIndexData)).mean(dim=0) # take mean along batch dimmension
                subDiff = subWeight * (subIndex - img)

            newFeats = self.relu(img + diff - subDiff)
            output = self.postIndex(newFeats)
            # code.interact(local=dict(globals(),**locals()))

            return (output, newFeats) if returnFeatures else output

    # # this is the old implementation before adding the subtracted index
    # def forward(self, x, indexData=None, weight=0.5, returnFeatures=False):
        # if indexData is None:
            # newFeats = self.preIndex(self.features(x))
            # output = self.postIndex(newFeats)
            # return (output, newFeats) if returnFeatures else output
        # else:
            # # # simple
            # # index = self.preIndex(self.features(indexData)).mean(dim=0) # take mean along batch dimmension
            # # ret = self.postIndex(self.preIndex(self.features(x)) + index)
            # # return ret

            # # complex here we want to move the image representation towards the index not in the same direction the index is pointing 
            # index = self.preIndex(self.features(indexData)).mean(dim=0) # take mean along batch dimmension
            # img = self.preIndex(self.features(x))
            # # diff = 0.6 *(index - img)
            # diff = weight * (index - img)
            # newFeats = img + diff
            # output = self.postIndex(newFeats)
            # return (output, newFeats) if returnFeatures else output


    @staticmethod
    def set_parameter_requires_grad(model, feature_extracting):
        if feature_extracting:
            for param in model.parameters():
                param.requires_grad = False



model = IndexModel(use_pretrained=True, feature_extract=True, indexLayer=1)
#                                       label:{"color":"tab:blue", "images":[]}          colro:{"color":"tab:blue", "images":[]}
def plotDimReduction(model, imagePaths={"label":{"color":"tab:blue", "images":[]}}, indexPaths={"label":{"color":"tab:blue", "images":[]}}, weightValues=[ii for ii in np.arange(0,1.1,0.025)] ,figureFolder="./imagesOutputs/", rootLabel="image"):
    model.eval()

    # Define the image transformations
    transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    
    points = []
    colors = []
    labels = []
    patches = []

    imagesDict = {}
    for label in imagePaths:
        labels.append(label)
        patches.append(mpatches.Patch(color=imagePaths[label]["color"], label=label))
        imagesDict[label] = {"color": imagePaths[label]["color"], "images":[]}
        for path in imagePaths[label]["images"]:
            pilImage = Image.open(path)
            image = transform(pilImage)
            image = image.unsqueeze(0)  # Add batch dimension
            with torch.no_grad():
                output, features = model(image, indexData=None, weight=0.0, returnFeatures=True)
                print(features.shape)
                imagesDict[label]["images"].append(features)
                points.append(features)
                colors.append(imagePaths[label]["color"])


    print("____________________")
    indexDict = {}
    for label in indexPaths:
        labels.append(label)
        patches.append(mpatches.Patch(color=indexPaths[label]["color"], label=label))
        indexDict[label] = {"color": indexPaths[label]["color"], "images":[]}
        curImages = []
        for path in indexPaths[label]["images"]:
            pilImage = Image.open(path)
            image = transform(pilImage)
            image = image.unsqueeze(0)  # Add batch dimension
            curImages.append(image)

        curIndex = torch.cat(curImages, dim=0)
        with torch.no_grad():
            output, features = model(curIndex, indexData=None, weight=0.0, returnFeatures=True)
            features = features.mean(dim=0).unsqueeze(0)
            indexDict[label]["images"].append(features)
            points.append(features)
            colors.append(indexPaths[label]["color"])

    def plotPoints(pcaPoints, colors, xdim=0, ydim=1, plotName="pc", labels=None):
        plt.figure()
        plt.scatter(pcaPoints[:,xdim], pcaPoints[:,ydim], c=colors)
        plt.xlabel("PC %d" % xdim)
        plt.ylabel("PC %d" % ydim)
        if labels is not None:
            plt.legend(handles=labels)
        plt.savefig("%s/%s-%d-%d.png" % (figureFolder, plotName, xdim, ydim))

    points = torch.cat(points, dim=0).numpy()
    
    print("saving pcs")
    print(labels)
    pca = PCA(n_components=3)
    pcaPoints = pca.fit_transform(points)
    plotPoints(pcaPoints, colors, xdim=0, ydim=1, plotName="pc", labels=patches)
    plotPoints(pcaPoints, colors, xdim=2, ydim=1, plotName="pc", labels=patches)

    
    # for kernel in ["linear", "poly", "rbf", "sigmoid", "cosine", "precomputed"]:
        # print("PCA %s......" % kernel)
        # kpca = KernelPCA(n_components=3)
        # pcaPoints = kpca.fit_transform(points)
        # plotPoints(pcaPoints, colors, xdim=0, ydim=1, plotName="kpc-%s" % kernel)
        # plotPoints(pcaPoints, colors, xdim=2, ydim=1, plotName="kpc-%s" % kernel)

    # pca = PCA(n_components=10)
    # pcaPoints = pca.fit_transform(points)
    # tsne = TSNE(n_components=2)
    # tsnePoints = tsne.fit_transform(pcaPoints)
    # plotPoints(tsnePoints, colors, xdim=0, ydim=1, plotName="pcTsne")

    # tsne = TSNE(n_components=2)
    # tsnePoints = tsne.fit_transform(points)
    # plotPoints(tsnePoints, colors, xdim=0, ydim=1, plotName="tsne")





    # runOnImageSub(model, "images/%s.jpg" % img, indexPaths=indexPaths, subIndexPaths=beachIndexPaths, weightValues=weightValues, indexType="Add %s | sub Beach" % indexLabels, figureFolder="./imagesOutputs/", rootLabel=img)
def runOnImageSub(model, imagePath, indexPaths=[], subIndexPaths=[], weightValues=[ii for ii in np.arange(0,1.1,0.025)], indexType="Farm",figureFolder="./imagesOutputs/", rootLabel="image"):
    model.eval()

    # Define the image transformations
    transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    pilImage = Image.open(imagePath)

    image = transform(pilImage)
    image = image.unsqueeze(0)  # Add batch dimension

    indexImages = []
    for path in indexPaths:
        indexImage = Image.open(path)
        indexImage = transform(indexImage)
        indexImage = indexImage.unsqueeze(0)  # Add batch dimension
        indexImages.append(indexImage)
    indexImages = torch.cat(indexImages, dim=0)

    subIndexImages = []
    for path in subIndexPaths:
        indexImage = Image.open(path)
        indexImage = transform(indexImage)
        indexImage = indexImage.unsqueeze(0)  # Add batch dimension
        subIndexImages.append(indexImage)
    subIndexImages = torch.cat(subIndexImages, dim=0)

    sm = nn.Softmax(dim=1)
    # confValues = []
    # confLabels = []
    print("--------------------------------------------")
    with torch.no_grad():
        for subWeight in weightValues:
            confValues = []
            confLabels = []
            for weight in weightValues:
                output = model(image, indexData=indexImages, weight=weight, subIndexData=subIndexImages, subWeight=subWeight)
                probs = sm(output)
                predProb, predClass = torch.max(probs, 1)
                class_index = int(predClass.item())
                confValues.append(predProb.item())
                confLabels.append(indexToLabel[class_index][:15])

                print("%s | weight %f | Class Index %d | Class Label %s" % (imagePath, weight, class_index, indexToLabel[class_index]))
    
             
            subType = "Beach"
            addType = None
            if "Farm" in indexType:
                addType = "Farm"
            elif "City" in indexType:
                addType = "City"
            # fig, axs = plt.subplots(2)
            fig, axs = plt.subplots(1)
            axs.set_aspect(0.35)
            axs = [axs]
            for i in range(len(weightValues)):
                # axs[0].text(weightValues[i], confValues[i] + 0.03, confLabels[i], rotation="vertical")
                axs[0].text(weightValues[i], confValues[i] + 0.03, confLabels[i], rotation=45, fontsize=15)
            axs[0].plot(weightValues, confValues)
            # axs[0].set_xlabel("Index Weight | %s %f" % (indexType, subWeight))
            axs[0].set_xlabel(f"Injected ({addType}) CONTXT α")
            axs[0].set_title(f"Removed ({subType}) CONTXT α {-1 * subWeight:.1f}")
            axs[0].set_ylabel("Confidence")
            axs[0].set_ylim((0,1))
            # axs[1].imshow(np.asarray(pilImage))

            stylize_plot(
                axs[0],
                base_fontsize=22,
                title_fontsize=22,
                tick_fontsize=None,
                rotate_xticks=0,
                grid=False,
                grid_axis="y",
                grid_style="--",
                grid_alpha=0.35,
                grid_linewidth=0.6,
                despine=True,
                spine_width=2.0,
                tick_length=4,
                tick_width=0.8,
                tick_direction="out",
                tighten_layout=True,
                extra_bottom_for_xticks=True,
                legend=True,
                legend_loc="best",
                legend_frame=False,
                line_width= 5,       # set an absolute linewidth for all lines
                line_width_scale= None, # multiply existing widths (e.g., 1.5)
            )




            plt.tight_layout()
            plt.savefig(os.path.join(figureFolder, "%s-%s-SubWeight-%f.pdf" % (rootLabel.replace("|","-"), indexType.replace("|","-"), subWeight)))

def runOnImage(model, imagePath, indexPaths=[], weightValues=[ii for ii in np.arange(0,1.1,0.025)], indexType="Farm",figureFolder="./imagesOutputs/", rootLabel="image"):
    model.eval()

    # Define the image transformations
    transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    pilImage = Image.open(imagePath)

    image = transform(pilImage)
    image = image.unsqueeze(0)  # Add batch dimension

    indexImages = []
    for path in indexPaths:
        indexImage = Image.open(path)
        indexImage = transform(indexImage)
        indexImage = indexImage.unsqueeze(0)  # Add batch dimension
        indexImages.append(indexImage)
    indexImages = torch.cat(indexImages, dim=0)

    sm = nn.Softmax(dim=1)
    confValues = []
    confLabels = []
    print("--------------------------------------------")
    with torch.no_grad():
        for weight in weightValues:
            output = model(image, indexData=indexImages, weight=weight)
            probs = sm(output)
            predProb, predClass = torch.max(probs, 1)
            class_index = int(predClass.item())
            confValues.append(predProb.item())
            confLabels.append(indexToLabel[class_index][:15])

            print("%s | weight %f | Class Index %d | Class Label %s" % (imagePath, weight, class_index, indexToLabel[class_index]))
    
    fig, axs = plt.subplots(2)
    for i in range(len(weightValues)):
        # axs[0].text(weightValues[i], confValues[i] + 0.03, confLabels[i], rotation="vertical")
        axs[0].text(weightValues[i], confValues[i] + 0.03, confLabels[i], rotation=45)
    axs[0].plot(weightValues, confValues)
    axs[0].set_xlabel("Index Weight | %s" % indexType)
    axs[0].set_ylabel("Confidence")
    axs[0].set_ylim((0,1))
    axs[1].imshow(np.asarray(pilImage))
    plt.tight_layout()
    plt.savefig(os.path.join(figureFolder, "%s-%s.png" % (rootLabel, indexType)))


def runIndexClassification():
    model = IndexModel(use_pretrained=True, feature_extract=True, indexLayer=1)
    farmIndexPaths =  [ "/bazhlab/edelanois/msProj/9/images/farm1.jpg", "/bazhlab/edelanois/msProj/9/images/farm2.jpg", "/bazhlab/edelanois/msProj/9/images/farm3.jpg", "/bazhlab/edelanois/msProj/9/images/grass1.jpg", "/bazhlab/edelanois/msProj/9/images/grass2.jpg", "/bazhlab/edelanois/msProj/9/images/grass3.jpg" ]
    cityIndexPaths =  ["./images/city%d.jpg" % i for i in range(1,7)]
    weightValues = [float(ii) for ii in np.arange(0,1.1,0.025)]
    # weightValues = [float(ii) for ii in np.arange(0,1.1,0.1)]


    # allImgs = ["cowBeach", "cowCity", "cowFarm", "ox", "oxBeach", "oxBeach2"]
    # allImgs = ["cowInWoods", "cowInWoods2", "cowInRoad1", "cowInRoad0", "cowCity1", "cowCity2"]
    allImgs = ["cowCity","cowInRoad0"]

    for img in allImgs:
        for indexPaths, indexLabels in zip([farmIndexPaths, cityIndexPaths], ["Farm", "City"]):
            runOnImage(model, "images/%s.jpg" % img, indexPaths=indexPaths, weightValues=weightValues, indexType=indexLabels, figureFolder="./imagesOutputs/", rootLabel=img)

def runSubIndexClassification():
    model = IndexModel(use_pretrained=True, feature_extract=True, indexLayer=1)
    farmIndexPaths =  [ "/bazhlab/edelanois/msProj/9/images/farm1.jpg", "/bazhlab/edelanois/msProj/9/images/farm2.jpg", "/bazhlab/edelanois/msProj/9/images/farm3.jpg", "/bazhlab/edelanois/msProj/9/images/grass1.jpg", "/bazhlab/edelanois/msProj/9/images/grass2.jpg", "/bazhlab/edelanois/msProj/9/images/grass3.jpg" ]
    cityIndexPaths =  ["./images/city%d.jpg" % i for i in range(1,7)]
    beachIndexPaths =  ["./images/beach%d.jpeg" % i for i in range(1,7)]
    weightValues = [float(ii) for ii in np.arange(0,1.1,0.1)]


    # allImgs = ["cowBeach", "cowCity", "cowFarm", "ox", "oxBeach", "oxBeach2"]
    # allImgs = ["cowInWoods", "cowInWoods2", "cowInRoad1", "cowInRoad0", "cowCity1", "cowCity2"]
    # allImgs = ["cowBeach","cowInRoad0"]
    allImgs = ["cowBeach"]

    for img in allImgs:
        for indexPaths, indexLabels in zip([farmIndexPaths, cityIndexPaths], ["Farm", "City"]):
            runOnImageSub(model, "images/%s.jpg" % img, indexPaths=indexPaths, subIndexPaths=beachIndexPaths, weightValues=weightValues, indexType="Add %s | sub Beach" % indexLabels, figureFolder="./imagesOutputsSub/", rootLabel=img)
        
        
def runDimReduction():
    # imagePaths={"label":
        # {"color":"tab:blue", 
         # "images":[
             
         # ]
         # }}

    imagePaths={
        "OOD Cow":
        {"color":"tab:red", 
         "images":[
            "./images/cowBeach.jpg",
            "./images/cowCity.jpg",
            "./images/cowCity1.jpg",
            "./images/cowCity2.jpg",
            "./images/cowInRoad0.jpg",
            "./images/cowInRoad1.jpg",
            "./images/oxBeach.jpg",
            "./images/oxBeach2.jpg",
         ]
         }
        ,"ID Cow":
        {"color":"tab:blue", 
         "images":[
            "./images/cowFarm.jpg",
            "./images/cowInWoods.jpg",
            "./images/cowInWoods2.jpg",
            "./images/ox.jpg"
         ]
         }
        }

    indexPaths={
        "Farm":
        {"color":"tab:green", 
         "images":[
            "./images/farm1.jpg",
            "./images/farm2.jpg",
            "./images/farm3.jpg",
            "./images/grass1.jpg",
            "./images/grass2.jpg",
            "./images/grass3.jpg"
         ]
         }
        ,"city":
        {"color":"tab:orange", 
         "images":[
            "./images/city1.jpg",
            "./images/city2.jpg",
            "./images/city3.jpg",
            "./images/city4.jpg",
            "./images/city5.jpg",
            "./images/city6.jpg"
         ]
         }
        }
    plotDimReduction(model, imagePaths=imagePaths, indexPaths=indexPaths, weightValues=[ii for ii in np.arange(0,1.1,0.025)] ,figureFolder="./imagesOutputs/", rootLabel="image")
    
# runIndexClassification()
# runDimReduction()
runSubIndexClassification()