import argparse
parser = argparse.ArgumentParser(description = "Template")
parser.add_argument("-gpu",
                    "--GPUindex",
                    default = 0,
                    type = int,
                    help = "gpu index")
parser.add_argument("-l",
                    "--layers",
                    default = 1,
                    type = int,
                    help = "number of layers")
parser.add_argument("-hi",
                    "--hidden",
                    default = 10,
                    type = int,
                    help = "number of hidden")
parser.add_argument("-hrf",
                    "--HRFdelay",
                    default = 3,
                    type = int,
                    help = "HRF delay")
parser.add_argument("-m",
                    "--modality",
                    type = str,
                    help = "modality")
parser.add_argument("-w",
                    "--word",
                    type = str,
                    help = "word")
parser.add_argument("-p",
                    "--partial",
                    default=False,
                    action="store_true")
parser.add_argument("-i",
                    "--identifier",
                    type = str,
                    help = "identifier")
parser.add_argument("-d",
                    "--directory",
                    type = str,
                    help = "directory")
parser.add_argument("--single-subject",
                    default=False,
                    action="store_true")
parser.add_argument("--cross-modal",
                    default=False,
                    action="store_true")
parser.add_argument("--cross-subject",
                    default=False,
                    action="store_true")
parser.add_argument("--pooled-subject",
                    default=False,
                    action="store_true")
parser.add_argument("--pooled-subject-cross-modal",
                    default=False,
                    action="store_true")
parser.add_argument("-s",
                    "--svm",
                    default=False,
                    action="store_true")
parser.add_argument("--save",
                    default=False,
                    action="store_true")
options = parser.parse_args()

from xml.dom import minidom
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim
from mvpa2.suite import *
import pickle as pkl
from regressors import *
import os
import random
from libsvm.svmutil import *

seed = 12
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
mvpa2.seed(seed)
torch.backends.cudnn.deterministic = True
#torch.use_deterministic_algorithms(True)

infinity = float("inf")
experiment = "predication"
subjects = 12
folds = 8
runs = 16
number_of_detectors = 1
K = 32
supertrials = True
lr = 0.01

