import argparse
import datetime
import sys
import json
from collections import defaultdict
from pathlib import Path
from tempfile import mkdtemp

import numpy as np
import torch
from torch import optim

import models
import objectives_together
from utils import Logger, Timer, save_model_light, save_vars, unpack_data_polymnist, unpack_data_polymnist_train

import os
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from statistics import mean
from test_functions import calculate_inception_features_for_gen_evaluation, calculate_fid, \
    classify_latent_representations
import random

# TensorBoard
from torch.utils.tensorboard import SummaryWriter

from sklearn.linear_model import LogisticRegression

parser = argparse.ArgumentParser(description='Multi-Modal VAEs')
parser.add_argument('--experiment', type=str, default='', metavar='E',
                    help='experiment name')
parser.add_argument('--model', type=str, default='polymnist_5modalities', metavar='M',
                    choices=[s[4:] for s in dir(models) if 'VAE_' in s],
                    help='model name (default: mnist_svhn)')
parser.add_argument('--obj', type=str, default='dreg', metavar='O',
                    choices=['elbo_naive', 'elbo', 'iwae', 'dreg'],
                    help='objective to use (default: elbo)')
parser.add_argument('--K', type=int, default=10, metavar='K',
                    help='number of particles to use for iwae/dreg (default: 10)')
parser.add_argument('--looser', action='store_true', default=True,
                    help='use the looser version of IWAE/DREG')
parser.add_argument('--llik_scaling', type=float, default=0.,
                    help='likelihood scaling for cub images/svhn modality when running in'
                         'multimodal setting, set as 0 to use default value')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                    help='batch size for data (default: 256)')
parser.add_argument('--epochs', type=int, default=50, metavar='E',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--latent-dim-w', type=int, default=32, metavar='L',
                    help='latent dimensionality (default: 20)')
parser.add_argument('--latent-dim-u', type=int, default=32, metavar='L',
                    help='latent dimensionality (default: 20)')
parser.add_argument('--pre-trained', type=str, default="",
                    help='path to pre-trained model (train from scratch if empty)')
parser.add_argument('--learn-prior', action='store_true', default=False,
                    help='learn model prior parameters for w')
parser.add_argument('--logp', action='store_true', default=False,
                    help='estimate tight marginal likelihood on completion')
parser.add_argument('--print-freq', type=int, default=100, metavar='f',
                    help='frequency with which to print stats (default: 0)')
parser.add_argument('--no-analytics', action='store_true', default=False,
                    help='disable plotting analytics')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disable CUDA use')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--w-from-prior', type=str, default='single', metavar='W',
                    choices=['single', 'natural', 'natural-nograd'],
                    help='w for recontruction options (default: single)')
parser.add_argument('--generate-multiple-ws', action='store_true', default=False,
                    help='if True w the model generates samples from same u and only resampling w')
parser.add_argument('--learn-prior-w-polymnist', action='store_true', default=True,
                    help='learn model prior parameters for w')
parser.add_argument('--beta', type=float, default=1.0, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--recon-option', type=str, default='jointprior', metavar='O',
                    choices=['natural', 'stdprior', 'jointprior', 'tunedsingleprior', 'None'],
                    help='cross-recon-option to use')
parser.add_argument('--recon-option-factor', type=float, default=1.0, metavar='F',
                    help='cross-recon-option-factor to use')
parser.add_argument('--tmpdir', type=str, default='../data')
parser.add_argument('--outputdir', type=str, default='../outputs')
# args
args = parser.parse_args()
flags_clf_lr = {'latdimu': args.latent_dim_u,
                'latdimw': args.latent_dim_w}
# random seed
# https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# load args from disk if pretrained model path is given
pretrained_path = ""
if args.pre_trained:
    pretrained_path = args.pre_trained
    args = torch.load(args.pre_trained + '/args.rar')

args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")

modelC = getattr(models, 'VAE_{}'.format(args.model))
model = modelC(args).to(device)

if pretrained_path:
    print('Loading model {} from {}'.format(model.modelName, pretrained_path))
    model.load_state_dict(torch.load(pretrained_path + '/model.rar'))
    model._pz_params = model._pz_params

if not args.experiment:
    args.experiment = model.modelName

# set up run path
runId = str(args.latent_dim_w) + '_' + str(args.latent_dim_u) + '_' + str(args.beta) + '_' + str(args.seed) + \
        '_' + datetime.datetime.now().isoformat()
