from __future__ import division
import argparse
parser = argparse.ArgumentParser(description = "Template")
# Dataset options
parser.add_argument("-iv",
                    "--iv",
                    help = "image/video",
                    type = str,
                    required = True)
parser.add_argument("-rf",
                    "--results-file",
                    default = "results.plk",
                    help = "results file",
                    type = str,
                    required = False)
parser.add_argument("-s",
                    "--subject",
                    default = 0,
                    type = int,
                    help = "subject")
parser.add_argument("-r",
                    "--run",
                    default = "none",
                    type = str,
                    help = "run")
parser.add_argument("-ed",
                    "--eeg-dataset",
                    help = "EEG dataset path")
parser.add_argument("-sp",
                    "--splits-path",
                    help = "splits path")
parser.add_argument("-ip",
                    "--image-path",
                    help = "image path")
parser.add_argument("-f",
                    "--fold",
                    default = 5,
                    help = "number of folds",
                    type = int,
                    required = False)
# Training options
parser.add_argument("-b",
                    "--batch_size",
                    default = 16,
                    type = int,
                    help = "batch size")
parser.add_argument("-o",
                    "--optim",
                    default = "Adam",
                    help = "optimizer")
parser.add_argument("-clr",
                    "--joint-learning-rate",
                    default = 0.000001,
                    type = float,
                    help = "joint learning rate")
parser.add_argument("-eelr",
                    "--eeg-encoder-learning-rate",
                    default = 0.001,
                    type = float,
                    help = "EEG encoder learning rate")
parser.add_argument("-elr",
                    "--eeg-learning-rate",
                    default = 0.00001,
                    type = float,
                    help = "EEG learning rate")
parser.add_argument("-ilr",
                    "--image-learning-rate",
                    default = 0.000001,
                    type = float,
                    help = "image learning rate")
parser.add_argument("-ce",
                    "--joint-epochs",
                    default = 100,
                    type = int,
                    help = "joint training epochs")
parser.add_argument("-eee",
                    "--eeg-encoder-epochs",
                    default = 100,
                    type = int,
                    help = "EEG encoder training epochs")
parser.add_argument("-ee",
                    "--eeg-epochs",
                    default = 500,
                    type = int,
                    help = "EEG training epochs")
parser.add_argument("-ie",
                    "--image-epochs",
                    default = 500,
                    type = int,
                    help = "image training epochs")
parser.add_argument("-mt",
                    "--max-tries",
                    default = 10,
                    type = int,
                    help = "max tries")
parser.add_argument("-gpu",
                    "--GPUindex",
                    default = 0,
                    type = int,
                    help = "gpu index")
parser.add_argument("-ca",
                    "--cached",
                    default = False,
                    help = "cached",
                    action = "store_true")
parser.add_argument("-l",
                    "--large",
                    default = False,
                    help = "large",
                    action = "store_true")
parser.add_argument("-m",
                    "--mode",
                    default = "joint",
                    help = "mode",
                    type = str,
                    required = False)
parser.add_argument("--no-pretraining",
                    default = False,
                    help = "disable pretraining",
                    action = "store_true")
# Backend options
parser.add_argument("--no-cuda",
                    default = False,
                    help = "disable CUDA",
                    action = "store_true")
parser.add_argument("-c",
                    "--classifier",
                    required = True,
                    help = "inception_v3/resnet101/densenet161/alexnet")

# Parse arguments
opt = parser.parse_args()

import torch
import random
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from p_values import *
import os
import pickle as pkl

torch.manual_seed(12)
torch.cuda.manual_seed(12)
np.random.seed(12)
random.seed(12)
torch.backends.cudnn.deterministic = True
torch.set_num_threads(1)