def aggregate_supertrials(raw_dataset, C):
    fmri = raw_dataset["fmri"]
    labels = raw_dataset["labels"]
    regions = raw_dataset["regions"]
    region_labels = raw_dataset["region_labels"]
    aggregated_fmri = []
    averaged_fmri = []
    # An empty list for each class.
    for c in range(C):
        aggregated_fmri.append([])
        averaged_fmri.append([])
    # labels[trial].item() is the class of trial number trial
    # fmri[trial, :] is the brain volume for trial number trial
    # This aggregates trial brain volumes by class.
    for trial in range(fmri.shape[0]):
        for k in range(K):
            # make K copies of each trial
            aggregated_fmri[int(labels[trial].item())].append(fmri[trial, :])
    # This actually constructs the supetrials.
    for label in range(len(aggregated_fmri)):
        random.shuffle(aggregated_fmri[label])
        # Instead of int(floor(len(aggregated_fmri[label])/K)) we can do
        # len(aggregated_fmri[label])//K.
        # This is the number of supertrials for each class.
        # k ranges over the supertrial indices
        # average the trials for each supetrial
        for k in range(len(aggregated_fmri[label])//K):
            # As supertrial k is the average of the K trials from k*K to
            # (k+1)*K-1.
            if options.svm:
                mean = np.zeros((len(aggregated_fmri[label][0]),), dtype=float)
            else:
                mean = torch.zeros((len(aggregated_fmri[label][0]),),
                                   dtype=torch.float)
            for sample in aggregated_fmri[label][k*K:(k+1)*K]:
                mean += sample
            averaged_fmri[label].append(mean/K)
    # This makes a dataset that is aggregated by class.
    supertrials = []
    supertrial_labels = []
    for c in range(C):
        supertrials += averaged_fmri[c]
        supertrial_labels += [[c]]*len(averaged_fmri[c])
    # This then shuffles the dataset.
    permutation = list(range(len(supertrial_labels)))
    random.shuffle(permutation)
    supertrials = [supertrials[i] for i in permutation]
    supertrial_labels = [supertrial_labels[i] for i in permutation]
    if options.svm:
        return {"fmri": np.stack(supertrials),
                "labels": np.array(supertrial_labels, dtype=float),
                "regions": regions,
                "region_labels": region_labels}
    else:
        return {"fmri": torch.stack(supertrials),
                "labels": torch.tensor(supertrial_labels, dtype=torch.float32),
                "regions": regions,
                "region_labels": region_labels}

def attributes_to_labels(attributes):
    targets = attributes.targets
    chunks = attributes.chunks
    label = -1
    labels = []
    for i in range(len(targets)):
        if chunks[i]!=label:
            labels.append([])
            label = chunks[i]
        labels[-1].append(targets[i])
    return labels

def labels_to_attributes(labels):
    targets = []
    chunks = []
    chunk = -1
    for label in labels:
        targets.extend(label)
        chunk += 1
        for i in range(len(label)):
            chunks.append(chunk)
    return targets, chunks

def add_hrf_delay_to_labels(labels, hrf_delay):
    new_labels = []
    for label in labels:
        new_labels.append([])
        for i in range(hrf_delay):
            new_labels[-1].append("rest")
        new_labels[-1].extend(label[:len(label)-hrf_delay])
    return new_labels

def trim_labels(labels, runs):
    return labels[:runs]

def add_hrf_delay_and_trim_attributes(attributes, hrf_delay, runs):
    labels = attributes_to_labels(attributes)
    labels = add_hrf_delay_to_labels(labels, hrf_delay)
    labels = trim_labels(labels, runs)
    return labels_to_attributes(labels)

def read_dataset(experiment,
                 hrf_delay,
                 subject,
                 modality,
                 atlas_name,
                 atlas_labels_name,
                 atlas_subject):
    attributes = SampleAttributes("/home/USER/fmri-experiments/%s/design/generated-experiments/subject-%02d/%s/pymvpa/attributes"%(experiment, subject+1, modality))
    targets, chunks = add_hrf_delay_and_trim_attributes(
        attributes, hrf_delay, runs)
    dataset = fmri_dataset(samples="%s/subject-%02d/epi.nii.gz"%(options.directory, subject+1),
                           targets=targets,
                           chunks=chunks,
                           mask="%s/subject-%02d/full-mask.nii.gz"%(options.directory, subject+1))
    poly_detrend(dataset, polyord=4, chunks_attr="chunks")
    zscore(dataset, chunks_attr="chunks", param_est=("targets", ["rest"]))
    dataset = dataset[dataset.sa.targets!="rest"]
    atlas = fmri_dataset(samples="/aux/USER/fmri-datasets/%s/processed/experiments/subject-%02d/%s.nii.gz"%(experiment, atlas_subject+1, atlas_name),
                           mask="/aux/USER/fmri-datasets/%s/processed/experiments/subject-%02d/full-mask.nii.gz"%(experiment, atlas_subject+1))
    if (atlas.fa.voxel_indices!=dataset.fa.voxel_indices).any():
        raise RuntimeError("mistmatched voxels")
    fmri = {"data": dataset.samples,
            "targets": dataset.targets,
            "chunks": dataset.chunks}
    samples = len(fmri["data"])
    mydoc = minidom.parse(
        #"/usr/share/fsl/5.0/data/atlases/%s.xml"%atlas_labels_name)
        "/aux/USER/fsl/data/atlases/%s.xml"%atlas_labels_name)
    voxel_name = {
        int(label.attributes["index"].value): label.firstChild.data
                   for label in mydoc.getElementsByTagName('label')}
    voxels = [voxel
              for voxel, index in enumerate(atlas.samples[0])
              if index>0
              if ("Broca" in voxel_name[index-1] or
                  "Visual cortex" in voxel_name[index-1])]
    used_fmri = []
    used_labels = []
    for sample in range(samples):
        c = fmri["targets"][sample]
        if options.word=="Dan":
            if "Dan" in c:
                case = 1
            else:
                case = 0
        elif options.word=="Scott":
            if "Scott" in c:
                case = 1
            else:
                case = 0
        elif options.word=="pick-up":
            if "pick-up" in c:
                case = 1
            else:
                case = 0
        elif options.word=="put-down":
            if "put-down" in c:
                case = 1
            else:
                case = 0
        elif options.word=="briefcase":
            if "briefcase" in c:
                case = 1
            else:
                case = 0
        elif options.word=="chair":
            if "chair" in c:
                case = 1
            else:
                case = 0
        elif options.word=="Dan-pick-up":
            if "Dan" in c and "pick-up" in c:
                if ("Dan-on-left-pick-up" in c or
                    "Dan-on-right-pick-up" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="Dan-put-down":
            if "Dan" in c and "put-down" in c:
                if ("Dan-on-left-put-down" in c or
                    "Dan-on-right-put-down" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="Scott-pick-up":
            if "Scott" in c and "pick-up" in c:
                if ("Scott-on-left-pick-up" in c or
                    "Scott-on-right-pick-up" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="Scott-put-down":
            if "Scott" in c and "put-down" in c:
                if ("Scott-on-left-put-down" in c or
                    "Scott-on-right-put-down" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="pick-up-briefcase":
            if "pick-up" in c and "briefcase" in c:
                if ("pick-up-on-left-briefcase" in c or
                    "pick-up-on-right-briefcase" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="pick-up-chair":
            if "pick-up" in c and "chair" in c:
                if ("pick-up-on-left-chair" in c or
                    "pick-up-on-right-chair" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="put-down-briefcase":
            if "put-down" in c and "briefcase" in c:
                if ("put-down-on-left-briefcase" in c or
                    "put-down-on-right-briefcase" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="put-down-chair":
            if "put-down" in c and "chair" in c:
                if ("put-down-on-left-chair" in c or
                    "put-down-on-right-chair" in c):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="Dan-briefcase":
            if "Dan" in c and "briefcase" in c:
                if (("Dan-on-left" in c and "briefcase-on-left" in c) or
                    ("Dan-on-right" in c and "briefcase-on-right" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="Dan-chair":
            if "Dan" in c and "chair" in c:
                if (("Dan-on-left" in c and "chair-on-left" in c) or
                    ("Dan-on-right" in c and "chair-on-right" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="Scott-briefcase":
            if "Scott" in c and "briefcase" in c:
                if (("Scott-on-left" in c and "briefcase-on-left" in c) or
                    ("Scott-on-right" in c and "briefcase-on-right" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="Scott-chair":
            if "Scott" in c and "chair" in c:
                if (("Scott-on-left" in c and "chair-on-left" in c) or
                    ("Scott-on-right" in c and "chair-on-right" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="subject":
            if ("Dan" in c and
                "Scott" in c and
                "pick-up" in c and
                "put-down" in c):
                if (("Dan-on-left-pick-up" in c and
                     "Scott-on-right-put-down" in c) or
                    ("Dan-on-right-pick-up" in c and
                     "Scott-on-left-put-down" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        elif options.word=="object":
            if ("pick-up" in c and
                "put-down" in c and
                "briefcase" in c and
                "chair" in c):
                if (("pick-up-on-left-briefcase" in c and
                     "put-down-on-right-chair" in c) or
                    ("pick-up-on-right-briefcase" in c and
                     "put-down-on-left-chair" in c)):
                    case = 1
                else:
                    case = 0
            else:
                case = -1
        # 1. What did Dan do? nothing [both false]
        # 1. What did Scott do? nothing [both false]
        # 2. What did Dan do something to? nothing [both false]
        # 2. What did Scott do something to? nothing [both false]
        # 3. Who picked something up? nobody [both false]
        # 3. Who picked something up? Dan and Scott [both true]
        # 3. Who put something down? nobody [both false]
        # 3. Who put something down? Dan and Scott [both true]
        # 4. What was picked up? nothing [both false]
        # 4. What was picked up? briefcase ad chair [both true]
        # 4. What was put down? nothing [both false]
        # 4. What was put down? briefcase ad chair [both true]
        # Can't distinguish one briefcase from two briefcases
        # Can't distinguish one chair from two chairs
        # 5. Who did something to a briefcase? nobody [both false]
        # 5. Who did something to a  briefcase? Dan and Scott [both true]
        # 5. Who did something to a chair? nobody [both false]
        # 5. Who did something to a chair? Dan and Scott [both true]
        # 6. What was done to a briefcase? nothing [both false]
        # 6. What was done to a briefcase? pick up and put down [both true]
        # 6. What was done to a chair? nothing [both false]
        # 6. What was done to a chair? pick up and put down [both true]
        # Can't distinguish one pick up from two pick ups
        # Can't distinguish one put down from two put downs

        # 1. What did Dan do? pick up
        # 3. Who picked something up? Dan
        elif options.word=="Dan-pick-up1":
            if ("Dan-on-left-pick-up" in c or
                "Dan-on-right-pick-up" in c):
                case = 1
            else:
                case = 0
        # 1. What did Dan do? put down
        # 3. Who put something down? Dan
        elif options.word=="Dan-put-down1":
            if ("Dan-on-left-put-down" in c or
                "Dan-on-right-put-down" in c):
                case = 1
            else:
                case = 0
        # 1. What did Scott do? pick up
        # 3. Who picked something up? Scott
        elif options.word=="Scott-pick-up1":
            if ("Scott-on-left-pick-up" in c or
                "Scott-on-right-pick-up" in c):
                case = 1
            else:
                case = 0
        # 1. What did Scott do? put down
        # 3. Who put something down? Scott
        elif options.word=="Scott-put-down1":
            if ("Scott-on-left-put-down" in c or
                "Scott-on-right-put-down" in c):
                case = 1
            else:
                case = 0
        # 2. What did Dan do something to? briefcase
        # 5. Who did something to a briefcase? Dan
        elif options.word=="Dan-briefcase1":
            if ("Dan-on-left-pick-up-on-left-briefcase" in c or
                "Dan-on-right-pick-up-on-right-briefcase" in c or
                "Dan-on-left-put-down-on-left-briefcase" in c or
                "Dan-on-right-put-down-on-right-briefcase" in c):
                case = 1
            else:
                case = 0
        # 2. What did Dan do something to? chair
        # 5. Who did something to a chair? Dan
        elif options.word=="Dan-chair1":
            if ("Dan-on-left-pick-up-on-left-chair" in c or
                "Dan-on-right-pick-up-on-right-chair" in c or
                "Dan-on-left-put-down-on-left-chair" in c or
                "Dan-on-right-put-down-on-right-chair" in c):
                case = 1
            else:
                case = 0
        # 2. What did Scott do something to? briefcase
        # 5. Who did something to a briefcase? Scott
        elif options.word=="Scott-briefcase1":
            if ("Scott-on-left-pick-up-on-left-briefcase" in c or
                "Scott-on-right-pick-up-on-right-briefcase" in c or
                "Scott-on-left-put-down-on-left-briefcase" in c or
                "Scott-on-right-put-down-on-right-briefcase" in c):
                case = 1
            else:
                case = 0
        # 2. What did Scott do something to? chair
        # 5. Who did something to a chair? Scott
        elif options.word=="Scott-chair1":
            if ("Scott-on-left-pick-up-on-left-chair" in c or
                "Scott-on-right-pick-up-on-right-chair" in c or
                "Scott-on-left-put-down-on-left-chair" in c or
                "Scott-on-right-put-down-on-right-chair" in c):
                case = 1
            else:
                case = 0
        # 4. What was picked up? briefcase
        # 6. What was done to a briefcase? pick up
        elif options.word=="pick-up-briefcase1":
            if ("pick-up-on-left-briefcase" in c or
                "pick-up-on-right-briefcase" in c):
                case = 1
            else:
                case = 0
        # 4. What was picked up? chair
        # 6. What was done to a chair? pick up
        elif options.word=="pick-up-chair1":
            if ("pick-up-on-left-chair" in c or
                "pick-up-on-right-chair" in c):
                case = 1
            else:
                case = 0
        # 4. What was put down? briefcase
        # 6. What was done to a briefcase? put down
        elif options.word=="put-down-briefcase1":
            if ("put-down-on-left-briefcase" in c or
                "put-down-on-right-briefcase" in c):
                case = 1
            else:
                case = 0
        # 4. What was put down? chair
        # 6. What was done to a chair? put down
        elif options.word=="put-down-chair1":
            if ("put-down-on-left-chair" in c or
                "put-down-on-right-chair" in c):
                case = 1
            else:
                case = 0
        else:
            raise RuntimeError("unknown word")
        if case>=0:
            if options.svm:
                used_fmri.append(np.array(fmri["data"][sample], dtype=float))
                used_labels.append(np.array([case], dtype=float))
            else:
                used_fmri.append(torch.tensor(fmri["data"][sample][voxels]
                                              if options.partial
                                              else fmri["data"][sample],
                                              dtype=torch.float32))
                used_labels.append(torch.tensor([case], dtype=torch.float32))
    if options.svm:
        return {"fmri": np.stack(used_fmri),
                "labels": np.stack(used_labels),
                "regions": atlas.samples[0],
                "region_labels": {
                    int(label.attributes["index"].value): label.firstChild.data
                       for label in mydoc.getElementsByTagName('label')}}
    else:
        return {"fmri": torch.stack(used_fmri),
                "labels": torch.stack(used_labels),
                "regions": atlas.samples[0],
                "region_labels": {
                    int(label.attributes["index"].value): label.firstChild.data
                       for label in mydoc.getElementsByTagName('label')}}

def read_pooled_subject_dataset(
        experiment, hrf_delay, modality, atlas_name, atlas_labels_name):
    datasets = [read_dataset(experiment,
                             hrf_delay,
                             s,
                             modality,
                             atlas_name,
                             atlas_labels_name,
                             # We arbitraryily pick atlas for subject 0.
                             0)
                for s in range(subjects)]
    # This assumes that all datasets have the same size.
    if options.svm:
        samples = datasets[0]["fmri"].shape[0]
        voxels = datasets[0]["fmri"].shape[1]
        dataset = {"fmri": np.zeros((subjects*samples, voxels), dtype=float),
                   "labels": np.zeros((subjects*samples, number_of_detectors),
                                      dtype=float),
                   "regions": datasets[0]["regions"],
                   "region_labels": datasets[0]["region_labels"]}
    else:
        samples = datasets[0]["fmri"].size(0)
        voxels = datasets[0]["fmri"].size(1)
        dataset = {"fmri": torch.zeros([subjects*samples, voxels],
                                       dtype=torch.float32),
                   "labels": torch.zeros([subjects*samples, number_of_detectors],
                                         dtype=torch.float32),
                   "regions": datasets[0]["regions"],
                   "region_labels": datasets[0]["region_labels"]}
    for i in range(samples):
        for s in range(subjects):
            dataset["fmri"][i*subjects+s, :] = datasets[s]["fmri"][i, :]
            dataset["labels"][i*subjects+s, :] = datasets[s]["labels"][i, :]
    return dataset

def read_pooled_cross_subject_dataset(
        experiment, hrf_delay, subject, modality, atlas_name, atlas_labels_name):
    datasets = [read_dataset(experiment,
                             hrf_delay,
                             s,
                             modality,
                             atlas_name,
                             atlas_labels_name,
                             subject)
                for s in range(subjects)
                if s!=subject]
    # This assumes that all datasets have the same size.
    if options.svm:
        samples = datasets[0]["fmri"].shape[0]
        voxels = datasets[0]["fmri"].shape[1]
        dataset = {"fmri": np.zeros(((subjects-1)*samples, voxels), dtype=float),
                   "labels": np.zeros(((subjects-1)*samples,
                                       number_of_detectors),
                                      dtype=float),
                   "regions": datasets[0]["regions"],
                   "region_labels": datasets[0]["region_labels"]}
    else:
        samples = datasets[0]["fmri"].size(0)
        voxels = datasets[0]["fmri"].size(1)
        dataset = {"fmri": torch.zeros([(subjects-1)*samples, voxels],
                                       dtype=torch.float32),
                   "labels": torch.zeros([(subjects-1)*samples,
                                          number_of_detectors],
                                         dtype=torch.float32),
                   "regions": datasets[0]["regions"],
                   "region_labels": datasets[0]["region_labels"]}
    for i in range(samples):
        s1 = 0
        for s in range(subjects):
            if s!=subject:
                dataset["fmri"][i*(subjects-1)+s1, :] = datasets[
                    s1]["fmri"][i, :]
                dataset["labels"][i*(subjects-1)+s1, :] = (
                    datasets[s1]["labels"][i, :])
                s1 += 1
    return dataset

def in_fold(sample, samples, fold, folds):
    return (sample>=fold*(samples/folds) and
            sample<(fold+1)*(samples/folds))

def select_training_set(dataset, fold, folds):
    if options.svm:
        samples = dataset["fmri"].shape[0]
        voxels = dataset["fmri"].shape[1]
    else:
        samples = dataset["fmri"].size(0)
        voxels = dataset["fmri"].size(1)
    number_of_samples = 0
    for sample in range(samples):
        if not in_fold(sample, samples, fold, folds):
            number_of_samples += 1
    if options.svm:
        training_set = {"fmri": np.zeros((number_of_samples, voxels),
                                         dtype=float),
                        "labels": np.zeros((number_of_samples,
                                            number_of_detectors),
                                           dtype=float),
                        "regions": dataset["regions"],
                        "region_labels": dataset["region_labels"]}
    else:
        training_set = {"fmri": torch.zeros([number_of_samples, voxels],
                                            dtype=torch.float32),
                        "labels": torch.zeros([number_of_samples,
                                               number_of_detectors],
                                              dtype=torch.float32),
                        "regions": dataset["regions"],
                        "region_labels": dataset["region_labels"]}
    selected = 0
    for sample in range(samples):
        if not in_fold(sample, samples, fold, folds):
            training_set["fmri"][selected, :] = dataset["fmri"][sample, :]
            training_set["labels"][selected] = dataset["labels"][sample]
            selected += 1
    return training_set

def select_test_set(dataset, fold, folds):
    if options.svm:
        samples = dataset["fmri"].shape[0]
        voxels = dataset["fmri"].shape[1]
    else:
        samples = dataset["fmri"].size(0)
        voxels = dataset["fmri"].size(1)
    number_of_samples = 0
    for sample in range(samples):
        if in_fold(sample, samples, fold, folds):
            number_of_samples += 1
    if options.svm:
        test_set = {"fmri": np.zeros((number_of_samples, voxels),
                                     dtype=float),
                    "labels": np.zeros((number_of_samples, number_of_detectors),
                                       dtype=float),
                    "regions": dataset["regions"],
                    "region_labels": dataset["region_labels"]}
    else:
        test_set = {"fmri": torch.zeros([number_of_samples, voxels],
                                        dtype=torch.float32),
                    "labels": torch.zeros([number_of_samples,
                                           number_of_detectors],
                                          dtype=torch.float32),
                    "regions": dataset["regions"],
                    "region_labels": dataset["region_labels"]}
    selected = 0
    for sample in range(samples):
        if in_fold(sample, samples, fold, folds):
            test_set["fmri"][selected, :] = dataset["fmri"][sample, :]
            test_set["labels"][selected] = dataset["labels"][sample]
            selected += 1
    return test_set

def print_validation_loss(epoch, gpu, net, test_sets):
    net.eval()
    #\needswork: batch the test
    for test_set in test_sets:
        test_input = test_set["fmri"].cuda(gpu, non_blocking = False)
        test_target = test_set["labels"].cuda(gpu, non_blocking = False)
        total_loss = 0
        for sample in range(test_set["fmri"].size(0)):
            with torch.no_grad():
                input = test_input[sample]
                target = test_target[sample]
                output = net(input)
                loss = F.mse_loss(output, target, reduction="mean")
                total_loss += np.sqrt(loss.tolist())
        print("epoch %d, validation loss %g"%(
            epoch, total_loss/test_set["fmri"].size(0)))
    net.train()

def train_net_on_dataset(gpu, net, training_set, test_sets, save):
    #\needswork: we don't do shuffling
    optimizer = getattr(torch.optim, "SGD")(net.parameters(), lr = lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones = [1000],
        gamma = 0.1)
    samples = training_set["fmri"].size(0)
    batch_size = 30
    net.cuda(gpu)
    net.train()
    training_input = training_set["fmri"].cuda(gpu, non_blocking = False)
    training_target = training_set["labels"].cuda(gpu, non_blocking = False)
    for epoch in range(2000):
        total_loss = 0
        for sample in range(0, samples, batch_size):
            input = training_input[sample:min(sample+batch_size, samples)]
            target = training_target[sample:min(sample+batch_size, samples)]
            output = net(input)
            loss = F.mse_loss(output, target, reduction="mean")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += np.sqrt(loss.tolist())
        if epoch%100==0:
            print("epoch %d, loss %g"%(epoch, total_loss/samples))
            print_validation_loss(epoch, gpu, net, test_sets)
        scheduler.step()
    if options.save:
        torch.save(net.state_dict(), save+".pth")
    net.eval()
    trials = []
    #\needswork: batch the test
    for test_set in test_sets:
        test_input = test_set["fmri"].cuda(gpu, non_blocking = False)
        test_target = test_set["labels"].cuda(gpu, non_blocking = False)
        for sample in range(test_set["fmri"].size(0)):
            with torch.no_grad():
                input = test_input[sample]
                target = test_target[sample]
                output = net(input)
                trials.append([output.tolist()[0],
                               target.tolist()[0]])
    return trials

def train_svm_on_dataset(training_set, test_sets, save):
    #\needswork: we don't do shuffling
    samples = training_set["fmri"].shape[0]
    training_input = training_set["fmri"].tolist()
    training_target = training_set["labels"][:, 0].tolist()
    problem = svm_problem(training_target, training_input)
    # -s 3 epsilon-SVR
    # -t 0 linear
    parameters = svm_parameter("-s 3 -t 0 -q")
    model = svm_train(problem, parameters)
    if options.save:
        svm_save_model(save+".model", model)
    trials = []
    #\needswork: batch the test
    for test_set in test_sets:
        test_input = test_set["fmri"].tolist()
        test_target = test_set["labels"][:, 0].tolist()
        output, _, _ = svm_predict(test_target, test_input, model, "-q")
        trials += zip(output, test_target)
    return trials

def train_on_dataset(gpu, net, training_set, test_sets, save):
    if options.svm:
        return train_svm_on_dataset(training_set, test_sets, save)
    else:
        return train_net_on_dataset(gpu, net, training_set, test_sets, save)

def model_filename(subject, modality):
    return "%s/subject-%02d/%s/%s-%s-%s-%d-%s-%s"%(
        options.directory,
        subject+1,
        modality,
        options.word,
        "single-subject" if options.single_subject
        else "cross-modal" if options.cross_modal
        else "cross-subject" if options.cross_subject
        else "pooled-subject" if options.pooled_subject
        else "unknown",
        options.identifier,
        options.HRFdelay,
        "svm" if options.svm else "nn-%d-%d"%(options.layers, options.hidden),
        "partial-brain" if options.partial else "whole-brain")

def pooled_model_filename(modality):
    return "%s/pooled-subject/%s/%s-pooled-subject-cross-modal-%s-%d-%s-%s"%(
        options.directory,
        modality,
        options.word,
        options.identifier,
        options.HRFdelay,
        "svm" if options.svm else "nn-%d-%d"%(options.layers, options.hidden),
        "partial-brain" if options.partial else "whole-brain")

def model_fold_filename(subject, modality, fold):
    return "%s/subject-%02d/%s/%s-%s-%s-%d-%s-%s-%d"%(
        options.directory,
        subject+1,
        modality,
        options.word,
        "single-subject" if options.single_subject
        else "cross-modal" if options.cross_modal
        else "cross-subject" if options.cross_subject
        else "pooled-subject" if options.pooled_subject
        else "unknown",
        options.identifier,
        options.HRFdelay,
        "svm" if options.svm else "nn-%d-%d"%(options.layers, options.hidden),
        "partial-brain" if options.partial else "whole-brain",
        fold)

def pooled_model_fold_filename(modality, fold):
    return "%s/pooled-subject/%s/%s-pooled-subject-%s-%d-%s-%s-%d"%(
        options.directory,
        modality,
        options.word,
        options.identifier,
        options.HRFdelay,
        "svm" if options.svm else "nn-%d-%d"%(options.layers, options.hidden),
        "partial-brain" if options.partial else "whole-brain",
        fold)

def pkl_filename(subject, modality):
    return "%s/subject-%02d/%s/%s-%s-%s-%d-%s-%s.pkl"%(
        options.directory,
        subject+1,
        modality,
        options.word,
        "single-subject" if options.single_subject
        else "cross-modal" if options.cross_modal
        else "cross-subject" if options.cross_subject
        else "pooled-subject" if options.pooled_subject
        else "unknown",
        options.identifier,
        options.HRFdelay,
        "svm" if options.svm else "nn-%d-%d"%(options.layers, options.hidden),
        "partial-brain" if options.partial else "whole-brain")

def pooled_pkl_filename(modality):
    return "%s/pooled-subject/%s/%s-%s-%s-%d-%s-%s.pkl"%(
        options.directory,
        modality,
        options.word,
        "pooled-subject" if options.pooled_subject
        else "pooled-subject-cross-modal" if options.pooled_subject_cross_modal
        else "unknown",
        options.identifier,
        options.HRFdelay,
        "svm" if options.svm else "nn-%d-%d"%(options.layers, options.hidden),
        "partial-brain" if options.partial else "whole-brain")

def train_single_subject(gpu,
                         experiment,
	                 hrf_delay,
	                 subject,
	                 modality,
	                 layers,
                         number_of_hidden,
                         atlas_name,
                         atlas_labels_name):
    raw_dataset = read_dataset(experiment,
                               hrf_delay,
                               subject,
                               modality,
                               atlas_name,
                               atlas_labels_name,
                               subject)
    trials = []
    for fold in range(folds):
        training_set = select_training_set(raw_dataset, fold, folds)
        test_set = select_test_set(raw_dataset, fold, folds)
        if supertrials:
            training_set = aggregate_supertrials(training_set, 2)
            test_set = aggregate_supertrials(test_set, 2)
        if options.svm:
            net = None
        else:
            net = regress_net(
                layers, training_set["fmri"].size(1), number_of_hidden,
                [options.word])
        trials += train_on_dataset(gpu,
                                   net,
                                   training_set,
                                   [test_set],
                                   model_fold_filename(subject, modality, fold))
    file = open(pkl_filename(subject, modality), "wb")
    pkl.dump(trials, file)
    file.close()

def train_cross_modal(gpu,
                      experiment,
	              hrf_delay,
	              subject,
	              source_modality,
	              target_modality,
	              layers,
                      number_of_hidden,
                      atlas_name,
                      atlas_labels_name):
    training_set = read_dataset(experiment,
                                hrf_delay,
                                subject,
                                source_modality,
                                atlas_name,
                                atlas_labels_name,
                                subject)
    test_set = read_dataset(experiment,
                            hrf_delay,
                            subject,
                            target_modality,
                            atlas_name,
                            atlas_labels_name,
                            subject)
    if supertrials:
        training_set = aggregate_supertrials(training_set, 2)
        test_set = aggregate_supertrials(test_set, 2)
    if options.svm:
        net = None
    else:
        net = regress_net(
            layers, training_set["fmri"].size(1), number_of_hidden,
            [options.word])
    trials = train_on_dataset(gpu,
                              net,
                              training_set,
                              [test_set],
                              model_filename(subject, source_modality))
    file = open(pkl_filename(subject, target_modality), "wb")
    pkl.dump(trials, file)
    file.close()

def train_cross_subject(gpu,
                        experiment,
	                hrf_delay,
	                subject,
	                modality,
	                layers,
                        number_of_hidden,
                        atlas_name,
                        atlas_labels_name):
    training_set = read_pooled_cross_subject_dataset(
        experiment, hrf_delay, subject, modality, atlas_name, atlas_labels_name)
    test_set = read_dataset(experiment,
                            hrf_delay,
                            subject,
                            modality,
                            atlas_name,
                            atlas_labels_name,
                            subject)
    if supertrials:
        training_set = aggregate_supertrials(training_set, 2)
        test_set = aggregate_supertrials(test_set, 2)
    if options.svm:
        net = None
    else:
        net = regress_net(
            layers, training_set["fmri"].size(1), number_of_hidden,
            [options.word])
    trials = train_on_dataset(gpu,
                              net,
                              training_set,
                              [test_set],
                              model_filename(subject, modality))
    file = open(pkl_filename(subject, modality), "wb")
    pkl.dump(trials, file)
    file.close()

def train_pooled_subject(gpu,
                         experiment,
	                 hrf_delay,
	                 modality,
	                 layers,
                         number_of_hidden,
                         atlas_name,
                         atlas_labels_name):
    raw_dataset = read_pooled_subject_dataset(
        experiment, hrf_delay, modality, atlas_name, atlas_labels_name)
    trials = []
    for fold in range(folds):
        training_set = select_training_set(raw_dataset, fold, folds)
        test_set = select_test_set(raw_dataset, fold, folds)
        if supertrials:
            training_set = aggregate_supertrials(training_set, 2)
            test_set = aggregate_supertrials(test_set, 2)
        if options.svm:
            net = None
        else:
            net = regress_net(
                layers, training_set["fmri"].size(1), number_of_hidden,
                [options.word])
        trials += train_on_dataset(gpu,
                                   net,
                                   training_set,
                                   [test_set],
                                   pooled_model_fold_filename(modality, fold))
    file = open(pooled_pkl_filename(modality), "wb")
    pkl.dump(trials, file)
    file.close()

def train_pooled_subject_cross_modal(gpu,
                                     experiment,
	                             hrf_delay,
	                             source_modality,
	                             target_modality,
	                             layers,
                                     number_of_hidden,
                                     atlas_name,
                                     atlas_labels_name):
    training_set = read_pooled_subject_dataset(
        experiment, hrf_delay, source_modality, atlas_name, atlas_labels_name)
    test_set = read_pooled_subject_dataset(
        experiment, hrf_delay, target_modality, atlas_name, atlas_labels_name)
    if supertrials:
        training_set = aggregate_supertrials(training_set, 2)
        test_set = aggregate_supertrials(test_set, 2)
    if options.svm:
        net = None
    else:
        net = regress_net(
            layers, training_set["fmri"].size(1), number_of_hidden,
            [options.word])
    trials = train_on_dataset(gpu,
                              net,
                              training_set,
                              [test_set],
                              pooled_model_filename(target_modality))
    file = open(pooled_pkl_filename(target_modality), "wb")
    pkl.dump(trials, file)
    file.close()

print("hrf_delay = %d"%options.HRFdelay)
print("layers = %d"%options.layers)
print("hidden = %d"%options.hidden)
print("K = %d"%K)
print("supertrials = %r"%supertrials)
print("newfangled replacement method")
print("lr = %f"%lr)

if options.single_subject:
    print("Single subject")
    print(options.modality)
    for subject in range(subjects):
        print("subject: %d"%subject)
        train_single_subject(options.GPUindex,
                             experiment,
                             options.HRFdelay,
                             subject,
                             options.modality,
                             options.layers,
                             options.hidden,
                             "Juelich-maxprob-thr50",
                             "Juelich")
elif options.cross_modal:
    print("Cross modal")
    for subject in range(subjects):
        print("subject: %d"%subject)
        train_cross_modal(options.GPUindex,
                          experiment,
                          options.HRFdelay,
                          subject,
                          "video",
                          "text",
                          options.layers,
                          options.hidden,
                          "Juelich-maxprob-thr50",
                          "Juelich")
        train_cross_modal(options.GPUindex,
                          experiment,
                          options.HRFdelay,
                          subject,
                          "text",
                          "video",
                          options.layers,
                          options.hidden,
                          "Juelich-maxprob-thr50",
                          "Juelich")
elif options.cross_subject:
    print("Cross subject")
    print(options.modality)
    for subject in range(subjects):
        print("subject: %d"%subject)
        train_cross_subject(options.GPUindex,
                            experiment,
                            options.HRFdelay,
                            subject,
                            options.modality,
                            options.layers,
                            options.hidden,
                            "Juelich-maxprob-thr50",
                            "Juelich")
elif options.pooled_subject:
    print("Pooled subject")
    print(options.modality)
    train_pooled_subject(options.GPUindex,
                         experiment,
                         options.HRFdelay,
                         options.modality,
                         options.layers,
                         options.hidden,
                         "Juelich-maxprob-thr50",
                         "Juelich")
elif options.pooled_subject_cross_modal:
    print("Cross subject cross modal")
    train_pooled_subject_cross_modal(options.GPUindex,
                                     experiment,
                                     options.HRFdelay,
                                     "video",
                                     "text",
                                     options.layers,
                                     options.hidden,
                                     "Juelich-maxprob-thr50",
                                     "Juelich")
    train_pooled_subject_cross_modal(options.GPUindex,
                                     experiment,
                                     options.HRFdelay,
                                     "text",
                                     "video",
                                     options.layers,
                                     options.hidden,
                                     "Juelich-maxprob-thr50",
                                     "Juelich")
else:
    raise RuntimeError("You need to specify one of --single-subject, --cross-modal, --cross-subject, --pooled-subject, or --pooled-subject-cross-modal")
