from __future__ import division
import argparse
parser = argparse.ArgumentParser(description = "Template")
# Dataset options
parser.add_argument("-iv",
                    "--iv",
                    help = "image/video",
                    type = str,
                    required = True)
parser.add_argument("-rf",
                    "--results-file",
                    default = "results.plk",
                    help = "results file",
                    type = str,
                    required = False)
parser.add_argument("-s",
                    "--subject",
                    default = 0,
                    type = int,
                    help = "subject")
parser.add_argument("-r",
                    "--run",
                    default = "none",
                    type = str,
                    help = "run")
parser.add_argument("-ed",
                    "--eeg-dataset",
                    help = "EEG dataset path")
parser.add_argument("-sp",
                    "--splits-path",
                    help = "splits path")
parser.add_argument("-ip",
                    "--image-path",
                    help = "image path")
parser.add_argument("-f",
                    "--fold",
                    default = 5,
                    help = "number of folds",
                    type = int,
                    required = False)
# Training options
parser.add_argument("-b",
                    "--batch_size",
                    default = 16,
                    type = int,
                    help = "batch size")
parser.add_argument("-o",
                    "--optim",
                    default = "Adam",
                    help = "optimizer")
parser.add_argument("-elr",
                    "--eeg-learning-rate",
                    default = 0.00001,
                    type = float,
                    help = "EEG learning rate")
parser.add_argument("-ilr",
                    "--image-learning-rate",
                    default = 0.000001,
                    type = float,
                    help = "image learning rate")
parser.add_argument("-ee",
                    "--eeg-epochs",
                    default = 100,
                    type = int,
                    help = "EEG training epochs")
parser.add_argument("-ie",
                    "--image-epochs",
                    default = 100,
                    type = int,
                    help = "image training epochs")
parser.add_argument("-mt",
                    "--max-tries",
                    default = 10,
                    type = int,
                    help = "max tries")
parser.add_argument("-gpu",
                    "--GPUindex",
                    default = 0,
                    type = int,
                    help = "gpu index")
parser.add_argument("-ca",
                    "--cached",
                    default = False,
                    help = "cached",
                    action = "store_true")
parser.add_argument("-l",
                    "--large",
                    default = False,
                    help = "large",
                    action = "store_true")
# Backend options
parser.add_argument("--no-cuda",
                    default = False,
                    help = "disable CUDA",
                    action = "store_true")
parser.add_argument("-c",
                    "--classifier",
                    required = True,
                    help = "inception_v3/resnet101/densenet161/alexnet")

# Parse arguments
opt = parser.parse_args()

import torch
import random
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from p_values import *
import os
import pickle as pkl

torch.manual_seed(12)
torch.cuda.manual_seed(12)
np.random.seed(12)
random.seed(12)
torch.backends.cudnn.deterministic = True
torch.set_num_threads(1)

mapping = [c[0:len(c)-1] for c in open("mapping.txt", "r").readlines()]
mapping_classes = ["" for i in range(1000)]
for pair in mapping:
    mapping_classes[int(pair.split(" ")[1])] = pair.split(" ")[0]

if "eeg_signals_128_sequential_band_all_with_mean_std" not in opt.eeg_dataset:
    spampinato_classes = ["n02106662",
                          "n02124075",
                          "n02281787",
                          "n02389026",
                          "n02492035",
                          "n02504458",
                          "n02510455",
                          "n02607072",
                          "n02690373",
                          "n02906734",
                          "n02951358",
                          "n02992529",
                          "n03063599",
                          "n03100240",
                          "n03180011",
                          "n03197337",
                          "n03272010",
                          "n03272562",
                          "n03297495",
                          "n03376595",
                          "n03445777",
                          "n03452741",
                          "n03584829",
                          "n03590841",
                          "n03709823",
                          "n03773504",
                          "n03775071",
                          "n03792782",
                          "n03792972",
                          "n03877472",
                          "n03888257",
                          "n03982430",
                          "n04044716",
                          "n04069434",
                          "n04086273",
                          "n04120489",
                          "n07753592",
                          "n07873807",
                          "n11939491",
                          "n13054560"]
