# Train CMVAE/CHolderplus model CUBICC dataset
import os
import shutil
import argparse
import sys
import json
from pathlib import Path
import numpy as np
import torch
from torch import optim
import models
import objectives as objectives
from utils import Logger, Timer, save_model_light, unpack_data
from torch.utils.data import DataLoader, Subset
import wandb
from test_functions_CUBICC import calculate_inception_features_for_gen_evaluation, calculate_fid
from dataset_CUBICC import CUBICCDataset


def build_parser(default_model_type):
    parser = argparse.ArgumentParser(description='CMVAE/CHolderplus CUBICC Experiment')
    parser.add_argument('--model-type', type=str, default=default_model_type, choices=['cmvae', 'cholderplus'],
                        help='model variant to train')
    parser.add_argument('--cuda-device', type=str, default='',
                        help='override cuda device (example: cuda or cuda:1)')
    parser.add_argument('--experiment', type=str, default='', metavar='E',
                        help='experiment name')
    parser.add_argument('--obj', type=str, default='dreg', choices=['iwae', 'dreg'],
                        help='objective to use')
    parser.add_argument('--K', type=int, default=10,
                        help='number of samples when resampling in the latent space')
    parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                        help='batch size for data (default: 256)')
    parser.add_argument('--epochs', type=int, default=400, metavar='E',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--latent-dim-w', type=int, default=32, metavar='L',
                        help='latent dimensionality (default: 20)')
    parser.add_argument('--latent-dim-z', type=int, default=64, metavar='L',
                        help='latent dimensionality (default: 20)')
    parser.add_argument('--latent-dim-c', type=int, default=35, metavar='L',
                        help='latent dimensionality (default: 20)')
    parser.add_argument('--learn-prior-c', action='store_true', default=True,
                        help='learn model prior parameters for w')
    parser.add_argument('--print-freq', type=int, default=50, metavar='f',
                        help='frequency with which to print stats (default: 0)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disable CUDA use')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--beta', type=float, default=1.0, help='random seed (default: 1)')
    parser.add_argument('--llik_scaling_sent', type=float, default=5.0,
                        help='likelihood scaling factor sentences')
    parser.add_argument('--datadir', type=str, default='data',
                        help=' Directory where data is stored and samples used for FID calculation are saved')
    parser.add_argument('--outputdir', type=str, default='outputs',
                        help='Output directory')
    parser.add_argument('--inception_path', type=str, default='data/pt_inception-2015-12-05-6726825d.pth',
                        help='Path to inception module for FID calculation')
    parser.add_argument('--priorposterior', type=str, default='Normal', choices=['Normal', 'Laplace'],
                        help='distribution choice for prior and posterior')
    parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
    parser.add_argument('--use-disen', action='store_true', default=False,
                        help='use disentangled q(w|x,z) posterior for shared/private latents')
    return parser


def resolve_device(args):
    if not args.cuda:
        return torch.device("cpu")
    if args.cuda_device:
        return torch.device(args.cuda_device)
    return torch.device("cuda")