experiment_dir = Path(os.path.join(args.outputdir, 'experiment', args.experiment))
experiment_dir.mkdir(parents=True, exist_ok=True)
runPath = mkdtemp(prefix=runId, dir=str(experiment_dir))
sys.stdout = Logger('{}/run.log'.format(runPath))
print('Expt:', runPath)
print('RunID:', runId)

NUM_VAES = len(model.vaes)
fid_path = os.path.join(args.tmpdir, 'fids_' + (runPath.rsplit('/')[-1]))
datadir = os.path.join(args.tmpdir, "PolyMNIST")

# save args to run
with open('{}/args.json'.format(runPath), 'w') as fp:
    json.dump(args.__dict__, fp)
# -- also save object because we want to recover these for other things
torch.save(args, '{}/args.rar'.format(runPath))

# TensorBoard
tensorboard_log_dir = args.outputdir + "/runs/" + args.experiment + "/" + runId
writer = SummaryWriter(log_dir=tensorboard_log_dir)

objectives = objectives_together

# preparation for training
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=1e-3, amsgrad=True)

model.setTmpDir(args.tmpdir)

train_loader, test_loader = model.getDataLoaders(args.batch_size, device=device)
objective = getattr(objectives,
                    ('m_' if hasattr(model, 'vaes') else '')
                    + args.obj
                    + ('_looser' if (args.looser and args.obj != 'elbo') else ''))
t_objective = getattr(objectives,
                      ('m_' if hasattr(model, 'vaes') else '') + 'iwae' + ('_test' if hasattr(model, 'vaes') else ''))

# cuda stuff
needs_conversion = not args.cuda
conversion_kwargs = {'map_location': lambda st, loc: st} if needs_conversion else {}


def get_10_mnist_samples(svhnmnist, num_testing_images):
    samples = []
    for i in range(10):
        while True:
            imgs, target = svhnmnist.__getitem__(random.randint(0, num_testing_images - 1))
            img_mnist, img_svhn, img_3, img_4, img_5 = imgs
            if target == i:
                img_mnist = img_mnist.to(device)
                img_svhn = img_svhn.to(device)
                img_3 = img_3.to(device)
                img_4 = img_4.to(device)
                img_5 = img_5.to(device)
                # text = text.to(flags.device);
                # samples.append((img_mnist, img_svhn, text, target))
                samples.append((img_mnist, img_svhn, img_3, img_4, img_5, target))
                break
    outputs = []
    for mod in range(5):
        outputs_mod = [samples[digit][mod] for digit in range(10)]
        outputs.append(torch.stack(outputs_mod, dim=0))
    return outputs


def unpack_and_send_to_device_coherence(dataT, device):
    data, targets = dataT
    for i, data_mod in enumerate(data):
        data[i] = data_mod.to(device)
    return data, targets.to(device)


class Flatten(torch.nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class ClfImg(nn.Module):
    """
    MNIST image-to-digit classifier. Roughly based on the encoder from:
    https://colab.research.google.com/github/smartgeometry-ucl/dl4g/blob/master/variational_autoencoder.ipynb
    """

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(  # input shape (3, 28, 28)
            nn.Conv2d(3, 10, kernel_size=4, stride=2, padding=1),  # -> (10, 14, 14)
            nn.Dropout2d(0.5),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=4, stride=2, padding=1),  # -> (20, 7, 7)
            nn.Dropout2d(0.5),
            nn.ReLU(),
            Flatten(),  # -> (980)
            nn.Linear(980, 128),  # -> (128)
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(128, 10)  # -> (10)
        )

    def forward(self, x):
        h = self.encoder(x)
        return F.log_softmax(h, dim=-1)


clfs = [ClfImg() for idx, modal in enumerate(model.vaes)]
for idx, vae in enumerate(model.vaes):
    clfs[idx].load_state_dict(torch.load("../data/trained_clfs_polyMNIST/pretrained_img_to_digit_clf_m"+str(idx), **conversion_kwargs), strict=False)
    clfs[idx].eval()
    if args.cuda:
        clfs[idx].cuda()


