import argparse
parser = argparse.ArgumentParser(description = "Template")
parser.add_argument("-gpu","--GPUindex",default = 0,type = int,help = "gpu index")
parser.add_argument("-l","--layers",default = 1,type = int,help = "number of layers")
parser.add_argument("-hi","--hidden",default = 10,type = int,help = "number of hidden")
parser.add_argument("-hrf","--HRFdelay",default = 3,type = int,help = "HRF delay")
parser.add_argument("-d","--directory",type = str,help = "directory")
parser.add_argument("--cross-subject",default=False,action="store_true")
parser.add_argument("--pooled-subject-cross-modal",default=False,action="store_true")
parser.add_argument("--plot-only",default=False,action="store_true")
parser.add_argument("--heatmap-only",default=False,action="store_true")

options = parser.parse_args()

tasks = [
    "Dan-pick-up","Dan-put-down","Scott-pick-up","Scott-put-down",
    "Dan-briefcase", "Dan-chair", "Scott-briefcase", "Scott-chair",
    "pick-up-briefcase","pick-up-chair","put-down-briefcase","put-down-chair",
]

from xml.dom import minidom
import torch
import numpy as np
from mvpa2.suite import *
from regressors import *
import pickle as pkl
import os
import random
import matplotlib.pyplot as plt
import json
from scipy.stats import kendalltau


seed = 12
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
mvpa2.seed(seed)
torch.backends.cudnn.deterministic = True


experiment = "predication"
subjects = 12
folds = 8
runs = 16
number_of_detectors = 1
K = 32
supertrials = True
n_voxel_interest = 500
window_size = 100
stride = 1
n_step = n_voxel_interest - window_size
n_window = int(np.ceil(n_step/stride))
device = torch.device(f"cuda:{options.GPUindex}")

if not os.path.exists("heatmaps"):
    os.mkdir("heatmaps")



