import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

roman = ["i", "ii", "iii", "iv", "v", "vi", "vii", "viii", "ix", "x",
         "xi", "xii", "xiii", "xiv", "xv", "xvi", "xvii", "xviii", "xix", "xx",
         "xxi", "xxii", "xxiii", "xxiv"]

datasets = ["their", "our-block", "our-randomized"]

encoders = ["inception_v3", "resnet101", "densenet161", "alexnet"]

mapping = [c[0:len(c)-1] for c in open("mapping.txt", "r").readlines()]
classes = ["" for i in range(1000)]
for pair in mapping:
    classes[int(pair.split(" ")[1])] = pair.split(" ")[0]

encoder_name = {"inception_v3": "Inception v3",
                "resnet101": "ResNet-101",
                "densenet161": "DenseNet-161",
                "alexnet": "AlexNet"}

highlight1 = "\\red"
highlight2 = "\\green"

table = {}
for joint in [True, False]:
    for before in [True, False]:
        for encoder in encoders:
            for dataset in datasets:
                for forty_classes in [True, False]:
                    for no_pretraining in [False, True]:
                        if no_pretraining and (before or not joint):
                            continue
                        if "our" in dataset:
                            spampinato_classes = ["n02106662",
                                                  "n02124075",
                                                  "n02281787",
                                                  "n02389026",
                                                  "n02492035",
                                                  "n02504458",
                                                  "n02510455",
                                                  "n02607072",
                                                  "n02690373",
                                                  "n02906734",
                                                  "n02951358",
                                                  "n02992529",
                                                  "n03063599",
                                                  "n03100240",
                                                  "n03180011",
                                                  "n03197337",
                                                  "n03272010",
                                                  "n03272562",
                                                  "n03297495",
                                                  "n03376595",
                                                  "n03445777",
                                                  "n03452741",
                                                  "n03584829",
                                                  "n03590841",
                                                  "n03709823",
                                                  "n03773504",
                                                  "n03775071",
                                                  "n03792782",
                                                  "n03792972",
                                                  "n03877472",
                                                  "n03888257",
                                                  "n03982430",
                                                  "n04044716",
                                                  "n04069434",
                                                  "n04086273",
                                                  "n04120489",
                                                  "n07753592",
                                                  "n07873807",
                                                  "n11939491",
                                                  "n13054560"]
                        else:
                            spampinato_classes = ["n02389026",
                                                  "n03888257",
                                                  "n03584829",
                                                  "n02607072",
                                                  "n03297495",
                                                  "n03063599",
                                                  "n03792782",
                                                  "n04086273",
                                                  "n02510455",
                                                  "n11939491",
                                                  "n02951358",
                                                  "n02281787",
                                                  "n02106662",
                                                  "n04120489",
                                                  "n03590841",
                                                  "n02992529",
                                                  "n03445777",
                                                  "n03180011",
                                                  "n02906734",
                                                  "n07873807",
                                                  "n03773504",
                                                  "n02492035",
                                                  "n03982430",
                                                  "n03709823",
                                                  "n03100240",
                                                  "n03376595",
                                                  "n03877472",
                                                  "n03775071",
                                                  "n03272010",
                                                  "n04069434",
                                                  "n03452741",
                                                  "n03792972",
                                                  "n07753592",
                                                  "n13054560",
                                                  "n03197337",
                                                  "n02504458",
                                                  "n02690373",
                                                  "n03272562",
                                                  "n04044716",
                                                  "n02124075"]

                        if joint:
                            if dataset=="our-block":
                                name = "%sour-block-%s-encodings%s.pkl"%("before-" if before else "after-joint-", encoder, "-no-pretraining" if no_pretraining else "")
                            elif dataset=="our-randomized":
                                name = "%sour-randomized-%s-encodings%s.pkl"%("before-" if before else "after-joint-", encoder, "-no-pretraining" if no_pretraining else "")
                            else:
                                name = "%stheir-%s-encodings%s.pkl"%("before-" if before else "after-joint-", encoder, "-no-pretraining" if no_pretraining else "")
                        else:
                            if dataset=="our-block":
                                name = "%sour-block-%s-encodings.pkl"%("before-" if before else "after-separate-", encoder)
                            elif dataset=="our-randomized":
                                name = "%sour-randomized-%s-encodings.pkl"%("before-" if before else "after-separate-", encoder)
                            else:
                                name = "%stheir-%s-encodings.pkl"%("before-" if before else "after-separate-", encoder)

                        encodings = pkl.load(open(name, "rb"))

                        eeg_encodings = np.array(encodings[0])
                        image_encodings = np.array(encodings[1])
                        targets = np.array(encodings[2])

                        pca = PCA()

                        pca.fit(eeg_encodings)
                        eeg_encodings40 = (
                            pca.transform(eeg_encodings)[:,:40].dot(
                            pca.components_[:40,:])+pca.mean_)
                        eeg_encodings960 = (
                            pca.transform(eeg_encodings)[:,40:].dot(
                            pca.components_[40:,:])+pca.mean_)
                        eeg_explained_variance40 = (
                            pca.explained_variance_ratio_[:40].sum())
                        eeg_explained_variance960 = (
                            pca.explained_variance_ratio_[40:].sum())

                        pca.fit(image_encodings)
                        image_encodings40 = pca.transform(
                            image_encodings)[:,:40].dot(
                                pca.components_[:40,:])+pca.mean_
                        image_encodings960 = pca.transform(
                            image_encodings)[:,40:].dot(
                                pca.components_[40:,:])+pca.mean_
                        image_explained_variance40 = (
                            pca.explained_variance_ratio_[:40].sum())
                        image_explained_variance960 = (
                            pca.explained_variance_ratio_[40:].sum())

                        eeg_encodings = eeg_encodings.tolist()
                        eeg_encodings40 = eeg_encodings40.tolist()
                        eeg_encodings960 = eeg_encodings960.tolist()
                        image_encodings = image_encodings.tolist()
                        image_encodings40 = image_encodings40.tolist()
                        image_encodings960 = image_encodings960.tolist()

                        table[("PCA explained variance",
                               "EEG",
                               "top 40",
                               joint,
                               before,
                               encoder,
                               dataset,
                               forty_classes,
                               no_pretraining)] = 100*eeg_explained_variance40
                        table[("PCA explained variance",
                               "EEG",
                               "bottom 960",
                               joint,
                               before,
                               encoder,
                               dataset,
                               forty_classes,
                               no_pretraining)] = 100*eeg_explained_variance960
                        table[("PCA explained variance",
                               "image",
                               "top 40",
                               joint,
                               before,
                               encoder,
                               dataset,
                               forty_classes,
                               no_pretraining)] = 100*image_explained_variance40
                        table[("PCA explained variance",
                               "image",
                               "bottom 960",
                               joint,
                               before,
                               encoder,
                               dataset,
                               forty_classes,
                               no_pretraining)] = 100*image_explained_variance960

                        count = 0
                        for eeg_encoding, target in zip(eeg_encodings, targets):
                            if forty_classes:
                                for i in range(1000):
                                    if classes[i] not in spampinato_classes:
                                        eeg_encoding[i] = -float("inf")
                            if (spampinato_classes[target]==
                                classes[eeg_encoding.index(max(eeg_encoding))]):
                                count += 1
                        table[("PCA accuracy",
                               "EEG",
                               "original",
                               joint,
                               before,
                               encoder,
                               dataset,
                               forty_classes,
                               no_pretraining)] = (
                                   100*float(count)/len(encodings[2]))

                        count = 0
                        for eeg_encoding, target in zip(eeg_encodings40,
                                                        targets):
                            if forty_classes:
                                for i in range(1000):
                                    if classes[i] not in spampinato_classes:
                                        eeg_encoding[i] = -float("inf")
                            if (spampinato_classes[target]==
                                classes[eeg_encoding.index(max(eeg_encoding))]):
                                count += 1
                        table[("PCA accuracy",
                               "EEG",
                               "top 40",
                               joint,
                               before,
                               encoder,
                               dataset,
                               forty_classes,
                               no_pretraining)] = (
                                   100*float(count)/len(encodings[2]))

                        count = 0
                        for eeg_encoding, target in zip(eeg_encodings960,
                                                        targets):
                            if forty_classes:
                                for i in range(1000):
                                    if classes[i] not in spampinato_classes:
                                        eeg_encoding[i] = -float("inf")
                            if (spampinato_classes[target]==
                                classes[eeg_encoding.index(max(eeg_encoding))]):
                                count += 1
                        table[("PCA accuracy",
                               "EEG",
                               "bottom 960",
                               joint,
                               before,
                               encoder,
                               dataset,
                               forty_classes,
                               no_pretraining)] = (
                                   100*float(count)/len(encodings[2]))

                        count = 0
                        for image_encoding, target in zip(image_encodings,
                                                          targets):
                            if forty_classes:
                                for i in range(1000):
                                    if classes[i] not in spampinato_classes:
                                        image_encoding[i] = -float("inf")
                            if (spampinato_classes[target]==
                                classes[
                                    image_encoding.index(max(image_encoding))]):
                                count += 1
                        table[("PCA accuracy",
                               "image",
                               "original",
                               joint,
                               before,
                               encoder,
                               dataset,
                               forty_classes,
                               no_pretraining)] = (
                                   100*float(count)/len(encodings[2]))

                        count = 0
                        for image_encoding, target in zip(
                                image_encodings40, targets):
                            if forty_classes:
                                for i in range(1000):
                                    if classes[i] not in spampinato_classes:
                                        image_encoding[i] = -float("inf")
                            if (spampinato_classes[target]==
                                classes[
                                    image_encoding.index(max(image_encoding))]):
                                count += 1
                        table[("PCA accuracy",
                               "image",
                               "top 40",
                               joint,
                               before,
                               encoder,
                               dataset,
                               forty_classes,
                               no_pretraining)] = (
                                   100*float(count)/len(encodings[2]))

                        count = 0
                        for image_encoding, target in zip(
                                image_encodings960, targets):
                            if forty_classes:
                                for i in range(1000):
                                    if classes[i] not in spampinato_classes:
                                        image_encoding[i] = -float("inf")
                            if (spampinato_classes[target]==
                                classes[
                                    image_encoding.index(max(image_encoding))]):
                                count += 1
                        table[("PCA accuracy",
                               "image",
                               "bottom 960",
                               joint,
                               before,
                               encoder,
                               dataset,
                               forty_classes,
                               no_pretraining)] = (
                                   100*float(count)/len(encodings[2]))