def main(default_model_type='cmvae'):
    parser = build_parser(default_model_type)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)


    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = resolve_device(args)
    print(device)

    if args.model_type == 'cholderplus':
        model_cls = models.CHolderplus
        args.model_name = "CHolderplus"
    else:
        model_cls = models.CMVAE
        args.model_name = "CMVAE"
    model = models.CUB_Image_Sentence(args, model_cls=model_cls).to(device)

    if not args.experiment:
        args.experiment = model.modelName

    run_parts = [
        args.obj,
        str(args.latent_dim_w),
        str(args.latent_dim_z),
    ]
    run_parts.append(str(args.beta))
    run_parts.append(str(args.seed))
    runId = "_".join(run_parts)
    experiment_dir = Path(os.path.join(args.outputdir, args.experiment, "checkpoints"))
    experiment_dir.mkdir(parents=True, exist_ok=True)
    runPath = os.path.join(str(experiment_dir), runId)
    if os.path.exists(runPath):
        shutil.rmtree(runPath)
    os.makedirs(runPath)
    sys.stdout = Logger('{}/run.log'.format(runPath))
    print('Expt:', runPath)
    print('RunID:', runId)

    num_vaes = len(model.vaes)

    if args.use_disen:
        fid_tag = 'disen_'
    else:
        fid_tag = ''
    fid_path = os.path.join(
        args.datadir,
        'fids_CUBICC_' + args.model_type + '_' + fid_tag + (runPath.rsplit('/')[-1])
    )
        
    datadirCUBICC = os.path.join(args.datadir, "CUBICC")

    with open('{}/args.json'.format(runPath), 'w') as fp:
        json.dump(args.__dict__, fp)
    torch.save(args, '{}/args.rar'.format(runPath))

    wandb.login()

    wandb.init(
        project=args.experiment,
        config=args,
    )

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                           lr=args.lr, amsgrad=True)

    dataset = CUBICCDataset(datadir=os.path.join(args.datadir, 'CUBICC'))

    train_dataset = Subset(dataset, dataset.train_split)
    validation_dataset = Subset(dataset, dataset.validation_split)
    test_dataset = Subset(dataset, dataset.test_split)

    # kwargs = {'num_workers': 2, 'pin_memory': True} if device == 'cuda' else {}
    kwargs = {'num_workers': 0, 'pin_memory': False} if device == 'cuda' else {}

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
    validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)

    is_holder = args.model_type == "cholderplus"
    objective_prefix = 'cholderplus_' if is_holder else 'cmvae_'
    objective = getattr(objectives, objective_prefix + args.obj)
    t_objective = objective

    def train(epoch):
        model.train()
        b_loss = 0
        for i, dataT in enumerate(train_loader):
            data, label = unpack_data(dataT, device=device)
            optimizer.zero_grad()
            obj_val = objective(model, data, K=args.K)
            loss = -obj_val
            loss.backward()
            optimizer.step()
            b_loss += loss.item()
            if args.print_freq > 0 and i % args.print_freq == 0:
                print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size))
        epoch_loss = b_loss / len(train_loader.dataset)
        wandb.log({"Loss/train": epoch_loss}, step=epoch)
        print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, epoch_loss))

    def test(epoch):
        b_loss = 0
        with torch.no_grad():
            for i, dataT in enumerate(validation_loader):
                data, label = unpack_data(dataT, device=device)
                loss = -t_objective(model, data, K=args.K)
                b_loss += loss.item()
                if i == 0 and epoch % 1 == 0:
                    cg_imgs = model.self_and_cross_modal_generation(data, 10, 10)
                    for i in range(num_vaes):
                        for j in range(num_vaes):
                            wandb.log({'Cross_Generation/m{}/m{}'.format(i, j): wandb.Image(cg_imgs[i][j])}, step=epoch)
        epoch_loss = b_loss / len(validation_loader.dataset)
        wandb.log({"Loss/test": epoch_loss}, step=epoch)
        print('====>             Test loss: {:.4f}'.format(epoch_loss))

    def calculate_fid_routine(datadir, fid_path, num_fid_samples, epoch):
        """ Calculate FID scores for unconditional and conditional generation """
        total_cond = 0
        for j in [0]:
            if os.path.exists(os.path.join(fid_path, 'test', 'm{}'.format(j))):
                shutil.rmtree(os.path.join(fid_path, 'test', 'm{}'.format(j)))
                os.makedirs(os.path.join(fid_path, 'test', 'm{}'.format(j)))
            else:
                os.makedirs(os.path.join(fid_path, 'test', 'm{}'.format(j)))
            if os.path.exists(os.path.join(fid_path, 'random', 'm{}'.format(j))):
                shutil.rmtree(os.path.join(fid_path, 'random', 'm{}'.format(j)))
                os.makedirs(os.path.join(fid_path, 'random', 'm{}'.format(j)))
            else:
                os.makedirs(os.path.join(fid_path, 'random', 'm{}'.format(j)))
            for i in [0, 1]:
                if os.path.exists(os.path.join(fid_path, 'm{}'.format(i), 'm{}'.format(j))):
                    shutil.rmtree(os.path.join(fid_path, 'm{}'.format(i), 'm{}'.format(j)))
                    os.makedirs(os.path.join(fid_path, 'm{}'.format(i), 'm{}'.format(j)))
                else:
                    os.makedirs(os.path.join(fid_path, 'm{}'.format(i), 'm{}'.format(j)))
        with torch.no_grad():
            for tranche in range(num_fid_samples // 100):
                kwargs_uncond = {
                    'savePath': fid_path,
                    'tranche': tranche
                }
                model.generate_unconditional(N=100, indexes_to_prune=None, indexes_to_select=None, random=True, coherence_calculation=False, fid_calculation=True, **kwargs_uncond)
            for i, dataT in enumerate(validation_loader):
                data, label = unpack_data(dataT, device=device)

                if total_cond < num_fid_samples:
                    model.self_and_cross_modal_generation_for_fid_calculation(data, fid_path, i)
                    model.save_test_samples_for_fid_calculation(data, fid_path, i)
                    total_cond += data[0].size(0)
            calculate_inception_features_for_gen_evaluation(args.inception_path, device,
                                                            fid_path, datadir)
            modality_target = 'm{}'.format(0)
            file_activations_real = os.path.join(fid_path, 'test',
                                                 'real_activations_{}.npy'.format(modality_target))
            feats_real = np.load(file_activations_real)
            file_activations_randgen = os.path.join(fid_path, 'random',
                                                    modality_target + '_activations.npy')
            feats_randgen = np.load(file_activations_randgen)
            fid_randval = calculate_fid(feats_real, feats_randgen)
            wandb.log({"FID/Random/{}".format(modality_target): fid_randval}, step=epoch)
            fid_condgen_target_list = []
            for modality_source in ['m{}'.format(m) for m in [0, 1]]:
                file_activations_gen = os.path.join(fid_path, modality_source,
                                                    modality_target + '_activations.npy')
                feats_gen = np.load(file_activations_gen)
                fid_val = calculate_fid(feats_real, feats_gen)
                wandb.log({"FID/{}/{}".format(modality_source, modality_target): fid_val}, step=epoch)
                fid_condgen_target_list.append(fid_val)
        if os.path.exists(fid_path):
            shutil.rmtree(fid_path)
            os.makedirs(fid_path)

    with Timer('MM-VAE') as t:
        for epoch in range(1, args.epochs + 1):
            train(epoch)
            if epoch % 10 == 0:
                test(epoch)
                gen_samples = model.generate_unconditional(N=100, indexes_to_prune=None, indexes_to_select=None,
                                                           coherence_calculation=False, fid_calculation=False,
                                                           random=True)
                for j in range(num_vaes):
                    wandb.log({'Generations/m{}'.format(j): wandb.Image(gen_samples[j])}, step=epoch)
                calculate_fid_routine(datadirCUBICC, fid_path, 10000, epoch)
            if epoch % 20 == 0:
                save_model_light(model, runPath + '/model_' + str(epoch) + '.rar')


if __name__ == '__main__':
    main()
