from __future__ import print_function

import os
import sys
import random
from argparse import ArgumentParser
from distutils.util import strtobool
from copy import copy

import numpy as np
import torch
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
from pixyz.utils import set_cache_maxsize

from models import JVAE, MVAE, MMVAE, MoPoE, MMJSD, CRMVAE, CRMVAE_Update, CRMVAE_New, CRMVAE_New2, CRMVAE2

# import dataset
sys.path.append('../')
from multimodal_datasets.SVHNMNISTDataset import SVHNMNIST
from multimodal_datasets.CelebADataset import CelebaDataset
from multimodal_datasets.PolyMNIST import PolyMNIST
from multimodal_datasets.CUBImageCaptions import MMCUB
dir_data = os.path.join(os.getcwd().replace('jmvae_journal', ''), 'multimodal_datasets/data')

models = {"JVAE": JVAE, "MVAE": MVAE, "MMVAE": MMVAE, "MoPoE": MoPoE, "MMJSD": MMJSD, "CRMVAE": CRMVAE, "MVTCAE": CRMVAE, "MMJSD_PoE": CRMVAE,
 "CRMVAE_Update": CRMVAE_Update, "MVTCAE_Update": CRMVAE_Update,
 "CRMVAE_New": CRMVAE_New, "CRMVAE_New2": CRMVAE_New2, "CRMVAE2": CRMVAE2}

parser = ArgumentParser(description='Train models')
parser.add_argument('--model', choices=list(models.keys()), default="MVAE")
parser.add_argument('--data', choices=["SVHNMNIST", "CelebAText", "PolyMNIST", "TranslatedPolyMNIST", "CUB"], default="SVHNMNIST")
parser.add_argument('--optimizer_name', choices=["rmsprop", "adadelta", "adagrad", "adam"], default="adam")
parser.add_argument('--seed', type=int, action='store', default=1)
parser.add_argument('--epochs', type=int, action='store', default=200)
parser.add_argument('--clf_epochs', type=int, action='store', default=0)
parser.add_argument('--batch_size', type=int, action='store', default=256)
parser.add_argument('--lr', type=float, action='store', default=1e-3)
parser.add_argument('--weight_decay', type=float, action='store', default=0)
parser.add_argument('--test_epoch', type=int, action='store', default=5)

parser.add_argument('--z_dim', type=int, action='store', default=32)
parser.add_argument('--s_dim', type=int, action='store', default=0)
parser.add_argument('--num_sampling', type=int, action='store', default=1)

parser.add_argument('--rec_weight_all', type=lambda s: [float(item) for item in s.split(',')], default=None)
parser.add_argument('--modality_id', type=lambda s: [int(item) for item in s.split(',')], default=[0,1,2])
parser.add_argument('--beta', type=float, action='store', default=1.0) #2.5
parser.add_argument('--gamma', type=float, action='store', default=1.0) #2.5
parser.add_argument('--alpha_crmvae', type=float, action='store', default=0.5) #2.5

parser.add_argument('--plot_image', type=strtobool, action='store', default="true")
parser.add_argument('--reconst_unimodal', type=strtobool, action='store', default="true")
parser.add_argument('--sample_sum', type=strtobool, action='store', default="true")
parser.add_argument('--kl_set', type=strtobool, action='store', default="false")
parser.add_argument('--forward_kl', type=strtobool, action='store', default="true")
parser.add_argument('--eval_train', type=strtobool, action='store', default="false")
parser.add_argument('--eval_fid', type=strtobool, action='store', default="false")
parser.add_argument('--eval_reconst', type=strtobool, action='store', default="false")
parser.add_argument('--save_models', type=strtobool, action='store', default="true")
parser.add_argument('--use_schedule', type=strtobool, action='store', default="false")
parser.add_argument('--use_batch_gon', type=strtobool, action='store', default="false")
parser.add_argument('--fix_elbo_smvae', type=strtobool, action='store', default="false")
parser.add_argument('--kl_annealing', type=int, action='store', default=0)
parser.add_argument('--kl_annealing_start', type=int, action='store', default=-1)
parser.add_argument('--fix_weight', type=strtobool, action='store', default="false")
parser.add_argument('--batch_sample', type=strtobool, action='store', default="false")
#parser.add_argument('--gpu', type=int, action='store', default=0)