triplet_loss = {}
for joint in [True, False]:
    for encoder in encoders:
        for dataset in datasets:
            for no_pretraining in [False, True]:
                if no_pretraining and not joint:
                    continue
                if joint:
                    if dataset=="our-block":
                        name = "after-joint-our-block-%s-encodings%s.txt"%(
                            encoder, "-no-pretraining" if no_pretraining else "")
                    elif dataset=="our-randomized":
                        name = "after-joint-our-randomized-%s-encodings%s.txt"%(
                            encoder, "-no-pretraining" if no_pretraining else "")
                    else:
                        name = "after-joint-their-%s-encodings%s.txt"%(
                            encoder, "-no-pretraining" if no_pretraining else "")
                else:
                    if dataset=="our-block":
                        name = "after-separate-our-block-%s-encodings.txt"%encoder
                    elif dataset=="our-randomized":
                        name = "after-separate-our-randomized-%s-encodings.txt"%encoder
                    else:
                        name = "after-separate-their-%s-encodings.txt"%encoder
                for line in open(name, "r").readlines():
                    if joint:
                        if "joint" in line and "epoch 100" in line:
                            if ((joint, encoder, dataset, no_pretraining)
                                not in triplet_loss):
                                triplet_loss[(joint,
                                              encoder,
                                              dataset,
                                              no_pretraining)] = []
                            triplet_loss[(joint,
                                          encoder,
                                          dataset,
                                          no_pretraining)].append(
                                              (float(line.split(" ")[9]),
                                               int(line.split(" ")[11])))
                    else:
                        if "joint" in line:
                            if ((joint, encoder, dataset, no_pretraining)
                                not in triplet_loss):
                                triplet_loss[(joint,
                                              encoder,
                                              dataset,
                                              no_pretraining)] = []
                            triplet_loss[(joint,
                                          encoder,
                                          dataset,
                                          no_pretraining)].append(
                                              (float(line.split(" ")[6]),
                                               int(line.split(" ")[8])))
                triplet_loss[(joint, encoder, dataset, no_pretraining)] = (
                    sum([loss
                         for loss, _
                         in triplet_loss[(joint,
                                          encoder,
                                          dataset,
                                          no_pretraining)]])/
                    sum([loss
                         for _, samples
                         in triplet_loss[(joint,
                                          encoder,
                                          dataset,
                                          no_pretraining)]]))