def aggregate_supertrials(raw_dataset, C):
    fmri = raw_dataset["fmri"]
    labels = raw_dataset["labels"]
    aggregated_fmri = []
    averaged_fmri = []
    # An empty list for each class.
    for c in range(C):
        aggregated_fmri.append([])
        averaged_fmri.append([])
    # labels[trial].item() is the class of trial number trial
    # fmri[trial, :] is the brain volume for trial number trial
    # This aggregates trial brain volumes by class.
    for trial in range(fmri.shape[0]):
        for k in range(K):
            # make K copies of each trial
            aggregated_fmri[int(labels[trial].item())].append(fmri[trial, :])
    # This actually constructs the supetrials.
    for label in range(len(aggregated_fmri)):
        random.shuffle(aggregated_fmri[label])
        # Instead of int(floor(len(aggregated_fmri[label])/K)) we can do
        # len(aggregated_fmri[label])//K.
        # This is the number of supertrials for each class.
        # k ranges over the supertrial indices
        # average the trials for each supetrial
        for k in range(len(aggregated_fmri[label])//K):
            # As supertrial k is the average of the K trials from k*K to
            # (k+1)*K-1.
            mean = torch.zeros((len(aggregated_fmri[label][0]),),
                               dtype=torch.float)
            for sample in aggregated_fmri[label][k*K:(k+1)*K]:
                mean += sample
            averaged_fmri[label].append(mean/K)
    # This makes a dataset that is aggregated by class.
    supertrials = []
    supertrial_labels = []
    for c in range(C):
        supertrials += averaged_fmri[c]
        supertrial_labels += [[c]]*len(averaged_fmri[c])
    # This then shuffles the dataset.
    permutation = list(range(len(supertrial_labels)))
    random.shuffle(permutation)
    supertrials = [supertrials[i] for i in permutation]
    supertrial_labels = [supertrial_labels[i] for i in permutation]

    return {"fmri": torch.stack(supertrials),
                "labels": torch.tensor(supertrial_labels, dtype=torch.float32)}


def attributes_to_labels(attributes):
    targets = attributes.targets
    chunks = attributes.chunks
    label = -1
    labels = []
    for i in range(len(targets)):
        if chunks[i]!=label:
            labels.append([])
            label = chunks[i]
        labels[-1].append(targets[i])
    return labels

def labels_to_attributes(labels):
    targets = []
    chunks = []
    chunk = -1
    for label in labels:
        targets.extend(label)
        chunk += 1
        for i in range(len(label)):
            chunks.append(chunk)
    return targets, chunks

def add_hrf_delay_to_labels(labels, hrf_delay):
    new_labels = []
    for label in labels:
        new_labels.append([])
        for i in range(hrf_delay):
            new_labels[-1].append("rest")
        new_labels[-1].extend(label[:len(label)-hrf_delay])
    return new_labels

def trim_labels(labels, runs):
    return labels[:runs]

def add_hrf_delay_and_trim_attributes(attributes, hrf_delay, runs):
    labels = attributes_to_labels(attributes)
    labels = add_hrf_delay_to_labels(labels, hrf_delay)
    labels = trim_labels(labels, runs)
    return labels_to_attributes(labels)


def read_dataset(
    experiment,
    hrf_delay,
    subject,
    modality,
    word,
):
    attributes = SampleAttributes("../../data/fmri/fmri-experiments/%s/design/generated-experiments/subject-%02d/%s/pymvpa/attributes"%(experiment, subject+1, modality))
    targets, chunks = add_hrf_delay_and_trim_attributes(
        attributes, hrf_delay, runs)
    dataset = fmri_dataset(samples="%s/subject-%02d/epi.nii.gz"%(options.directory, subject+1),
                           targets=targets,
                           chunks=chunks,
                           mask="%s/subject-%02d/full-mask.nii.gz"%(options.directory, subject+1))
    poly_detrend(dataset, polyord=4, chunks_attr="chunks")
    zscore(dataset, chunks_attr="chunks", param_est=("targets", ["rest"]))
    dataset = dataset[dataset.sa.targets!="rest"]

    fmri = {"data": dataset.samples,
            "targets": dataset.targets,
            "chunks": dataset.chunks}
    samples = len(fmri["data"])

    used_fmri = []
    used_labels = []
    for sample in range(samples):
        c = fmri["targets"][sample]
        if word=="Dan":
            if "Dan" in c:
                case = 1
            else:
                case = 0
        elif word=="Scott":
            if "Scott" in c:
                case = 1
            else:
                case = 0
        elif word=="pick-up":
            if "pick-up" in c:
                case = 1
            else:
                case = 0
        elif word=="put-down":
            if "put-down" in c:
                case = 1
            else:
                case = 0
        elif word=="briefcase":
            if "briefcase" in c:
                case = 1
            else:
                case = 0
        elif word=="chair":
            if "chair" in c:
                case = 1
            else:
                case = 0
        elif word=="Dan-pick-up":
            if "Dan" in c and "pick-up" in c:
                if ("Dan-on-left-pick-up" in c or
                    "Dan-on-right-pick-up" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="Dan-put-down":
            if "Dan" in c and "put-down" in c:
                if ("Dan-on-left-put-down" in c or
                    "Dan-on-right-put-down" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="Scott-pick-up":
            if "Scott" in c and "pick-up" in c:
                if ("Scott-on-left-pick-up" in c or
                    "Scott-on-right-pick-up" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="Scott-put-down":
            if "Scott" in c and "put-down" in c:
                if ("Scott-on-left-put-down" in c or
                    "Scott-on-right-put-down" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="pick-up-briefcase":
            if "pick-up" in c and "briefcase" in c:
                if ("pick-up-on-left-briefcase" in c or
                    "pick-up-on-right-briefcase" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="pick-up-chair":
            if "pick-up" in c and "chair" in c:
                if ("pick-up-on-left-chair" in c or
                    "pick-up-on-right-chair" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="put-down-briefcase":
            if "put-down" in c and "briefcase" in c:
                if ("put-down-on-left-briefcase" in c or
                    "put-down-on-right-briefcase" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="put-down-chair":
            if "put-down" in c and "chair" in c:
                if ("put-down-on-left-chair" in c or
                    "put-down-on-right-chair" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="Dan-briefcase":
            if "Dan" in c and "briefcase" in c:
                if (("Dan-on-left" in c and "briefcase-on-left" in c) or
                    ("Dan-on-right" in c and "briefcase-on-right" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="Dan-chair":
            if "Dan" in c and "chair" in c:
                if (("Dan-on-left" in c and "chair-on-left" in c) or
                    ("Dan-on-right" in c and "chair-on-right" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="Scott-briefcase":
            if "Scott" in c and "briefcase" in c:
                if (("Scott-on-left" in c and "briefcase-on-left" in c) or
                    ("Scott-on-right" in c and "briefcase-on-right" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="Scott-chair":
            if "Scott" in c and "chair" in c:
                if (("Scott-on-left" in c and "chair-on-left" in c) or
                    ("Scott-on-right" in c and "chair-on-right" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="subject":
            if ("Dan" in c and
                "Scott" in c and
                "pick-up" in c and
                "put-down" in c):
                if (("Dan-on-left-pick-up" in c and
                     "Scott-on-right-put-down" in c) or
                    ("Dan-on-right-pick-up" in c and
                     "Scott-on-left-put-down" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif word=="object":
            if ("pick-up" in c and
                "put-down" in c and
                "briefcase" in c and
                "chair" in c):
                if (("pick-up-on-left-briefcase" in c and
                     "put-down-on-right-chair" in c) or
                    ("pick-up-on-right-briefcase" in c and
                     "put-down-on-left-chair" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        # 1. What did Dan do? nothing [both false]
        # 1. What did Scott do? nothing [both false]
        # 2. What did Dan do something to? nothing [both false]
        # 2. What did Scott do something to? nothing [both false]
        # 3. Who picked something up? nobody [both false]
        # 3. Who picked something up? Dan and Scott [both true]
        # 3. Who put something down? nobody [both false]
        # 3. Who put something down? Dan and Scott [both true]
        # 4. What was picked up? nothing [both false]
        # 4. What was picked up? briefcase ad chair [both true]
        # 4. What was put down? nothing [both false]
        # 4. What was put down? briefcase ad chair [both true]
        # Can't distinguish one briefcase from two briefcases
        # Can't distinguish one chair from two chairs
        # 5. Who did something to a briefcase? nobody [both false]
        # 5. Who did something to a  briefcase? Dan and Scott [both true]
        # 5. Who did something to a chair? nobody [both false]
        # 5. Who did something to a chair? Dan and Scott [both true]
        # 6. What was done to a briefcase? nothing [both false]
        # 6. What was done to a briefcase? pick up and put down [both true]
        # 6. What was done to a chair? nothing [both false]
        # 6. What was done to a chair? pick up and put down [both true]
        # Can't distinguish one pick up from two pick ups
        # Can't distinguish one put down from two put downs

        # 1. What did Dan do? pick up
        # 3. Who picked something up? Dan
        elif word=="Dan-pick-up1":
            if ("Dan-on-left-pick-up" in c or
                "Dan-on-right-pick-up" in c):
                case = 1
            else:
                case = 0
        # 1. What did Dan do? put down
        # 3. Who put something down? Dan
        elif word=="Dan-put-down1":
            if ("Dan-on-left-put-down" in c or
                "Dan-on-right-put-down" in c):
                case = 1
            else:
                case = 0
        # 1. What did Scott do? pick up
        # 3. Who picked something up? Scott
        elif word=="Scott-pick-up1":
            if ("Scott-on-left-pick-up" in c or
                "Scott-on-right-pick-up" in c):
                case = 1
            else:
                case = 0
        # 1. What did Scott do? put down
        # 3. Who put something down? Scott
        elif word=="Scott-put-down1":
            if ("Scott-on-left-put-down" in c or
                "Scott-on-right-put-down" in c):
                case = 1
            else:
                case = 0
        # 2. What did Dan do something to? briefcase
        # 5. Who did something to a briefcase? Dan
        elif word=="Dan-briefcase1":
            if ("Dan-on-left-pick-up-on-left-briefcase" in c or
                "Dan-on-right-pick-up-on-right-briefcase" in c or
                "Dan-on-left-put-down-on-left-briefcase" in c or
                "Dan-on-right-put-down-on-right-briefcase" in c):
                case = 1
            else:
                case = 0
        # 2. What did Dan do something to? chair
        # 5. Who did something to a chair? Dan
        elif word=="Dan-chair1":
            if ("Dan-on-left-pick-up-on-left-chair" in c or
                "Dan-on-right-pick-up-on-right-chair" in c or
                "Dan-on-left-put-down-on-left-chair" in c or
                "Dan-on-right-put-down-on-right-chair" in c):
                case = 1
            else:
                case = 0
        # 2. What did Scott do something to? briefcase
        # 5. Who did something to a briefcase? Scott
        elif word=="Scott-briefcase1":
            if ("Scott-on-left-pick-up-on-left-briefcase" in c or
                "Scott-on-right-pick-up-on-right-briefcase" in c or
                "Scott-on-left-put-down-on-left-briefcase" in c or
                "Scott-on-right-put-down-on-right-briefcase" in c):
                case = 1
            else:
                case = 0
        # 2. What did Scott do something to? chair
        # 5. Who did something to a chair? Scott
        elif word=="Scott-chair1":
            if ("Scott-on-left-pick-up-on-left-chair" in c or
                "Scott-on-right-pick-up-on-right-chair" in c or
                "Scott-on-left-put-down-on-left-chair" in c or
                "Scott-on-right-put-down-on-right-chair" in c):
                case = 1
            else:
                case = 0
        # 4. What was picked up? briefcase
        # 6. What was done to a briefcase? pick up
        elif word=="pick-up-briefcase1":
            if ("pick-up-on-left-briefcase" in c or
                "pick-up-on-right-briefcase" in c):
                case = 1
            else:
                case = 0
        # 4. What was picked up? chair
        # 6. What was done to a chair? pick up
        elif word=="pick-up-chair1":
            if ("pick-up-on-left-chair" in c or
                "pick-up-on-right-chair" in c):
                case = 1
            else:
                case = 0
        # 4. What was put down? briefcase
        # 6. What was done to a briefcase? put down
        elif word=="put-down-briefcase1":
            if ("put-down-on-left-briefcase" in c or
                "put-down-on-right-briefcase" in c):
                case = 1
            else:
                case = 0
        # 4. What was put down? chair
        # 6. What was done to a chair? put down
        elif word=="put-down-chair1":
            if ("put-down-on-left-chair" in c or
                "put-down-on-right-chair" in c):
                case = 1
            else:
                case = 0
        else:
            raise RuntimeError("unknown word")
        if case>=0:
            used_fmri.append(torch.tensor(fmri["data"][sample],
                                          dtype=torch.float32))
            used_labels.append(torch.tensor([case], dtype=torch.float32))

    return {"fmri": torch.stack(used_fmri),
                "labels": torch.stack(used_labels)}


def read_pooled_subject_dataset(
        experiment, hrf_delay, modality, word):
    datasets = [
        read_dataset(
            experiment,
            hrf_delay,
            s,
            modality,
            word
        )
        for s in range(subjects)
    ]
    samples = datasets[0]["fmri"].size(0)
    voxels = datasets[0]["fmri"].size(1)
    dataset = {"fmri": torch.zeros([subjects*samples, voxels],
                                   dtype=torch.float32),
               "labels": torch.zeros([subjects*samples, number_of_detectors],
                                     dtype=torch.float32)}
    for i in range(samples):
        for s in range(subjects):
            dataset["fmri"][i*subjects+s, :] = datasets[s]["fmri"][i, :]
            dataset["labels"][i*subjects+s, :] = datasets[s]["labels"][i, :]
    return dataset

def sliding_nonSquare_Chamfer(X,Y,prev_dist=None,stride=1):
    if prev_dist is None:
        diff = X[:,np.newaxis,:] - Y[np.newaxis,:,:]
        dist = np.linalg.norm(diff,axis=-1)
    else:
        dist = np.delete(np.delete(prev_dist, slice(0, stride), 0), slice(0, stride), 1)
        new_rows = np.linalg.norm(X[-stride:,np.newaxis,:]-Y[np.newaxis,:,:],axis=-1)
        new_cols = np.linalg.norm(X[:,np.newaxis,:]-Y[np.newaxis,-stride:,:],axis=-1)

        dist = np.vstack([dist,new_rows[:,:-stride]])
        dist = np.hstack([dist,new_cols])

    d_Xy = dist.min(axis=0).mean()
    d_xY = dist.min(axis=1).mean()
    return d_Xy+d_xY, dist

def weighted_unsquared_Chamfer_distance(r1, r2):

    weight = np.exp(-np.linspace(0,10,n_window))
    # obtain the trajectories from the most important to the least important voxels
    t1 = np.flip(r1.argsort())
    t2 = np.flip(r2.argsort())

    Chamfer = []
    prev_dist = None
    for i in range(0,n_step,stride):
        voxels_1 = t1[i:i+window_size] # get the indices of the voxels in the sliding window
        voxels_2 = t2[i:i+window_size] # get the indices of the voxels in the sliding window
        coordinates_1 = voxel_coordinates[voxels_1] # get the coordinates of the voxels in the window
        coordinates_2 = voxel_coordinates[voxels_2] # get the coordinates of the voxels in the window

        distance, prev_dist = sliding_nonSquare_Chamfer(
            coordinates_1,
            coordinates_2,
            prev_dist,
            stride
        )
        Chamfer.append(distance)
    Chamfer = np.array(Chamfer)

    # Return the weighted summation over the windows
    return Chamfer.dot(weight)




# Load the universal mask
mask = nib.load(os.path.join(options.directory, "subject-01/full-mask.nii.gz")).get_fdata()
# Load coordinates of the 60732 voxels
voxel_coordinates = np.array(np.where(mask>0)).T


###################### Pooled Subject Cross Modal ######################
if options.pooled_subject_cross_modal:

    if not options.plot_only:
        if not os.path.exists("heatmaps/pooled_subject_cross_modal"):
            os.mkdir("heatmaps/pooled_subject_cross_modal")

        print("Pooled Subject Cross Modal: Generating heatmaps of the 60732 voxels and store them into .npy files")

        ## This part generates the heatmaps and save them to .npy files
        # Loop over 12 tasks
        for word in tasks:
            print("Generating heatmaps for task: %s"%word.replace("-", " "))

            # Loop over 2 modalities
            for target_modality in ["video", "text"]:

                # Load the testset for the given modality
                test_set = read_pooled_subject_dataset(
                    experiment, options.HRFdelay, target_modality, word+"1"
                )
                test_set = aggregate_supertrials(test_set, 2)

                # Load the model
                net = regress_net(
                    options.layers, 60732, options.hidden,
                    [word]).to(device)
                net.load_state_dict(torch.load(os.path.join(
                    options.directory, "pooled-subject", target_modality,
                    "%s1-pooled-subject-cross-modal-iclr2024-3-nn-1-10-whole-brain.pth"%word)))
                net.eval()

                # The input data that with stimuli
                stimulated = torch.where(test_set["labels"] == 1)[0]
                inputs = test_set["fmri"][stimulated].to(device)

                with torch.no_grad():
                    # Weights of the linear model
                    weight = net.layers[0].weight.squeeze()

                    # The absolute importnace of voxels
                    importance = (weight*inputs).detach().cpu()
                    importance = np.abs(importance)

                    # Individual rankings
                    ranking = importance.argsort(1).argsort(1).float()
                    ranking /= ranking.max()

                    # Averaged rankings across samples
                    ranking_avg = ranking.mean(0).detach().cpu().numpy()

                    # Save the averaged rankings
                    np.save("heatmaps/pooled_subject_cross_modal/%s_%s.npy"%(target_modality, word), ranking_avg)

        print("All heatmaps have been generated!")

    if not options.heatmap_only:
        # Start generating Kendall correlation and WUCD results
        print("Start generating Kendall correlation and WUCD results")
        Taus = []
        WUCDs = []

        ## This part uses the generated heatmaps to perform analysis
        # Loop over 12 tasks
        for word in tasks:
            ranking_video = np.load("heatmaps/pooled_subject_cross_modal/video_%s.npy"%(word))
            ranking_text = np.load("heatmaps/pooled_subject_cross_modal/text_%s.npy"%(word))

            # Generate and save the Kendall correlation results
            tau, _ = kendalltau(ranking_video,ranking_text)
            Taus.append(tau)

            # Generate and save the WUCD results
            WUCDs.append(weighted_unsquared_Chamfer_distance(ranking_video, ranking_text))



        # Plot the Kendall correlation results
        words = [task.replace("-", " ") for task in tasks]
        plt.figure(dpi = 300)
        plt.bar(np.linspace(0,1.5,len(words)), np.array(Taus), color="b", width=0.1)
        plt.xticks(np.linspace(0,1.5,len(words)), words, rotation=45, ha="right")
        plt.title("Pooled Subject Cross Modal Kendall Correlation")
        plt.ylabel("Kendall Correlation")
        plt.tight_layout()
        plt.savefig("heatmaps/pooled_subject_cross_modal/Kendall.png",bbox_inches="tight")
        plt.show()

        # Save the Kendall correlation results
        Kendall = dict(zip(words, Taus))
        with open("heatmaps/pooled_subject_cross_modal/Kendall.json", "w") as file:
            json.dump(Kendall, file)


        # Plot the WUCD results
        plt.figure(dpi = 300)
        plt.bar(np.linspace(0,1.5,len(words)), WUCDs, color="b", width=0.1)
        plt.xticks(np.linspace(0,1.5,len(words)), words, rotation=45, ha="right")  # Rotate the xticks for better readability
        plt.ylabel("Weighted Unsqared Chamfer Distance")
        plt.title("Pooled Subject Cross Modal WUCD")
        plt.tight_layout()
        plt.savefig("heatmaps/pooled_subject_cross_modal/WUCD.png",bbox_inches="tight")
        plt.show()

        # Save the WUCD results
        Chamfer = dict(zip(words, WUCDs))
        with open("heatmaps/pooled_subject_cross_modal/WUCDs.json", "w") as file:
            json.dump(Chamfer, file)









###################### Cross Subject ######################
if options.cross_subject:

    if not options.plot_only:

        if not os.path.exists("heatmaps/cross_subject"):
            os.mkdir("heatmaps/cross_subject")
            for subject in range(subjects):
                os.mkdir("heatmaps/cross_subject/subject-%02d"%(subject+1))
                os.mkdir("heatmaps/cross_subject/subject-%02d/video"%(subject+1))
                os.mkdir("heatmaps/cross_subject/subject-%02d/text"%(subject+1))


        print("Cross Subject: Generating heatmaps of the 60732 voxels and store them into .npy files")

        ## This part generates the heatmaps and save them to .npy files

        # Loop over 12 tasks:
        for word in tasks:
            print("Generating heatmaps for task: %s"%word.replace("-", " "))

            # Loop over 2 modalities
            for modality in ["video", "text"]:
                print("Modality: %s"%modality)

                # Loop over 12 subjects:
                for subject in range(subjects):
                    print("Subject-%02d"%(subject+1))

                    # Load the testset of the corresponding subject
                    test_set = read_dataset(
                        experiment,
                        options.HRFdelay,
                        subject,
                        modality,
                        word+"1"
                    )
                    test_set = aggregate_supertrials(test_set, 2)

                    # Load the model
                    net = regress_net(
                        options.layers, 60732, options.hidden,
                        [word]).to(device)
                    net.load_state_dict(
                        torch.load(os.path.join(
                            options.directory,
                            "subject-%02d/%s/%s1-cross-subject-None-3-nn-1-10-whole-brain.pth"%(
                                subject+1,modality,word))))
                    net.eval()

                    # The input data with the stimuli
                    stimulated = torch.where(test_set["labels"] == 1)[0]
                    inputs = test_set["fmri"][stimulated].to(device)

                    with torch.no_grad():
                        # Weights of the linear model
                        weight = net.layers[0].weight.squeeze()

                        # The absolute importance of voxels
                        importance = (weight*inputs).detach().cpu()
                        importance = torch.abs(importance)

                        # Individual rankings
                        ranking = importance.argsort(1).argsort(1).float()
                        ranking /= ranking.max()

                        # Averaged rankings across samples
                        ranking_avg = ranking.mean(0).detach().cpu().numpy()

                        # Save the averaged rankings
                        np.save("heatmaps/cross_subject/subject-%02d/%s/%s.npy"%(
                            subject+1, modality, word), ranking_avg)

    if not options.heatmap_only:

        ## This part uses the generated heatmaps to perform analysis

        # Loop over 2 modalities
        for modality in ["video", "text"]:

            Taus = []
            WUCDs = []

            # Loop over 12 tasks
            for word in tasks:

                Tau = np.zeros((subjects,subjects))
                WUCD = np.zeros((subjects,subjects))

                # Loop over 12C2=66 subject pairs
                for subject_1 in range(subjects):
                    for subject_2 in range(subject_1+1,subjects):

                        # Load the corresponding rankings
                        ranking_1 = np.load(os.path.join("heatmaps/cross_subject/subject-%02d/%s/%s.npy"%(
                            subject_1+1, modality, word)))
                        ranking_2 = np.load(os.path.join("heatmaps/cross_subject/subject-%02d/%s/%s.npy"%(
                            subject_2+1, modality, word)))

                        # Save the Kendall correlation results
                        tau, _ = kendalltau(ranking_1,ranking_2)
                        Tau[subject_1,subject_2] = tau

                        # Save the WUCD results
                        WUCD[subject_1,subject_2] = weighted_unsquared_Chamfer_distance(ranking_1, ranking_2)


                Taus.append(Tau)
                WUCDs.append(WUCD)

            # Save the metric results of the entire modality
            np.save("heatmaps/cross_subject/%s_Kendall.npy"%modality, Taus)
            np.save("heatmaps/cross_subject/%s_WUCD.npy"%modality, WUCDs)


        ## This part plots the results:

        # Kendall correlations
        Ax = []
        fig, ax = plt.subplots(1,2,figsize=(10,5))
        i = 0
        for modality in ["video","text"]:
            # Load the statistics
            Taus = np.load("heatmaps/cross_subject/%s_Kendall.npy"%modality)
            Taus = Taus.mean(0)

            Ax.append(ax[i].imshow(Taus, vmin = -1, vmax = 1, cmap='bwr'))
            ax[i].set_xticks(np.arange(12), np.arange(12)+1)
            ax[i].set_yticks(np.arange(12), np.arange(12)+1)
            ax[i].set_xlabel("Subject", fontsize = 15)
            ax[i].set_ylabel("Subject", fontsize = 15)
            ax[i].set_title("%s"%(modality.capitalize()), fontsize = 15)
            i += 1

        fig.subplots_adjust(right=0.82)
        cbar_ax = fig.add_axes([0.835, 0.2, 0.02, 0.6])
        fig.colorbar(Ax[1], orientation="vertical", cax=cbar_ax)
        plt.savefig("heatmaps/cross_subject/Kendall.png", bbox_inches="tight")
        plt.show()


        # WUCDs
        Ax = []
        fig, ax = plt.subplots(1,2,figsize=(10,5))
        i = 0
        for modality in ["video","text"]:
            # Load the statistics
            WUCDs = np.load("heatmaps/cross_subject/%s_WUCD.npy"%modality)
            WUCDs = WUCDs.mean(0)

            Ax.append(ax[i].imshow(WUCDs, vmin = 0, vmax = 811.70, cmap='bwr'))
            ax[i].set_xticks(np.arange(12), np.arange(12)+1)
            ax[i].set_yticks(np.arange(12), np.arange(12)+1)
            ax[i].set_xlabel("Subject", fontsize = 15)
            ax[i].set_ylabel("Subject", fontsize = 15)
            ax[i].set_title("%s"%(modality.capitalize()), fontsize = 15)
            i += 1

        fig.subplots_adjust(right=0.82)
        cbar_ax = fig.add_axes([0.835, 0.2, 0.02, 0.6])
        fig.colorbar(Ax[1], orientation="vertical", cax=cbar_ax)
        plt.savefig("heatmaps/cross_subject/WUCD.png", bbox_inches="tight")
        plt.show()