args = parser.parse_args()
if args.model == "MVTCAE":
    args.reconst_unimodal = 0

if args.model == "MMJSD_PoE":
    args.reconst_unimodal = 0
    args.forward_kl = False

if args.model == "MVTCAE_Update":
    args.reconst_unimodal = 0    

params = copy(vars(args))
print(params)

set_cache_maxsize(100)

wandb.init(config=args, project='SMVAEs')
config = wandb.config

if torch.cuda.is_available():
#    device = "cuda:%d" % args.gpu
    device = "cuda"
else:
    device = "cpu"

torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

#torch.autograd.set_detect_anomaly(True)

# load dataset
if params["data"] == "SVHNMNIST":
    d_size_m1 = 28**2
    d_size_m2 = 3*32**2
    d_size_m3 = 71

    if params["rec_weight_all"] is None:
        params["rec_weight_all"] = [d_size_m2/d_size_m1, 1.0, d_size_m2/d_size_m3]

    svhnmnist_train = SVHNMNIST(dir_data, len_sequence=8, train=True,
                                data_multiplications=20, two_modality=False)
    svhnmnist_test = SVHNMNIST(dir_data, len_sequence=8, train=False,
                            data_multiplications=20, two_modality=False)

    train_loader = DataLoader(svhnmnist_train, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
    test_loader = DataLoader(svhnmnist_test, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
    params["alphabet"] = test_loader.dataset.alphabet

elif params["data"]=="CelebAText":
    d_size_m1 = 64*64*3
    d_size_m2 = 256

    if params["rec_weight_all"] is None:
        params["rec_weight_all"] = [1.0, d_size_m1/d_size_m2]

    train_dataset = CelebaDataset(dir_data, partition=0, use_text=True)
    eval_dataset = CelebaDataset(dir_data, partition=1, use_text=True)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
    test_loader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8, drop_last=True)

elif params["data"]=="PolyMNIST":
    if params["rec_weight_all"] is None:
        params["rec_weight_all"] = [1.0, 1.0, 1.0, 1.0, 1.0]

    train_dataset = PolyMNIST(os.path.join(dir_data, "MMNIST"), split="train")
    eval_dataset = PolyMNIST(os.path.join(dir_data, "MMNIST"), split="test")

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)
    test_loader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=1, drop_last=True)

elif params["data"] == "TranslatedPolyMNIST":
    if params["rec_weight_all"] is None:
        params["rec_weight_all"] = [1.0, 1.0, 1.0, 1.0, 1.0]

    data_path = os.path.join(dir_data, "translated_polymnist_new")
    train_dataset = PolyMNIST(data_path, split="train")
    eval_dataset = PolyMNIST(data_path, split="test")

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)
    test_loader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=1, drop_last=True)

elif params["data"] == "CUB":
    d_size_m1 = 64*64*3
    d_size_m2 = 32

    if params["rec_weight_all"] is None:
        params["rec_weight_all"] = [1.0, 1.0]#[1.0, d_size_m1/d_size_m2]

    train_dataset = MMCUB(dir_data, split="train")
    eval_dataset = MMCUB(dir_data, split="test")

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)
    test_loader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1, drop_last=True)        

else:
    pass

# init models
params["optimizer_params"] = {"lr": args.lr, "weight_decay": args.weight_decay}
model = models[args.model](params, device=device)
print(model)
scheduler = StepLR(model.optimizer, step_size=5, gamma=0.97)

# load classifier
model.load_clf()

# start experiments
print("Train models")

for epoch in tqdm(range(args.epochs)):
    if args.use_schedule:
        scheduler.step()

    train_dict = model.train(epoch, train_loader)
    wandb.log({"train": train_dict}, step=epoch)

    if epoch % args.test_epoch == 0:
        test_dict = model.test(epoch, test_loader)
        wandb.log({"val": test_dict}, step=epoch)        

    model.log_models(wandb.run.dir)