mse_loss = {}
for encoder in encoders:
    for dataset in datasets:
        if dataset=="our-block":
            name = "after-separate-our-block-%s-encodings.txt"%encoder
        elif dataset=="our-randomized":
            name = "after-separate-our-randomized-%s-encodings.txt"%encoder
        else:
            name = "after-separate-their-%s-encodings.txt"%encoder
        for line in open(name, "r").readlines():
            if "EEG" in line and "epoch 100" in line:
                if ("EEG", encoder, dataset) not in mse_loss:
                    mse_loss[("EEG", encoder, dataset)] = []
                mse_loss[("EEG", encoder, dataset)].append(
                    (float(line.split(" ")[6]),
                     int(line.split(" ")[8])))
            if "image" in line and "epoch 100" in line:
                if ("image", encoder, dataset) not in mse_loss:
                    mse_loss[("image", encoder, dataset)] = []
                mse_loss[("image", encoder, dataset)].append(
                    (float(line.split(" ")[6]),
                     int(line.split(" ")[8])))
        mse_loss[("EEG", encoder, dataset)] = (
            sum([loss
                 for loss, _
                 in mse_loss[("EEG", encoder, dataset)]])/
            sum([samples
                 for _, samples
                 in mse_loss[("EEG", encoder, dataset)]]))
        mse_loss[("image", encoder, dataset)] = (
            sum([loss
                 for loss, _
                 in mse_loss[("image", encoder, dataset)]])/
            sum([samples
                 for _, samples
                 in mse_loss[("image", encoder, dataset)]]))