def read_images(eeg_signals_path, image_model, image_path):
    if image_model=="inception_v3":
       input_size = 299
    else:
       input_size = 224
    positive_transform = transforms.Compose([
        transforms.Resize(int(1.1*input_size)),
        transforms.TenCrop(input_size),
        transforms.Lambda(
            lambda crops: [transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    mean = [0.485, 0.456, 0.406],
                    std = [0.229, 0.224, 0.225])])(crop) for crop in crops])])
    negative_transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406],
                             std = [0.229, 0.224, 0.225])])
    eeg_signals = torch.load(eeg_signals_path)
    global images
    images = []
    for i in range(len(eeg_signals["images"])):
       image_name = eeg_signals["images"][i]
       image = Image.open(image_path+"/"+image_name+".JPEG").convert('RGB')
       images.append([negative_transform(image)]+positive_transform(image))

class Dataset_for_joint_training:

    def __init__(self,
                 eeg_signals_path,
                 splits_path,
                 split_num,
                 split_name,
                 eeg_encoder,
                 image_encoder,
                 opt):
        self.large = opt.large
        self.cached = opt.cached
        self.max_tries = opt.max_tries
        eeg_signals = torch.load(eeg_signals_path)
        splits = torch.load(splits_path)
        self.data = eeg_signals["dataset"]
        self.split_idx = [i for i in splits["splits"][split_num][split_name]
                          if 480<=self.data[i]["eeg"].size(1)]
        try:
            self.means = eeg_signals["means"]
            self.stddevs = eeg_signals["stddevs"]
        except:
            pass
        if self.large:
            self.size = 10*len(self.split_idx)
        else:
            self.size = len(self.split_idx)
        self.eeg_encoder = eeg_encoder
        self.image_encoder = image_encoder
        self.eeg_encodings = None
        self.image_encodings = None

    def compute_cache(self):
        self.eeg_encodings = []
        for i in range(len(self.split_idx)):
            sample = self.data[self.split_idx[i]]
            try:
                eeg = ((sample["eeg"].float()-self.means)/self.stddevs).t()
            except:
                eeg = sample["eeg"].float().t()
            # WARNING: This is different.
            eeg = eeg[20:460, :]
            eeg_unsqueeze = eeg.contiguous().unsqueeze(0)
            if not opt.no_cuda:
                eeg_unsqueeze = eeg_unsqueeze.cuda(opt.GPUindex, async = True)
            with torch.no_grad():
                self.eeg_encodings.append(self.eeg_encoder(eeg_unsqueeze))
        self.image_encodings = []
        for i in range(len(images)):
            self.image_encodings.append([])
            index = sample["image"]
            for j in range(len(images[index])):
                image = images[index][j]
                image_unsqueeze = image.unsqueeze(0)
                if not opt.no_cuda:
                    image_unsqueeze = image_unsqueeze.cuda(
                        opt.GPUindex, async = True)
                with torch.no_grad():
                    self.image_encodings[i].append(
                        self.image_encoder(image_unsqueeze))

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        if self.large:
            sample = self.data[self.split_idx[i//10]]
            positive_crop_index = i%10+1
        else:
            sample = self.data[self.split_idx[i]]
            positive_crop_index = random.randrange(1, 11)
        try:
            eeg = ((sample["eeg"].float()-self.means)/self.stddevs).t()
        except:
            eeg = sample["eeg"].float().t()
        # WARNING: This is different.
        eeg = eeg[20:460, :]
        positive_index = sample["image"]
        positive_image = images[positive_index][positive_crop_index]
        if self.cached:
            eeg_encoding = self.eeg_encodings[i]
            positive_image_encoding = (
                self.image_encodings[positive_index][positive_crop_index])
        else:
            eeg_unsqueeze = eeg.contiguous().unsqueeze(0)
            if not opt.no_cuda:
                eeg_unsqueeze = eeg_unsqueeze.cuda(opt.GPUindex, async = True)
            with torch.no_grad():
                eeg_encoding = self.eeg_encoder(eeg_unsqueeze)
            positive_image_unsqueeze = positive_image.unsqueeze(0)
            if not opt.no_cuda:
                positive_image_unsqueeze = positive_image_unsqueeze.cuda(
                    opt.GPUindex, async = True)
            with torch.no_grad():
                positive_image_encoding = (
                    self.image_encoder(positive_image_unsqueeze))
        # https://discuss.pytorch.org/t/dot-product-batch-wise/9746
        positive_compatability = (
                (eeg_encoding*positive_image_encoding).sum(1))[0]
        indices = [j for j in range(len(images)) if j!=i]
        indices = random.sample(indices, len(indices))
        tries = len(indices)
        if not self.cached:
            tries = min(self.max_tries, tries)
        for j in range(tries):
            negative_index = indices[j]
            negative_image = images[negative_index][0]
            if self.cached:
                negative_image_encoding = (
                    self.image_encodings[negative_index][0])
            else:
                negative_image_unsqueeze = negative_image.unsqueeze(0)
                if not opt.no_cuda:
                    negative_image_unsqueeze = negative_image.unsqueeze(0).cuda(
                        opt.GPUindex, async = True)
                with torch.no_grad():
                    negative_image_encoding = (
                        self.image_encoder(negative_image_unsqueeze))
            # https://discuss.pytorch.org/t/dot-product-batch-wise/9746
            negative_compatability = (
                (eeg_encoding*negative_image_encoding).sum(1))[0]
            if negative_compatability>positive_compatability :
                break
        return eeg, positive_image, negative_image

class Dataset_for_training_eeg_classifier:

    def __init__(self,
                 eeg_signals_path,
                 classes,
                 splits_path,
                 split_num,
                 split_name):
        eeg_signals = torch.load(eeg_signals_path)
        splits = torch.load(splits_path)
        self.data = eeg_signals["dataset"]
        self.split_idx = [i for i in splits["splits"][split_num][split_name]
                          if 480<=self.data[i]["eeg"].size(1)]
        try:
            self.means = eeg_signals["means"]
            self.stddevs = eeg_signals["stddevs"]
        except:
            pass
        self.size = len(self.split_idx)
        self.classes = classes

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        sample = self.data[self.split_idx[i]]
        try:
            eeg = ((sample["eeg"].float()-self.means)/self.stddevs).t()
        except:
            eeg = sample["eeg"].float().t()
        # WARNING: This is different.
        eeg = eeg[20:460, :]
        label = self.classes.index(sample["label"])
        return eeg, label

class Dataset_for_training_image_classifier:

    def __init__(self,
                 eeg_signals_path,
                 classes,
                 splits_path,
                 split_num,
                 split_name):
        self.large = opt.large
        eeg_signals = torch.load(eeg_signals_path)
        splits = torch.load(splits_path)
        self.data = eeg_signals["dataset"]
        self.split_idx = [i for i in splits["splits"][split_num][split_name]
                          if 480<=self.data[i]["eeg"].size(1)]
        if self.large:
            self.size = 10*len(self.split_idx)
        else:
            self.size = len(self.split_idx)
        self.classes = classes

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        if self.large:
            sample = self.data[self.split_idx[i//10]]
            positive_crop_index = i%10+1
        else:
            sample = self.data[self.split_idx[i]]
            positive_crop_index = random.randrange(1, 11)
        positive_index = sample["image"]
        positive_image = images[positive_index][positive_crop_index]
        label = self.classes.index(sample["label"])
        return positive_image, label

class Dataset_for_validation:

    def __init__(self,
                 eeg_signals_path,
                 classes,
                 splits_path,
                 split_num,
                 split_name):
        eeg_signals = torch.load(eeg_signals_path)
        splits = torch.load(splits_path)
        self.data = eeg_signals["dataset"]
        self.split_idx = [i for i in splits["splits"][split_num][split_name]
                          if 480<=self.data[i]["eeg"].size(1)]
        try:
            self.means = eeg_signals["means"]
            self.stddevs = eeg_signals["stddevs"]
        except:
            pass
        self.size = len(self.split_idx)
        self.classes = classes

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        sample = self.data[self.split_idx[i]]
        try:
            eeg = ((sample["eeg"].float()-self.means)/self.stddevs).t()
        except:
            eeg = sample["eeg"].float().t()
        # WARNING: This is different.
        eeg = eeg[20:460, :]
        positive_index = sample["image"]
        positive_image = images[positive_index][0]
        label = self.classes.index(sample["label"])
        return eeg, positive_image, label

class EEGChannelNet(nn.Module):

    def __init__(self, spatial, temporal):
        super(EEGChannelNet, self).__init__()
        self.temporal_layers = []
        self.temporal_layers.append(nn.Sequential(nn.Conv2d(in_channels = 1,
                                    out_channels = 10,
                                    kernel_size = (1, 33),
                                    stride = (1, 2),
                                    dilation = (1, 1),
                                    padding = (0, 16)),
                                    nn.BatchNorm2d(10),
                                    nn.ReLU()))
        self.temporal_layers.append(nn.Sequential(nn.Conv2d(in_channels = 1,
                                    out_channels = 10,
                                    kernel_size = (1, 33),
                                    stride = (1, 2),
                                    dilation = (1, 2),
                                    padding = (0, 32)),
                                    nn.BatchNorm2d(10),
                                    nn.ReLU()))
        self.temporal_layers.append(nn.Sequential(nn.Conv1d(in_channels = 1,
                                    out_channels = 10,
                                    kernel_size = (1, 33),
                                    stride = (1, 2),
                                    dilation = (1, 4),
                                    padding = (0, 64)),
                                    nn.BatchNorm2d(10),
                                    nn.ReLU()))
        self.temporal_layers.append(nn.Sequential(nn.Conv1d(in_channels = 1,
                                    out_channels = 10,
                                    kernel_size = (1, 33),
                                    stride = (1, 2),
                                    dilation = (1, 8),
                                    padding = (0, 128)),
                                    nn.BatchNorm2d(10),
                                    nn.ReLU()))
        self.temporal_layers.append(nn.Sequential(nn.Conv1d(in_channels = 1,
                                    out_channels = 10,
                                    kernel_size = (1, 33),
                                    stride = (1, 2),
                                    dilation = (1, 16),
                                    padding = (0, 256)),
                                    nn.BatchNorm2d(10),
                                    nn.ReLU()))
        self.spatial_layers = []
        self.spatial_layers.append(nn.Sequential(nn.Conv2d(in_channels = 50,
                                   out_channels = 50,
                                   kernel_size = (128, 1),
                                   stride = (2, 1),
                                   padding = (63, 0)),
                                   nn.BatchNorm2d(50),
                                   nn.ReLU()))
        self.spatial_layers.append(nn.Sequential(nn.Conv2d(in_channels = 50,
                                   out_channels = 50,
                                   kernel_size = (64, 1),
                                   stride = (2, 1),
                                   padding = (31, 0)),
                                   nn.BatchNorm2d(50),
                                   nn.ReLU()))
        self.spatial_layers.append(nn.Sequential(nn.Conv2d(in_channels = 50,
                                   out_channels = 50,
                                   kernel_size = (32, 1),
                                   stride = (2, 1),
                                   padding = (15, 0)),
                                   nn.BatchNorm2d(50),
                                   nn.ReLU()))
        self.spatial_layers.append(nn.Sequential(nn.Conv2d(in_channels = 50,
                                   out_channels = 50,
                                   kernel_size = (16, 1),
                                   stride = (2, 1),
                                   padding = (7, 0)),
                                   nn.BatchNorm2d(50),
                                   nn.ReLU()))
        self.residual_layers = []
        self.residual_layers.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 2,
                                    padding = 1),
                                    nn.BatchNorm2d(200),
                                    nn.ReLU(),
                                    nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 1,
                                    padding = 1),
                                    nn.BatchNorm2d(200)))
        self.residual_layers.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 2,
                                    padding = 1),
                                    nn.BatchNorm2d(200),
                                    nn.ReLU(),
                                    nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 1,
                                    padding = 1),
                                    nn.BatchNorm2d(200)))
        self.residual_layers.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 2,
                                    padding = 1),
                                    nn.BatchNorm2d(200),
                                    nn.ReLU(),
                                    nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 1,
                                    padding = 1),
                                    nn.BatchNorm2d(200)))
        self.residual_layers.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 2,
                                    padding = 1),
                                    nn.BatchNorm2d(200),
                                    nn.ReLU(),
                                    nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 1,
                                    padding = 1),
                                    nn.BatchNorm2d(200)))
        self.shortcuts = []
        self.shortcuts.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                              out_channels = 200,
                              kernel_size = 1,
                              stride = 2),
                              nn.BatchNorm2d(200)))
        self.shortcuts.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                              out_channels = 200,
                              kernel_size = 1,
                              stride = 2),
                              nn.BatchNorm2d(200)))
        self.shortcuts.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                              out_channels = 200,
                              kernel_size = 1,
                              stride = 2),
                              nn.BatchNorm2d(200)))
        self.shortcuts.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                              out_channels = 200,
                              kernel_size = 1,
                              stride = 2),
                              nn.BatchNorm2d(200)))
        spatial_kernel = 3
        temporal_kernel = 3
        if spatial == 128:
            spatial_kernel = 3
        elif spatial==96:
            spatial_kernel = 3
        elif spatial==64:
            spatial_kernel = 2
        else:
            spatial_kernel = 1
        if temporal == 1024:
            temporal_kernel = 3
        elif temporal == 512:
            temporal_kernel = 3
        elif temporal == 440:
            temporal_kernel = 3
        elif temporal == 50:
            temporal_kernel = 2
        self.final_conv = nn.Conv2d(in_channels = 200,
                                    out_channels = 50,
                                    kernel_size = (spatial_kernel,
                                                   temporal_kernel),
                                    stride = 1,
                                    dilation = 1,
                                    padding = 0)
        spatial_sizes = [128, 96, 64, 32, 16, 8]
        spatial_outs = [2, 1, 1, 1, 1, 1]
        temporal_sizes = [1024, 512, 440, 256, 200, 128, 100, 50]
        temporal_outs = [30, 14, 12, 6, 5, 2, 2, 1]
        inp_size = (50*
                    spatial_outs[spatial_sizes.index(spatial)]*
                    temporal_outs[temporal_sizes.index(temporal)])
        self.fc1 = nn.Linear(inp_size, 1000)

    def forward(self, x):
        x = x.unsqueeze(0).permute(1, 0, 3, 2)
        y = []
        for i in range(5):
            y.append(self.temporal_layers[i](x))
        x = torch.cat(y, 1)
        y=[]
        for i in range(4):
            y.append(self. spatial_layers[i](x))
        x = torch.cat(y, 1)
        for i in range(4):
            x = F.relu(self.shortcuts[i](x)+self.residual_layers[i](x))
        x = self.final_conv(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        return x

    def cuda(self, gpuIndex):
        for i in range(len(self.temporal_layers)):
            self.temporal_layers[i] = self.temporal_layers[i].cuda(gpuIndex)
        for i in range(len(self.spatial_layers)):
            self.spatial_layers[i] = self.spatial_layers[i].cuda(gpuIndex)
        for i in range(len(self.residual_layers)):
            self.residual_layers[i] = self.residual_layers[i].cuda(gpuIndex)
        for i in range(len(self.shortcuts)):
            self.shortcuts[i] = self.shortcuts[i].cuda(gpuIndex)
        self.final_conv = self.final_conv.cuda(gpuIndex)
        self.fc1 = self.fc1.cuda(gpuIndex)
        return self

class encoding_classifier(nn.Module):

    def __init__(self):
        super(encoding_classifier, self).__init__()
        self.fc2 = nn.Linear(1000, 40)

    def forward(self, x):
        # https://github.com/adambielski/siamese-triplet/networks.py
        # has the EmbeddingNet end in FC and ClassificationNet start with
        # PReLU. That way the embedding need not be in the first quadrant.
        # https://github.com/pytorch/vision/blob/master/torchvision/models/
        # Inception3, ResNet, DenseNet, and AlexNet forward all end with FC.
        # https://github.com/adambielski/siamese-triplet/networks.py
        # has ClassificationNet end with log softmax. But
        # https://github.com/adambielski/siamese-triplet/blob/master/Experiments_MNIST.ipynb
        # has Baseline: Classification with softmax
        # has loss_fn = torch.nn.NLLLoss()
        # whereas we use F.cross_entropy.
        x = F.relu(x)
        x = self.fc2(x)
        return x

    def cuda(self, gpuIndex):
        self.fc2 = self.fc2.cuda(gpuIndex)
        return self

def trainer(channel,
            length,
            eeg_dataset,
            splits_path,
            classes,
            split_num,
            image_model,
            image_path,
            batch_size,
            opt):
    eeg_encoder = EEGChannelNet(channel, length)
    if image_model=="inception_v3":
        image_encoder = models.inception_v3(pretrained = not opt.no_pretraining,
                                            aux_logits = False)
    elif image_model=="resnet101":
        image_encoder = models.resnet101(pretrained = not opt.no_pretraining)
    elif image_model=="densenet161":
        image_encoder = models.densenet161(pretrained = not opt.no_pretraining)
    elif image_model=="alexnet":
        image_encoder = models.alexnet(pretrained = not opt.no_pretraining)
    eeg_classifier = encoding_classifier()
    image_classifier = encoding_classifier()
    joint_optimizer = getattr(torch.optim,
                              opt.optim)([
                                  {"params": eeg_encoder.parameters()},
                                  {"params": image_encoder.parameters()}],
                                         lr = opt.joint_learning_rate)
    eeg_optimizer = getattr(torch.optim,
                            opt.optim)([
                                {"params": eeg_encoder.parameters()},
                                {"params": eeg_classifier.parameters()}],
                                       lr = opt.eeg_encoder_learning_rate)
    eeg_classifier_optimizer = getattr(torch.optim,
                                       opt.optim)(eeg_classifier.parameters(),
                                                  lr = opt.eeg_learning_rate)
    image_classifier_optimizer = getattr(torch.optim,
                                         opt.optim)(
                                             image_classifier.parameters(),
                                             lr = opt.image_learning_rate)
    if not opt.no_cuda:
        eeg_encoder.cuda(opt.GPUindex)
        image_encoder.cuda(opt.GPUindex)
        eeg_classifier.cuda(opt.GPUindex)
        image_classifier.cuda(opt.GPUindex)
    joint_dataset = Dataset_for_joint_training(eeg_dataset,
                                               splits_path,
                                               split_num,
                                               "train",
                                               eeg_encoder,
                                               image_encoder,
                                               opt)
    joint_loader = DataLoader(joint_dataset,
                              batch_size = batch_size,
                              drop_last = False,
                              shuffle = True)
    eeg_classifier_loader = DataLoader(
        Dataset_for_training_eeg_classifier(eeg_dataset,
                                            classes,
                                            splits_path,
                                            split_num,
                                            "train"),
        batch_size = batch_size,
        drop_last = False,
        shuffle = True)
    image_classifier_loader = DataLoader(
        Dataset_for_training_image_classifier(eeg_dataset,
                                              classes,
                                              splits_path,
                                              split_num,
                                              "train"),
        batch_size = batch_size,
        drop_last = False,
        shuffle = True)
    validation_loaders = {split:DataLoader(
        Dataset_for_validation(eeg_dataset,
                               classes,
                               splits_path,
                               split_num,
                               split),
        batch_size = batch_size,
        drop_last = False,
        shuffle = True) for split in ["train", "val", "test"]}
    if opt.mode=="joint":
        # Jointly train EEG and image encoders
        for epoch in range(1, opt.joint_epochs+1):
            if opt.cached:
                joint_dataset.compute_cache()
            training_loss = 0.0
            training_samples = 0
            eeg_encoder.train()
            image_encoder.train()
            # We don't know what Palazzo et al. (2020) do. Here we enlarge the
            # positive image by 1.1 and ten crop. And we potentially process all
            # ten crops in different batches. We don't enlarge and ten crop the
            # negative images. Each ten crop of each epoch gets a distinct
            # negative image. The negative images change every epoch.
            for _, (eeg, positive_image, negative_image) in enumerate(
                    joint_loader):
                if not opt.no_cuda:
                    eeg = eeg.cuda(opt.GPUindex, async = True)
                    positive_image = positive_image.cuda(
                        opt.GPUindex, async = True)
                    negative_image = negative_image.cuda(
                        opt.GPUindex, async = True)
                eeg = Variable(eeg)
                positive_image = Variable(positive_image)
                negative_image = Variable(negative_image)
                eeg_encoding = eeg_encoder(eeg)
                positive_image_encoding = image_encoder(positive_image)
                negative_image_encoding = image_encoder(negative_image)
                # https://discuss.pytorch.org/t/dot-product-batch-wise/9746
                positive_compatability = (
                    (eeg_encoding*positive_image_encoding).sum(1))
                negative_compatability = (
                    (eeg_encoding*negative_image_encoding).sum(1))
                loss = F.relu(
                    negative_compatability-positive_compatability).sum(0)
                training_loss += loss.item()*eeg.size(0)
                training_samples += eeg.size(0)
                joint_optimizer.zero_grad()
                loss.backward()
                joint_optimizer.step()
            print "joint EEG and image encoder training epoch %d loss %g with %d samples"%(epoch, training_loss, training_samples)
            if training_loss==0.0:
                break
    # Compute EEG and image encodings on validation and test sets.
    eeg_encodings = []
    image_encodings = []
    targets = []
    for split in ("val", "test"):
        eeg_encoder.eval()
        image_encoder.eval()
         # We don't know what Palazzo et al. (2020) do. Here we don't enlarge
         # by 1.1 and ten crop.
        for _, (eeg, image, target) in enumerate(validation_loaders[split]):
            if not opt.no_cuda:
                eeg = eeg.cuda(opt.GPUindex, async = True)
                image = image.cuda(opt.GPUindex, async = True)
                target = target.cuda(opt.GPUindex, async = True)
            with torch.no_grad():
                eeg_encodings += eeg_encoder(eeg).tolist()
                image_encodings += image_encoder(image).tolist()
                targets += target.tolist()
    return eeg_encodings, image_encodings, targets

def analysis(channel,
             length,
             eeg_dataset,
             splits_path,
             classes,
             fold,
             image_model,
             image_path,
             batch_size,
             opt):
    read_images(eeg_dataset, image_model, image_path)
    eeg_encodings = []
    image_encodings = []
    targets = []
    for split_num in range(fold):
        eeg_encoding, image_encoding, target = trainer(channel,
                                                       length,
                                                       eeg_dataset,
                                                       splits_path,
                                                       classes,
                                                       split_num,
                                                       image_model,
                                                       image_path,
                                                       batch_size,
                                                       opt)
        eeg_encodings += eeg_encoding
        image_encodings += image_encoding
        targets += target
    f = open(opt.results_file, "wb")
    pkl.dump((eeg_encodings, image_encodings, targets), f)
    f.close()

if opt.iv=="spampinato":
    length = 440
    channel = 128
    n_classes = 40
elif opt.iv=="image":
    length = 440
    channel = 96
    n_classes = 40
classes = range(n_classes)

analysis(channel,
         length,
         opt.eeg_dataset,
         opt.splits_path,
         classes,
         opt.fold,
         opt.classifier,
         opt.image_path,
         opt.batch_size,
         opt)