def train(epoch, agg):
    # clf_lr = None
    model.train()
    b_loss = 0
    # num_batches_epoch = int(len(test_loader.dataset) / float(args.batch_size))
    # means_wgrad, stds_wgrad = [],[]
    for i, dataT in enumerate(train_loader):
        if i == 1:
            break
        data, labels_batch = unpack_data_polymnist_train(dataT, device=device)
        optimizer.zero_grad()
        loss = -objective(model, data, K=args.K)
        loss.backward()
        # for vae in model.vaes:
        #    means_wgrad.append(torch.mean(vae._pw_params[1].grad))
        #    stds_wgrad.append(torch.mean(vae._pw_params[1].grad))
        optimizer.step()
        b_loss += loss.item()
        if args.print_freq > 0 and i % args.print_freq == 0:
            print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size))
        """
        if i == (num_batches_epoch - 1):
            labels_batch = nn.functional.one_hot(labels_batch, num_classes=10).float()
            labels = labels_batch.cpu().data.numpy().reshape(args.batch_size, 10);
            latent_reps = []
            for v, vae in enumerate(model.vaes):
                with torch.no_grad():
                    qz_x_params = vae.enc(data[v])
                    zs_v = vae.qz_x(*qz_x_params).rsample()
                ws_v, us_v = torch.split(zs_v, [args.latent_dim_w, args.latent_dim_u], dim=-1)
                latent_reps.append([zs_v.cpu().data.numpy(), ws_v.cpu().data.numpy(), us_v.cpu().data.numpy()])
        """
    epoch_loss = b_loss / len(train_loader.dataset)
    # writer.add_scalar("Meanwgrad", sum(means_wgrad)/len(means_wgrad), epoch)
    # writer.add_scalar("Stdwgrad", sum(stds_wgrad)/len(stds_wgrad), epoch)
    writer.add_scalar("Loss/train", epoch_loss, epoch)
    agg['train_loss'].append(epoch_loss)
    print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, agg['train_loss'][-1]))
    # return clf_lr


def test(epoch, agg):
    model.eval()
    b_loss = 0
    with torch.no_grad():
        test_selected_samples = get_10_mnist_samples(test_loader.dataset,
                                                     num_testing_images=test_loader.dataset.__len__())
        for i, dataT in enumerate(test_loader):
            data = unpack_data_polymnist(dataT, device=device)
            loss = -t_objective(model, data, K=args.K)
            b_loss += loss.item()
            if i == 0: #and epoch % 10 == 0:
                cg_imgs = model.cross_generate_tb(test_selected_samples)
                for i in range(NUM_VAES):
                    for j in range(NUM_VAES):
                        #print(cg_imgs[i][j].size())
                        writer.add_image(tag='Cross_Generation/m{}/m{}'.format(i, j), img_tensor=cg_imgs[i][j],
                                         global_step=epoch)
                if not args.no_analytics:
                    model.analyse(data, runPath, epoch)
    epoch_loss = b_loss / len(test_loader.dataset)
    writer.add_scalar("Loss/test", epoch_loss, epoch)
    agg['test_loss'].append(epoch_loss)
    print('====>             Test loss: {:.4f}'.format(agg['test_loss'][-1]))


def cross_coherence_2mods():
    model.eval()
    corrs = [[0 for idx, modal in enumerate(model.vaes)] for idx, modal in enumerate(model.vaes)]
    total = 0
    with torch.no_grad():
        for i, dataT in enumerate(test_loader):
            data, targets = unpack_and_send_to_device_coherence(dataT, device)  # needs to be sent to device
            total += targets.size(0)
            _, px_zs, _ = model.reconstruct_options_forw(data, "jointprior", factor=1.0)
            for idx_srt, srt_mod in enumerate(model.vaes):
                for idx_trg, trg_mod in enumerate(model.vaes):
                    clfs_results = torch.argmax(clfs[idx_trg](px_zs[idx_srt][idx_trg].mean.squeeze(0)), dim=-1)
                    corrs[idx_srt][idx_trg] += (clfs_results == targets).sum().item()
        for idx_trgt, vae in enumerate(model.vaes):
            for idx_strt, _ in enumerate(model.vaes):
                corrs[idx_strt][idx_trgt] = corrs[idx_strt][idx_trgt] / total

        means_target = [0 for idx, modal in enumerate(model.vaes)]
        for idx_target, _ in enumerate(model.vaes):
            means_target[idx_target] = mean(
                [corrs[idx_start][idx_target] for idx_start, _ in enumerate(model.vaes) if idx_start != idx_target])
    return corrs, means_target, mean(means_target)