else:
    spampinato_classes = ["n02389026",
                          "n03888257",
                          "n03584829",
                          "n02607072",
                          "n03297495",
                          "n03063599",
                          "n03792782",
                          "n04086273",
                          "n02510455",
                          "n11939491",
                          "n02951358",
                          "n02281787",
                          "n02106662",
                          "n04120489",
                          "n03590841",
                          "n02992529",
                          "n03445777",
                          "n03180011",
                          "n02906734",
                          "n07873807",
                          "n03773504",
                          "n02492035",
                          "n03982430",
                          "n03709823",
                          "n03100240",
                          "n03376595",
                          "n03877472",
                          "n03775071",
                          "n03272010",
                          "n04069434",
                          "n03452741",
                          "n03792972",
                          "n07753592",
                          "n13054560",
                          "n03197337",
                          "n02504458",
                          "n02690373",
                          "n03272562",
                          "n04044716",
                          "n02124075"]

def read_images(eeg_signals_path, image_model, image_path):
    if image_model=="inception_v3":
       input_size = 299
    else:
       input_size = 224
    positive_transform = transforms.Compose([
        transforms.Resize(int(1.1*input_size)),
        transforms.TenCrop(input_size),
        transforms.Lambda(
            lambda crops: [transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    mean = [0.485, 0.456, 0.406],
                    std = [0.229, 0.224, 0.225])])(crop) for crop in crops])])
    negative_transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406],
                             std = [0.229, 0.224, 0.225])])
    eeg_signals = torch.load(eeg_signals_path)
    global images
    images = []
    for i in range(len(eeg_signals["images"])):
       image_name = eeg_signals["images"][i]
       image = Image.open(image_path+"/"+image_name+".JPEG").convert('RGB')
       images.append([negative_transform(image)]+positive_transform(image))