def make_table1(pathname, condition1, condition2):
    f = open(pathname, "w")
    f.write("\\begin{tabular}{ll|l|rrrr|rrrr|}\n")
    f.write("&&&\\multicolumn{4}{c|}{EEG}&\multicolumn{4}{c|}{image}\\\\\n")
    f.write("&&&\\multicolumn{1}{c}{before}&\multicolumn{2}{c}{after}&\\multicolumn{1}{c|}{after}&\\multicolumn{1}{c}{before}&\multicolumn{2}{c}{after}&\\multicolumn{1}{c|}{after}\\\\\n")
    f.write("&&&&\multicolumn{2}{c}{joint}&\\multicolumn{1}{c|}{separate}&&\multicolumn{2}{c}{joint}&\\multicolumn{1}{c|}{separate}\\\\\n")
    f.write("&&&&\\multicolumn{1}{c}{pretraining}&\\multicolumn{1}{c}{no pretraining}&&&\\multicolumn{1}{c}{pretraining}&\\multicolumn{1}{c}{no pretraining}&\\\\\n")
    f.write("\\hline\n")
    f.write("&&&i&ii&iii&iv&v&vi&vii&viii\\\\\n")
    row = 0
    for dataset in datasets:
        for encoder in encoders:
            if dataset=="our-block": name = "our block"
            elif dataset=="our-randomized": name = "our randomized"
            else: name = "their"
            if encoder=="inception_v3":
                f.write("\\hline\n")
                f.write("%s&%s&%s"%(name, encoder_name[encoder], roman[row]))
            else: f.write("&%s&%s"%(encoder_name[encoder], roman[row]))
            column = 0
            for modality in ["EEG", "image"]:
                if (table[("PCA explained variance",
                           modality,
                           "top 40",
                           True,
                           True,
                           encoder,
                           dataset,
                           False,
                           False)]!=
                    table[("PCA explained variance",
                           modality,
                           "top 40",
                           False,
                           True,
                           encoder,
                           dataset,
                           False,
                           False)]):
                    raise RuntimeError("unequal before")
                f.write("&%s{%5.1f}"%(
                    highlight1 if condition1(column, row) else
                    highlight2 if condition2(column, row) else "",
                    table[("PCA explained variance",
                           modality,
                           "top 40",
                           True,
                           True,
                           encoder,
                           dataset,
                           False,
                           False)]))
                column += 1
                f.write("&%s{%5.1f}"%(
                    highlight1 if condition1(column, row) else
                    highlight2 if condition2(column, row) else "",
                    table[("PCA explained variance",
                           modality,
                           "top 40",
                           True,
                           False,
                           encoder,
                           dataset,
                           False,
                           False)]))
                column += 1
                f.write("&%s{%5.1f}"%(
                    highlight1 if condition1(column, row) else
                    highlight2 if condition2(column, row) else "",
                    table[("PCA explained variance",
                           modality,
                           "top 40",
                           True,
                           False,
                           encoder,
                           dataset,
                           False,
                           True)]))
                column += 1
                f.write("&%s{%5.1f}"%(
                    highlight1 if condition1(column, row) else
                    highlight2 if condition2(column, row) else "",
                    table[("PCA explained variance",
                           modality,
                           "top 40",
                           False,
                           False,
                           encoder,
                           dataset,
                           False,
                           False)]))
                column += 1
            f.write("\\\\\n")
            row += 1
    f.write("\\end{tabular}\n")
    f.close()