def unconditional_coherence_and_lr(clf_lr):
    model.eval()
    correct = 0
    total = 0
    lr_acc_m0_u, lr_acc_m1_u, lr_acc_m2_u, lr_acc_m3_u, lr_acc_m4_u = [], [], [], [], []
    lr_acc_m0_w, lr_acc_m1_w, lr_acc_m2_w, lr_acc_m3_w, lr_acc_m4_w = [], [], [], [], []
    lr_acc_m0_uw, lr_acc_m1_uw, lr_acc_m2_uw, lr_acc_m3_uw, lr_acc_m4_uw = [], [], [], [], []
    accuracies_lr = {}
    with torch.no_grad():
        for i, dataT in enumerate(test_loader):
            # Unconditional coherence
            data, targets = unpack_and_send_to_device_coherence(dataT, device)
            b_size = data[0].size(0)
            labels_batch = nn.functional.one_hot(targets, num_classes=10).float()
            labels = labels_batch.cpu().data.numpy().reshape(b_size, 10)
            uncond_gens = model.generate_for_coherence(b_size)
            uncond_gens = [elem.to(device) for elem in uncond_gens]
            clfs_resultss = []
            for idx_trg, trg_mod in enumerate(model.vaes):
                clfs_results = torch.argmax(clfs[idx_trg](uncond_gens[idx_trg]), dim=-1)
                if idx_trg == 0:
                    total += b_size
                clfs_resultss.append(clfs_results)
            clfs_resultss_tensor = torch.stack(clfs_resultss, dim=-1)
            for dim in range(clfs_resultss_tensor.size(0)):
                if torch.unique(clfs_resultss_tensor[dim, :]).size(0) == 1:
                    correct += 1
            # Lr
            if clf_lr is not None:
                latent_reps = []
                for v, vae in enumerate(model.vaes):
                    with torch.no_grad():
                        qz_x_params = vae.enc(data[v])
                        zs_v = vae.qz_x(*qz_x_params).rsample()
                    ws_v, us_v = torch.split(zs_v, [args.latent_dim_w, args.latent_dim_u], dim=-1)
                    latent_reps.append([zs_v.cpu().data.numpy(), ws_v.cpu().data.numpy(), us_v.cpu().data.numpy()])
                accuracies = classify_latent_representations(clf_lr, latent_reps, labels)

                lr_acc_m0_uw.append(np.mean(accuracies['m0_uw']))
                lr_acc_m1_uw.append(np.mean(accuracies['m1_uw']))
                lr_acc_m2_uw.append(np.mean(accuracies['m2_uw']))
                lr_acc_m3_uw.append(np.mean(accuracies['m3_uw']))
                lr_acc_m4_uw.append(np.mean(accuracies['m4_uw']))

                lr_acc_m0_w.append(np.mean(accuracies['m0_w']))
                lr_acc_m1_w.append(np.mean(accuracies['m1_w']))
                lr_acc_m2_w.append(np.mean(accuracies['m2_w']))
                lr_acc_m3_w.append(np.mean(accuracies['m3_w']))
                lr_acc_m4_w.append(np.mean(accuracies['m4_w']))

                lr_acc_m0_u.append(np.mean(accuracies['m0_u']))
                lr_acc_m1_u.append(np.mean(accuracies['m1_u']))
                lr_acc_m2_u.append(np.mean(accuracies['m2_u']))
                lr_acc_m3_u.append(np.mean(accuracies['m3_u']))
                lr_acc_m4_u.append(np.mean(accuracies['m4_u']))

        uncond_coherence = correct / total

        accuracies_lr['m0_uw'] = mean(lr_acc_m0_uw)
        accuracies_lr['m1_uw'] = mean(lr_acc_m1_uw)
        accuracies_lr['m2_uw'] = mean(lr_acc_m2_uw)
        accuracies_lr['m3_uw'] = mean(lr_acc_m3_uw)
        accuracies_lr['m4_uw'] = mean(lr_acc_m4_uw)

        accuracies_lr['m0_w'] = mean(lr_acc_m0_w)
        accuracies_lr['m1_w'] = mean(lr_acc_m1_w)
        accuracies_lr['m2_w'] = mean(lr_acc_m2_w)
        accuracies_lr['m3_w'] = mean(lr_acc_m3_w)
        accuracies_lr['m4_w'] = mean(lr_acc_m4_w)

        accuracies_lr['m0_u'] = mean(lr_acc_m0_u)
        accuracies_lr['m1_u'] = mean(lr_acc_m1_u)
        accuracies_lr['m2_u'] = mean(lr_acc_m2_u)
        accuracies_lr['m3_u'] = mean(lr_acc_m3_u)
        accuracies_lr['m4_u'] = mean(lr_acc_m4_u)

        accuracies_lr['_mean_uw'] = mean([accuracies_lr['m{}_uw'.format(n)] for n in range(5)])
        accuracies_lr['_mean_w'] = mean([accuracies_lr['m{}_w'.format(n)] for n in range(5)])
        accuracies_lr['_mean_u'] = mean([accuracies_lr['m{}_u'.format(n)] for n in range(5)])

    return uncond_coherence, accuracies_lr