class Dataset_for_joint_training:

    def __init__(self,
                 eeg_signals_path,
                 splits_path,
                 split_num,
                 split_name,
                 eeg_encoder,
                 image_encoder,
                 opt):
        self.large = opt.large
        self.cached = opt.cached
        self.max_tries = opt.max_tries
        eeg_signals = torch.load(eeg_signals_path)
        splits = torch.load(splits_path)
        self.data = eeg_signals["dataset"]
        self.split_idx = [i for i in splits["splits"][split_num][split_name]
                          if 480<=self.data[i]["eeg"].size(1)]
        try:
            self.means = eeg_signals["means"]
            self.stddevs = eeg_signals["stddevs"]
        except:
            pass
        if self.large:
            self.size = 10*len(self.split_idx)
        else:
            self.size = len(self.split_idx)
        self.eeg_encoder = eeg_encoder
        self.image_encoder = image_encoder
        self.eeg_encodings = None
        self.image_encodings = None

    def compute_cache(self):
        self.eeg_encodings = []
        for i in range(len(self.split_idx)):
            sample = self.data[self.split_idx[i]]
            try:
                eeg = ((sample["eeg"].float()-self.means)/self.stddevs).t()
            except:
                eeg = sample["eeg"].float().t()
            # WARNING: This is different.
            eeg = eeg[20:460, :]
            eeg_unsqueeze = eeg.contiguous().unsqueeze(0)
            if not opt.no_cuda:
                eeg_unsqueeze = eeg_unsqueeze.cuda(opt.GPUindex, async = True)
            with torch.no_grad():
                self.eeg_encodings.append(self.eeg_encoder(eeg_unsqueeze))
        self.image_encodings = []
        for i in range(len(images)):
            self.image_encodings.append([])
            index = sample["image"]
            for j in range(len(images[index])):
                image = images[index][j]
                image_unsqueeze = image.unsqueeze(0)
                if not opt.no_cuda:
                    image_unsqueeze = image_unsqueeze.cuda(
                        opt.GPUindex, async = True)
                with torch.no_grad():
                    self.image_encodings[i].append(
                        self.image_encoder(image_unsqueeze))

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        if self.large:
            sample = self.data[self.split_idx[i//10]]
            positive_crop_index = i%10+1
        else:
            sample = self.data[self.split_idx[i]]
            positive_crop_index = random.randrange(1, 11)
        try:
            eeg = ((sample["eeg"].float()-self.means)/self.stddevs).t()
        except:
            eeg = sample["eeg"].float().t()
        # WARNING: This is different.
        eeg = eeg[20:460, :]
        positive_index = sample["image"]
        positive_image = images[positive_index][positive_crop_index]
        if self.cached:
            eeg_encoding = self.eeg_encodings[i]
            positive_image_encoding = (
                self.image_encodings[positive_index][positive_crop_index])
        else:
            eeg_unsqueeze = eeg.contiguous().unsqueeze(0)
            if not opt.no_cuda:
                eeg_unsqueeze = eeg_unsqueeze.cuda(opt.GPUindex, async = True)
            with torch.no_grad():
                eeg_encoding = self.eeg_encoder(eeg_unsqueeze)
            positive_image_unsqueeze = positive_image.unsqueeze(0)
            if not opt.no_cuda:
                positive_image_unsqueeze = positive_image_unsqueeze.cuda(
                    opt.GPUindex, async = True)
            with torch.no_grad():
                positive_image_encoding = (
                    self.image_encoder(positive_image_unsqueeze))
        # https://discuss.pytorch.org/t/dot-product-batch-wise/9746
        positive_compatability = (
                (eeg_encoding*positive_image_encoding).sum(1))[0]
        indices = [j for j in range(len(images)) if j!=i]
        indices = random.sample(indices, len(indices))
        tries = len(indices)
        if not self.cached:
            tries = min(self.max_tries, tries)
        for j in range(tries):
            negative_index = indices[j]
            negative_image = images[negative_index][0]
            if self.cached:
                negative_image_encoding = (
                    self.image_encodings[negative_index][0])
            else:
                negative_image_unsqueeze = negative_image.unsqueeze(0)
                if not opt.no_cuda:
                    negative_image_unsqueeze = negative_image.unsqueeze(0).cuda(
                        opt.GPUindex, async = True)
                with torch.no_grad():
                    negative_image_encoding = (
                        self.image_encoder(negative_image_unsqueeze))
            # https://discuss.pytorch.org/t/dot-product-batch-wise/9746
            negative_compatability = (
                (eeg_encoding*negative_image_encoding).sum(1))[0]
            if negative_compatability>positive_compatability :
                break
        return eeg, positive_image, negative_image

class Dataset_for_training_eeg_encoder:

    def __init__(self,
                 eeg_signals_path,
                 classes,
                 splits_path,
                 split_num,
                 split_name):
        eeg_signals = torch.load(eeg_signals_path)
        splits = torch.load(splits_path)
        self.data = eeg_signals["dataset"]
        self.split_idx = [i for i in splits["splits"][split_num][split_name]
                          if 480<=self.data[i]["eeg"].size(1)]
        try:
            self.means = eeg_signals["means"]
            self.stddevs = eeg_signals["stddevs"]
        except:
            pass
        self.size = len(self.split_idx)
        self.classes = classes

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        sample = self.data[self.split_idx[i]]
        try:
            eeg = ((sample["eeg"].float()-self.means)/self.stddevs).t()
        except:
            eeg = sample["eeg"].float().t()
        # WARNING: This is different.
        eeg = eeg[20:460, :]
        target = torch.zeros(1000)
        target[mapping_classes.index(spampinato_classes[sample["label"]])] = 1.0
        return eeg, target

class Dataset_for_training_image_encoder:

    def __init__(self,
                 eeg_signals_path,
                 classes,
                 splits_path,
                 split_num,
                 split_name):
        self.large = opt.large
        eeg_signals = torch.load(eeg_signals_path)
        splits = torch.load(splits_path)
        self.data = eeg_signals["dataset"]
        self.split_idx = [i for i in splits["splits"][split_num][split_name]
                          if 480<=self.data[i]["eeg"].size(1)]
        if self.large:
            self.size = 10*len(self.split_idx)
        else:
            self.size = len(self.split_idx)
        self.classes = classes

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        if self.large:
            sample = self.data[self.split_idx[i//10]]
            positive_crop_index = i%10+1
        else:
            sample = self.data[self.split_idx[i]]
            positive_crop_index = random.randrange(1, 11)
        positive_index = sample["image"]
        positive_image = images[positive_index][positive_crop_index]
        target = torch.zeros(1000)
        target[mapping_classes.index(spampinato_classes[sample["label"]])] = 1.0
        return positive_image, target

class Dataset_for_validation:

    def __init__(self,
                 eeg_signals_path,
                 classes,
                 splits_path,
                 split_num,
                 split_name):
        eeg_signals = torch.load(eeg_signals_path)
        splits = torch.load(splits_path)
        self.data = eeg_signals["dataset"]
        self.split_idx = [i for i in splits["splits"][split_num][split_name]
                          if 480<=self.data[i]["eeg"].size(1)]
        try:
            self.means = eeg_signals["means"]
            self.stddevs = eeg_signals["stddevs"]
        except:
            pass
        self.size = len(self.split_idx)
        self.classes = classes

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        sample = self.data[self.split_idx[i]]
        try:
            eeg = ((sample["eeg"].float()-self.means)/self.stddevs).t()
        except:
            eeg = sample["eeg"].float().t()
        # WARNING: This is different.
        eeg = eeg[20:460, :]
        positive_index = sample["image"]
        positive_image = images[positive_index][0]
        label = self.classes.index(sample["label"])
        return eeg, positive_image, label

class EEGChannelNet(nn.Module):

    def __init__(self, spatial, temporal):
        super(EEGChannelNet, self).__init__()
        self.temporal_layers = []
        self.temporal_layers.append(nn.Sequential(nn.Conv2d(in_channels = 1,
                                    out_channels = 10,
                                    kernel_size = (1, 33),
                                    stride = (1, 2),
                                    dilation = (1, 1),
                                    padding = (0, 16)),
                                    nn.BatchNorm2d(10),
                                    nn.ReLU()))
        self.temporal_layers.append(nn.Sequential(nn.Conv2d(in_channels = 1,
                                    out_channels = 10,
                                    kernel_size = (1, 33),
                                    stride = (1, 2),
                                    dilation = (1, 2),
                                    padding = (0, 32)),
                                    nn.BatchNorm2d(10),
                                    nn.ReLU()))
        self.temporal_layers.append(nn.Sequential(nn.Conv1d(in_channels = 1,
                                    out_channels = 10,
                                    kernel_size = (1, 33),
                                    stride = (1, 2),
                                    dilation = (1, 4),
                                    padding = (0, 64)),
                                    nn.BatchNorm2d(10),
                                    nn.ReLU()))
        self.temporal_layers.append(nn.Sequential(nn.Conv1d(in_channels = 1,
                                    out_channels = 10,
                                    kernel_size = (1, 33),
                                    stride = (1, 2),
                                    dilation = (1, 8),
                                    padding = (0, 128)),
                                    nn.BatchNorm2d(10),
                                    nn.ReLU()))
        self.temporal_layers.append(nn.Sequential(nn.Conv1d(in_channels = 1,
                                    out_channels = 10,
                                    kernel_size = (1, 33),
                                    stride = (1, 2),
                                    dilation = (1, 16),
                                    padding = (0, 256)),
                                    nn.BatchNorm2d(10),
                                    nn.ReLU()))
        self.spatial_layers = []
        self.spatial_layers.append(nn.Sequential(nn.Conv2d(in_channels = 50,
                                   out_channels = 50,
                                   kernel_size = (128, 1),
                                   stride = (2, 1),
                                   padding = (63, 0)),
                                   nn.BatchNorm2d(50),
                                   nn.ReLU()))
        self.spatial_layers.append(nn.Sequential(nn.Conv2d(in_channels = 50,
                                   out_channels = 50,
                                   kernel_size = (64, 1),
                                   stride = (2, 1),
                                   padding = (31, 0)),
                                   nn.BatchNorm2d(50),
                                   nn.ReLU()))
        self.spatial_layers.append(nn.Sequential(nn.Conv2d(in_channels = 50,
                                   out_channels = 50,
                                   kernel_size = (32, 1),
                                   stride = (2, 1),
                                   padding = (15, 0)),
                                   nn.BatchNorm2d(50),
                                   nn.ReLU()))
        self.spatial_layers.append(nn.Sequential(nn.Conv2d(in_channels = 50,
                                   out_channels = 50,
                                   kernel_size = (16, 1),
                                   stride = (2, 1),
                                   padding = (7, 0)),
                                   nn.BatchNorm2d(50),
                                   nn.ReLU()))
        self.residual_layers = []
        self.residual_layers.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 2,
                                    padding = 1),
                                    nn.BatchNorm2d(200),
                                    nn.ReLU(),
                                    nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 1,
                                    padding = 1),
                                    nn.BatchNorm2d(200)))
        self.residual_layers.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 2,
                                    padding = 1),
                                    nn.BatchNorm2d(200),
                                    nn.ReLU(),
                                    nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 1,
                                    padding = 1),
                                    nn.BatchNorm2d(200)))
        self.residual_layers.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 2,
                                    padding = 1),
                                    nn.BatchNorm2d(200),
                                    nn.ReLU(),
                                    nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 1,
                                    padding = 1),
                                    nn.BatchNorm2d(200)))
        self.residual_layers.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 2,
                                    padding = 1),
                                    nn.BatchNorm2d(200),
                                    nn.ReLU(),
                                    nn.Conv2d(in_channels = 200,
                                    out_channels = 200,
                                    kernel_size = 3,
                                    stride = 1,
                                    padding = 1),
                                    nn.BatchNorm2d(200)))
        self.shortcuts = []
        self.shortcuts.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                              out_channels = 200,
                              kernel_size = 1,
                              stride = 2),
                              nn.BatchNorm2d(200)))
        self.shortcuts.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                              out_channels = 200,
                              kernel_size = 1,
                              stride = 2),
                              nn.BatchNorm2d(200)))
        self.shortcuts.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                              out_channels = 200,
                              kernel_size = 1,
                              stride = 2),
                              nn.BatchNorm2d(200)))
        self.shortcuts.append(nn.Sequential(nn.Conv2d(in_channels = 200,
                              out_channels = 200,
                              kernel_size = 1,
                              stride = 2),
                              nn.BatchNorm2d(200)))
        spatial_kernel = 3
        temporal_kernel = 3
        if spatial == 128:
            spatial_kernel = 3
        elif spatial==96:
            spatial_kernel = 3
        elif spatial==64:
            spatial_kernel = 2
        else:
            spatial_kernel = 1
        if temporal == 1024:
            temporal_kernel = 3
        elif temporal == 512:
            temporal_kernel = 3
        elif temporal == 440:
            temporal_kernel = 3
        elif temporal == 50:
            temporal_kernel = 2
        self.final_conv = nn.Conv2d(in_channels = 200,
                                    out_channels = 50,
                                    kernel_size = (spatial_kernel,
                                                   temporal_kernel),
                                    stride = 1,
                                    dilation = 1,
                                    padding = 0)
        spatial_sizes = [128, 96, 64, 32, 16, 8]
        spatial_outs = [2, 1, 1, 1, 1, 1]
        temporal_sizes = [1024, 512, 440, 256, 200, 128, 100, 50]
        temporal_outs = [30, 14, 12, 6, 5, 2, 2, 1]
        inp_size = (50*
                    spatial_outs[spatial_sizes.index(spatial)]*
                    temporal_outs[temporal_sizes.index(temporal)])
        self.fc1 = nn.Linear(inp_size, 1000)

    def forward(self, x):
        x = x.unsqueeze(0).permute(1, 0, 3, 2)
        y = []
        for i in range(5):
            y.append(self.temporal_layers[i](x))
        x = torch.cat(y, 1)
        y=[]
        for i in range(4):
            y.append(self. spatial_layers[i](x))
        x = torch.cat(y, 1)
        for i in range(4):
            x = F.relu(self.shortcuts[i](x)+self.residual_layers[i](x))
        x = self.final_conv(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        return x

    def cuda(self, gpuIndex):
        for i in range(len(self.temporal_layers)):
            self.temporal_layers[i] = self.temporal_layers[i].cuda(gpuIndex)
        for i in range(len(self.spatial_layers)):
            self.spatial_layers[i] = self.spatial_layers[i].cuda(gpuIndex)
        for i in range(len(self.residual_layers)):
            self.residual_layers[i] = self.residual_layers[i].cuda(gpuIndex)
        for i in range(len(self.shortcuts)):
            self.shortcuts[i] = self.shortcuts[i].cuda(gpuIndex)
        self.final_conv = self.final_conv.cuda(gpuIndex)
        self.fc1 = self.fc1.cuda(gpuIndex)
        return self

def trainer(channel,
            length,
            eeg_dataset,
            splits_path,
            classes,
            split_num,
            image_model,
            image_path,
            batch_size,
            opt):
    eeg_encoder = EEGChannelNet(channel, length)
    if image_model=="inception_v3":
        image_encoder = models.inception_v3(pretrained = True,
                                            aux_logits = False)
    elif image_model=="resnet101":
        image_encoder = models.resnet101(pretrained = True)
    elif image_model=="densenet161":
        image_encoder = models.densenet161(pretrained = True)
    elif image_model=="alexnet":
        image_encoder = models.alexnet(pretrained = True)
    eeg_optimizer = getattr(torch.optim,
                            opt.optim)([
                                {"params": eeg_encoder.parameters()}],
                                       lr = opt.eeg_learning_rate)
    image_optimizer = getattr(torch.optim,
                              opt.optim)([
                                {"params": image_encoder.parameters()}],
                                       lr = opt.image_learning_rate)
    if not opt.no_cuda:
        eeg_encoder.cuda(opt.GPUindex)
        image_encoder.cuda(opt.GPUindex)
    joint_dataset = Dataset_for_joint_training(eeg_dataset,
                                               splits_path,
                                               split_num,
                                               "train",
                                               eeg_encoder,
                                               image_encoder,
                                               opt)
    joint_loader = DataLoader(joint_dataset,
                              batch_size = batch_size,
                              drop_last = False,
                              shuffle = True)
    eeg_encoder_loader = DataLoader(
        Dataset_for_training_eeg_encoder(eeg_dataset,
                                         classes,
                                         splits_path,
                                         split_num,
                                         "train"),
        batch_size = batch_size,
        drop_last = False,
        shuffle = True)
    image_encoder_loader = DataLoader(
        Dataset_for_training_image_encoder(eeg_dataset,
                                           classes,
                                           splits_path,
                                           split_num,
                                           "train"),
        batch_size = batch_size,
        drop_last = False,
        shuffle = True)
    validation_loaders = {split:DataLoader(
        Dataset_for_validation(eeg_dataset,
                               classes,
                               splits_path,
                               split_num,
                               split),
        batch_size = batch_size,
        drop_last = False,
        shuffle = True) for split in ["train", "val", "test"]}
    # Train EEG encoder
    for epoch in range(1, opt.eeg_epochs+1):
        training_loss = 0.0
        training_samples = 0
        eeg_encoder.train()
        for _, (eeg, target) in enumerate(eeg_encoder_loader):
            if not opt.no_cuda:
                eeg = eeg.cuda(opt.GPUindex, async = True)
                target = target.cuda(opt.GPUindex, async = True)
            eeg = Variable(eeg)
            eeg_encoding = eeg_encoder(eeg)
            target = Variable(target)
            loss = F.mse_loss(eeg_encoding, target)
            training_loss += loss.item()*eeg.size(0)
            training_samples += eeg.size(0)
            eeg_optimizer.zero_grad()
            loss.backward()
            eeg_optimizer.step()
        print "EEG encoder training epoch %d loss %g with %d samples"%(epoch, training_loss, training_samples)
    # Train image encoder
    for epoch in range(1, opt.image_epochs+1):
        training_loss = 0.0
        training_samples = 0
        image_encoder.train()
        for _, (image, target) in enumerate(image_encoder_loader):
            if not opt.no_cuda:
                image = image.cuda(opt.GPUindex, async = True)
                target = target.cuda(opt.GPUindex, async = True)
            image = Variable(image)
            image_encoding = image_encoder(image)
            target = Variable(target)
            loss = F.mse_loss(image_encoding, target)
            training_loss += loss.item()*image.size(0)
            training_samples += image.size(0)
            image_optimizer.zero_grad()
            loss.backward()
            image_optimizer.step()
        print "image encoder training epoch %d loss %g with %d samples"%(epoch, training_loss, training_samples)
    # Compute triplet loss
    if opt.cached:
        joint_dataset.compute_cache()
    training_loss = 0.0
    training_samples = 0
    eeg_encoder.train()
    image_encoder.train()
    # We don't know what Palazzo et al. (2020) do. Here we enlarge the
    # positive image by 1.1 and ten crop. And we potentially process all
    # ten crops in different batches. We don't enlarge and ten crop the
    # negative images. Each ten crop of each epoch gets a distinct
    # negative image. The negative images change every epoch.
    for _, (eeg, positive_image, negative_image) in enumerate(joint_loader):
        if not opt.no_cuda:
            eeg = eeg.cuda(opt.GPUindex, async = True)
            positive_image = positive_image.cuda(
                opt.GPUindex, async = True)
            negative_image = negative_image.cuda(
                opt.GPUindex, async = True)
        with torch.no_grad():
            eeg_encoding = eeg_encoder(eeg)
            positive_image_encoding = image_encoder(positive_image)
            negative_image_encoding = image_encoder(negative_image)
            # https://discuss.pytorch.org/t/dot-product-batch-wise/9746
            positive_compatability = (
                (eeg_encoding*positive_image_encoding).sum(1))
            negative_compatability = (
                (eeg_encoding*negative_image_encoding).sum(1))
            loss = F.relu(
                negative_compatability-positive_compatability).sum(0)
            training_loss += loss.item()*eeg.size(0)
            training_samples += eeg.size(0)
    print "joint EEG and image encoder loss %g with %d samples"%(training_loss, training_samples)
    # Compute EEG and image encodings on validation and test sets.
    eeg_encodings = []
    image_encodings = []
    targets = []
    for split in ("val", "test"):
        eeg_encoder.eval()
        image_encoder.eval()
         # We don't know what Palazzo et al. (2020) do. Here we don't enlarge
         # by 1.1 and ten crop.
        for _, (eeg, image, target) in enumerate(validation_loaders[split]):
            if not opt.no_cuda:
                eeg = eeg.cuda(opt.GPUindex, async = True)
                image = image.cuda(opt.GPUindex, async = True)
                target = target.cuda(opt.GPUindex, async = True)
            with torch.no_grad():
                eeg_encodings += eeg_encoder(eeg).tolist()
                image_encodings += image_encoder(image).tolist()
                targets += target.tolist()
    return eeg_encodings, image_encodings, targets

def analysis(channel,
             length,
             eeg_dataset,
             splits_path,
             classes,
             fold,
             image_model,
             image_path,
             batch_size,
             opt):
    read_images(eeg_dataset, image_model, image_path)
    eeg_encodings = []
    image_encodings = []
    targets = []
    for split_num in range(fold):
        eeg_encoding, image_encoding, target = trainer(channel,
                                                       length,
                                                       eeg_dataset,
                                                       splits_path,
                                                       classes,
                                                       split_num,
                                                       image_model,
                                                       image_path,
                                                       batch_size,
                                                       opt)
        eeg_encodings += eeg_encoding
        image_encodings += image_encoding
        targets += target
    f = open(opt.results_file, "wb")
    pkl.dump((eeg_encodings, image_encodings, targets), f)
    f.close()

if opt.iv=="spampinato":
    length = 440
    channel = 128
    n_classes = 40
elif opt.iv=="image":
    length = 440
    channel = 96
    n_classes = 40
classes = range(n_classes)

analysis(channel,
         length,
         opt.eeg_dataset,
         opt.splits_path,
         classes,
         opt.fold,
         opt.classifier,
         opt.image_path,
         opt.batch_size,
         opt)