def make_table2(pathname, condition1, condition2):
    f = open(pathname, "w")
    f.write("\\begin{tabular}{lll|l|rrrr|rrrr|rrrr|rrrr|rrrr|rrrr|}\n")
    f.write("&&&&\\multicolumn{12}{c|}{40 classes}&\\multicolumn{12}{c|}{1000 classes}\\\\\n")
    f.write("&&&&\\multicolumn{4}{c|}{original}&\\multicolumn{4}{c|}{top 40}&\\multicolumn{4}{c|}{bottom 960}&\\multicolumn{4}{c|}{original}&\\multicolumn{4}{c|}{top 40}&\\multicolumn{4}{c|}{bottom 960}\\\\\n")
    f.write("&&&&\\multicolumn{1}{c}{before}&\\multicolumn{2}{c}{after}&\\multicolumn{1}{c|}{after}&\\multicolumn{1}{c}{before}&\\multicolumn{2}{c}{after}&\\multicolumn{1}{c|}{after}&\\multicolumn{1}{c}{before}&\\multicolumn{2}{c}{after}&\\multicolumn{1}{c|}{after}&\\multicolumn{1}{c}{before}&\\multicolumn{2}{c}{after}&\\multicolumn{1}{c|}{after}&\\multicolumn{1}{c}{before}&\\multicolumn{2}{c}{after}&\\multicolumn{1}{c|}{after}&\\multicolumn{1}{c}{before}&\\multicolumn{2}{c}{after}&\\multicolumn{1}{c|}{after}\\\\\n")
    f.write("&&&&&\\multicolumn{2}{c}{joint}&\\multicolumn{1}{c|}{separate}&&\\multicolumn{2}{c}{joint}&\\multicolumn{1}{c|}{separate}&&\\multicolumn{2}{c}{joint}&\\multicolumn{1}{c|}{separate}&&\\multicolumn{2}{c}{joint}&\\multicolumn{1}{c|}{separate}&&\\multicolumn{2}{c}{joint}&\\multicolumn{1}{c|}{separate}&&\\multicolumn{2}{c}{joint}&\\multicolumn{1}{c|}{separate}\\\\\n")
    f.write("&&&&&\\multicolumn{1}{c}{pretraining}&\\multicolumn{1}{c}{no pretraining}&&&\\multicolumn{1}{c}{pretraining}&\\multicolumn{1}{c}{no pretraining}&&&\\multicolumn{1}{c}{pretraining}&\\multicolumn{1}{c}{no pretraining}&&&\\multicolumn{1}{c}{pretraining}&\\multicolumn{1}{c}{no pretraining}&&&\\multicolumn{1}{c}{pretraining}&\\multicolumn{1}{c}{no pretraining}&&&\\multicolumn{1}{c}{pretraining}&\\multicolumn{1}{c}{no pretraining}&\\\\\n")
    f.write("\\hline\n")
    f.write("&&&&i&ii&iii&iv&v&vi&vii&viii&ix&x&xi&xii&xiii&xiv&xv&xvi&xvii&xviii&xix&xx&xxi&xxii&xxiii&xxiv\\\\\n")
    row = 0
    for dataset in datasets:
        for modality in ["EEG", "image"]:
            for encoder in encoders:
                if modality=="EEG" and encoder=="inception_v3":
                    if dataset=="our-block": name = "our block"
                    elif dataset=="our-randomized": name = "our randomized"
                    else: name = "their"
                    f.write("\\hline\n")
                    f.write("%s&%s&%s&%s"%(
                        name, modality, encoder_name[encoder], roman[row]))
                elif encoder=="inception_v3":
                    f.write("\\cline{2-28}\n")
                    f.write("&%s&%s&%s"%(
                        modality, encoder_name[encoder], roman[row]))
                else: f.write("&&%s&%s"%(encoder_name[encoder], roman[row]))
                column = 0
                for forty_classes in [True, False]:
                    for kind in ["original", "top 40", "bottom 960"]:
                        if (table[("PCA accuracy",
                                   modality,
                                   kind,
                                   True,
                                   True,
                                   encoder,
                                   dataset,
                                   forty_classes,
                                   False)]!=
                            table[("PCA accuracy",
                                   modality,
                                   kind,
                                   False,
                                   True,
                                   encoder,
                                   dataset,
                                   forty_classes,
                                   False)]):
                            raise RuntimeError("unequal before")
                        f.write("&%s{%5.1f}"%(
                            highlight1 if condition1(column, row) else
                            highlight2 if condition2(column, row) else "",
                            table[("PCA accuracy",
                                   modality,
                                   kind,
                                   True,
                                   True,
                                   encoder,
                                   dataset,
                                   forty_classes,
                                   False)]))
                        column += 1
                        f.write("&%s{%5.1f}"%(
                            highlight1 if condition1(column, row) else
                            highlight2 if condition2(column, row) else "",
                            table[("PCA accuracy",
                                   modality,
                                   kind,
                                   True,
                                   False,
                                   encoder,
                                   dataset,
                                   forty_classes,
                                   False)]))
                        column += 1
                        f.write("&%s{%5.1f}"%(
                            highlight1 if condition1(column, row) else
                            highlight2 if condition2(column, row) else "",
                            table[("PCA accuracy",
                                   modality,
                                   kind,
                                   True,
                                   False,
                                   encoder,
                                   dataset,
                                   forty_classes,
                                   True)]))
                        column += 1
                        f.write("&%s{%5.1f}"%(
                            highlight1 if condition1(column, row) else
                            highlight2 if condition2(column, row) else "",
                            table[("PCA accuracy",
                                   modality,
                                   kind,
                                   False,
                                   False,
                                   encoder,
                                   dataset,
                                   forty_classes,
                                   False)]))
                        column += 1
                f.write("\\\\\n")
                row += 1
    f.write("\\end{tabular}\n")
    f.close()