def calculate_fid_routine(datadir, fid_path, num_fid_samples, epoch):
    total_cond = 0
    # Create new directories for conditional FIDs
    for j in [0, 1, 2, 3, 4]:
        if os.path.exists(os.path.join(fid_path, 'random', 'm{}'.format(j))):
            shutil.rmtree(os.path.join(fid_path, 'random', 'm{}'.format(j)))
            os.makedirs(os.path.join(fid_path, 'random', 'm{}'.format(j)))
        else:
            os.makedirs(os.path.join(fid_path, 'random', 'm{}'.format(j)))
        for i in [0, 1, 2, 3, 4]:
            if os.path.exists(os.path.join(fid_path, 'm{}'.format(j), 'm{}'.format(i))):
                shutil.rmtree(os.path.join(fid_path, 'm{}'.format(j), 'm{}'.format(i)))
                os.makedirs(os.path.join(fid_path, 'm{}'.format(j), 'm{}'.format(i)))
            else:
                os.makedirs(os.path.join(fid_path, 'm{}'.format(j), 'm{}'.format(i)))
    with torch.no_grad():
        # Generate unconditional fid samples
        for tranche in range(num_fid_samples // 100):
            model.generate_for_fid_tb(fid_path, 100, tranche)
        # Generate conditional fid samples
        for i, dataT in enumerate(test_loader):
            data = unpack_data_polymnist(dataT, device=device)
            if total_cond < num_fid_samples:
                model.reconstruct_for_fid_tb(data, fid_path, i)
                total_cond += data[0].size(0)
        calculate_inception_features_for_gen_evaluation('../data/pt_inception-2015-12-05-6726825d.pth', device, fid_path, datadir)
        # FID calculation
        fid_randm_list = []
        fid_condgen_list = []
        for modality_target in ['m{}'.format(m) for m in range(5)]:
            file_activations_real = os.path.join(args.tmpdir, 'PolyMNIST', 'test',
                                                 'real_activations_{}.npy'.format(modality_target))
            feats_real = np.load(file_activations_real)
            file_activations_randgen = os.path.join(fid_path, 'random',
                                                    modality_target + '_activations.npy')
            feats_randgen = np.load(file_activations_randgen)
            fid_randval = calculate_fid(feats_real, feats_randgen)
            writer.add_scalar("FID/{}/{}".format('random', modality_target), fid_randval, epoch)
            fid_randm_list.append(fid_randval)
            fid_condgen_target_list = []
            for modality_source in ['m{}'.format(m) for m in range(5)]:
                file_activations_gen = os.path.join(fid_path, modality_source,
                                                    modality_target + '_activations.npy')
                feats_gen = np.load(file_activations_gen)
                fid_val = calculate_fid(feats_real, feats_gen)
                writer.add_scalar("FID/{}/{}".format(modality_source, modality_target), fid_val, epoch)
                fid_condgen_target_list.append(fid_val)
            fid_condgen_list.append(mean(fid_condgen_target_list))
        mean_fid_condgen = mean(fid_condgen_list)
        mean_fid_randm = mean(fid_randm_list)
        writer.add_scalar("FID/random_meanall", mean_fid_randm, epoch)
        writer.add_scalar("FID/condgen_meanll", mean_fid_condgen, epoch)
    if os.path.exists(fid_path):
        shutil.rmtree(fid_path)
        #os.makedirs(fid_path)


def train_clf_lr(dl):
    latent_rep = {'m0': {'zs': [], 'us': [], 'ws': []},
                  'm1': {'zs': [], 'us': [], 'ws': []},
                  'm2': {'zs': [], 'us': [], 'ws': []},
                  'm3': {'zs': [], 'us': [], 'ws': []},
                  'm4': {'zs': [], 'us': [], 'ws': []}}
    labels_all = []
    for i, dataT_lr in enumerate(dl):
        data, labels_batch = unpack_data_polymnist_train(dataT_lr, device=device)
        b_size = data[0].size(0)
        labels_batch = nn.functional.one_hot(labels_batch, num_classes=10).float()
        labels = labels_batch.cpu().data.numpy().reshape(b_size, 10);
        labels_all.append(labels)
        for v, vae in enumerate(model.vaes):
            # latent_rep['m{}'.format(v)]
            with torch.no_grad():
                qz_x_params = vae.enc(data[v])
                zs_v = vae.qz_x(*qz_x_params).rsample()
            ws_v, us_v = torch.split(zs_v, [args.latent_dim_w, args.latent_dim_u], dim=-1)
            latent_rep['m{}'.format(v)]['zs'].append(zs_v.cpu().data.numpy())
            latent_rep['m{}'.format(v)]['us'].append(us_v.cpu().data.numpy())
            latent_rep['m{}'.format(v)]['ws'].append(ws_v.cpu().data.numpy())
            # latent_reps.append([zs_v.cpu().data.numpy(), ws_v.cpu().data.numpy(), us_v.cpu().data.numpy()])
    # print(labels_all[0].shape)
    labels_all = np.concatenate(labels_all, axis=0)
    gt = np.argmax(labels_all, axis=1).astype(int)
    clf_lr = dict();
    for v, vae in enumerate(model.vaes):
        latent_rep_uw = np.concatenate(latent_rep['m{}'.format(v)]['zs'], axis=0)
        latent_rep_w = np.concatenate(latent_rep['m{}'.format(v)]['ws'], axis=0)
        latent_rep_u = np.concatenate(latent_rep['m{}'.format(v)]['us'], axis=0)
        # data_rep_uw, data_rep_w, data_rep_u = data_k
        clf_lr_rep_uw = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=1000)
        clf_lr_rep_u = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=1000)
        clf_lr_rep_w = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=1000)
        clf_lr_rep_uw.fit(latent_rep_uw, gt.ravel())
        clf_lr['m' + str(v) + '_' + 'uw'] = clf_lr_rep_uw
        clf_lr_rep_w.fit(latent_rep_w, gt.ravel())
        clf_lr['m' + str(v) + '_' + 'w'] = clf_lr_rep_w
        clf_lr_rep_u.fit(latent_rep_u, gt.ravel())
        clf_lr['m' + str(v) + '_' + 'u'] = clf_lr_rep_u
    return clf_lr


def estimate_log_marginal(K):
    """Compute an IWAE estimate of the log-marginal likelihood of test data."""
    model.eval()
    marginal_loglik = 0
    with torch.no_grad():
        for dataT in test_loader:
            data = unpack_data_polymnist(dataT, device=device)
            marginal_loglik += -t_objective(model, data, K).item()

    marginal_loglik /= len(test_loader.dataset)
    writer.add_scalar("Marginal_llik", marginal_loglik)
    print('Marginal Log Likelihood (IWAE, K = {}): {:.4f}'.format(K, marginal_loglik))


if __name__ == '__main__':
    with Timer('MM-VAE') as t:
        agg = defaultdict(list)
        for epoch in range(1, args.epochs + 1):
            train(epoch, agg)
            # test(epoch, agg)
            if epoch % 50 == 0:
                test(epoch, agg)
                clf_lr = train_clf_lr(train_loader)
                save_model_light(model, runPath + '/model_' + str(epoch) + '.rar')
                save_vars(agg, runPath + '/losses_' + str(epoch) + '.rar')
                gen_samples = model.generate_tb()
                for j in range(NUM_VAES):
                    writer.add_image(tag='Generation_m{}'.format(j), img_tensor=gen_samples[j],
                                     global_step=epoch)
                cors, means_tgt, mt = cross_coherence_2mods()
                writer.add_scalar("Conditional_coherence_meanall", mt, global_step=epoch)
                for i in range(NUM_VAES):
                    writer.add_scalar("Conditional_coherence_target_m{}".format(i), means_tgt[i], global_step=epoch)
                    for j in range(NUM_VAES):
                        writer.add_scalar("Conditional_coherence_m{}xm{}".format(i, j), cors[i][j], global_step=epoch)
                uncond_coher, accuracies_lr = unconditional_coherence_and_lr(clf_lr)
                writer.add_scalar("Unconditional_coherence", uncond_coher, global_step=epoch)
                for key in accuracies_lr:
                    writer.add_scalar("Latclassacc_" + key, accuracies_lr[key], global_step=epoch)
                calculate_fid_routine(datadir, fid_path, 10000, epoch)
                if args.generate_multiple_ws:
                    model.generate_multiple_ws(runPath, epoch)
        writer.flush()
        writer.close()