def make_table3(pathname, condition1, condition2):
    f = open(pathname, "w")
    f.write("\\begin{tabular}{l|l|rrr|rrr|rrr|}\n")
    f.write("&&\\multicolumn{6}{c|}{joint}&\\multicolumn{3}{c|}{separate}\\\\\n")
    f.write("&&\\multicolumn{3}{c|}{pretraining}&\\multicolumn{3}{c|}{no pretraining}&\\multicolumn{3}{c|}{}\\\\\n")
    f.write("&&\\multicolumn{1}{c}{their}&\\multicolumn{1}{c}{our block}&\\multicolumn{1}{c|}{our randomized}&\\multicolumn{1}{c}{their}&\\multicolumn{1}{c}{our block}&\\multicolumn{1}{c|}{our randomized}&\\multicolumn{1}{c}{their}&\\multicolumn{1}{c}{our block}&\\multicolumn{1}{c|}{our randomized}\\\\\n")
    f.write("\\hline\n")
    f.write("&&i&ii&iii&iv&v&vi&vii&viii&ix\\\\\n")
    f.write("\\hline\n")
    row = 0
    for encoder in encoders:
        f.write("%s&%s"%(encoder_name[encoder], roman[row]))
        column = 0
        for joint in [True, False]:
            for no_pretraining in [False, True]:
                if no_pretraining and not joint:
                    continue
                for dataset in datasets:
                    f.write("&%s{%.3f}"%(
                        highlight1 if condition1(column, row) else
                        highlight2 if condition2(column, row) else "",
                        triplet_loss[(joint,
                                      encoder,
                                      dataset,
                                      no_pretraining)]))
                    column += 1
        f.write("\\\\\n")
        row += 1
    f.write("\\end{tabular}\n")
    f.close()

make_table1("table1.tex",
            lambda column, row: False,
            lambda column, row: False)
make_table1("table1a.tex",
            lambda column, row: column==4,
            lambda column, row: False)
make_table1("table1b.tex",
            lambda column, row: column==5 and row in [0, 1, 2, 3],
            lambda column, row: False)
make_table1("table1c.tex",
            lambda column, row: column==7,
            lambda column, row: False)
make_table1("table1d.tex",
            lambda column, row: (column in [3, 7] and row in [0, 1, 2, 3]),
            lambda column, row: False)
make_table2("table2.tex",
            lambda column, row: False,
            lambda column, row: False)
make_table2("table2a.tex",
            lambda column, row: column==0 and row in [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23],
            lambda column, row: False)
make_table2("table2b.tex",
            lambda column, row: column==12 and row in [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23],
            lambda column, row: False)
make_table2("table2c.tex",
            lambda column, row: column in [4, 16] and row in [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23],
            lambda column, row: False)
make_table2("table2d.tex",
            lambda column, row: column in [8, 20] and row in [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23],
            lambda column, row: False)
make_table2("table2e.tex",
            lambda column, row: column in [1, 5, 9, 13, 17, 21] and row in [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23],
            lambda column, row: column in [0, 4, 8, 12, 16, 20] and row in [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23])
make_table2("table2f.tex",
            lambda column, row: column in [3, 7, 15, 19] and row in [0, 1, 2, 3, 8, 9, 10, 11],
            lambda column, row: False)
make_table2("table2g.tex",
            lambda column, row: column in [11, 23] and row in [8, 9, 10, 11],
            lambda column, row: False)
make_table2("table2h.tex",
            lambda column, row: column in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23] and row in [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23],
            lambda column, row: False)
make_table2("table2i.tex",
            lambda column, row: column in [3, 7, 11, 15, 19, 23] and row in [0, 1, 2, 3, 8, 9, 10, 11],
            lambda column, row: False)
make_table2("table2j.tex",
            lambda column, row: column in [1, 5, 9, 13, 17, 20] and row in [16, 17, 18, 19],
            lambda column, row: False)
make_table2("table2k.tex",
            lambda column, row: column in [2, 6, 10, 14, 18, 22],
            lambda column, row: False)
make_table3("table3.tex",
            lambda column, row: False,
            lambda column, row: False)
make_table3("table3a.tex",
            lambda column, row: column in [6, 7, 8],
            lambda column, row: column in [0, 1, 2])
make_table3("table3b.tex",
            lambda column, row: column in [3, 4, 5],
            lambda column, row: column in [0, 1, 2])

f = open("table4.tex", "w")
f.write("\\begin{tabular}{l|l|rrr|rrr|}\n")
f.write("&&\\multicolumn{3}{c|}{EEG}&\multicolumn{3}{c|}{image}\\\\\n")
f.write("&&\\multicolumn{1}{c}{their}&\\multicolumn{1}{c}{our block}&\\multicolumn{1}{c|}{our randomized}&\\multicolumn{1}{c}{their}&\\multicolumn{1}{c}{our block}&\\multicolumn{1}{c|}{our randomized}\\\\\n")
f.write("\\hline\n")
f.write("&&i&ii&iii&iv&v&vi\\\\\n")
f.write("\\hline\n")
row = 0
for encoder in encoders:
    f.write("%s&%s"%(encoder_name[encoder], roman[row]))
    for modality in ["EEG", "image"]:
        for dataset in datasets:
            f.write("&%.3f"%mse_loss[(modality, encoder, dataset)])
    f.write("\\\\\n")
    row += 1
f.write("\\end{tabular}\n")
f.close()
