Search.setIndex({"docnames": ["code/config_cls", "code/dataloader", "code/datasets", "code/decoders", "code/encoders", "code/eval_cdsprites", "code/infer", "code/mmvae_base", "code/mmvae_models", "code/objectives", "code/trainer", "code/vae", "index", "tutorials/adddataset", "tutorials/addmodel"], "filenames": ["code/config_cls.rst", "code/dataloader.rst", "code/datasets.rst", "code/decoders.rst", "code/encoders.rst", "code/eval_cdsprites.rst", "code/infer.rst", "code/mmvae_base.rst", "code/mmvae_models.rst", "code/objectives.rst", "code/trainer.rst", "code/vae.rst", "index.rst", "tutorials/adddataset.rst", "tutorials/addmodel.rst"], "titles": ["Config class", "DataLoader", "Dataset Classes", "Decoders", "Encoders", "Evaluate on CdSprites+ dataset", "Inference module", "Multimodal VAE Base Class", "Multimodal VAE models", "Objectives", "MultimodalVAE class", "VAE class", "Multimodal VAE Comparison Toolkit", "Add a new dataset", "Add a new model"], "terms": {"multimodal_compar": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], "model": [0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13], "config_cl": 0, "parser": [0, 10], "eval_onli": 0, "fals": [0, 5, 7, 8, 10, 11, 14], "base": [0, 1, 2, 3, 4, 6, 8, 9, 10, 11, 12, 13], "object": [0, 1, 2, 5, 7, 8, 10, 11, 12, 13, 14], "manag": 0, "_define_param": 0, "set": [0, 4, 9, 10, 11, 12], "up": [0, 10, 11], "variabl": [0, 8, 14], "from": [0, 2, 5, 7, 8, 9, 10, 11, 13, 14], "retriev": [0, 14], "modal": [0, 1, 2, 7, 8, 9, 10, 11, 13, 14], "specif": [0, 2, 6, 7, 12, 14], "info": [0, 12], "_get_mods_config": 0, "make": [0, 1, 2, 7, 10, 11, 13], "list": [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 14], "all": [0, 1, 2, 7, 9, 10, 11, 14], "dict": [0, 1, 4, 5, 7, 8, 9, 10, 11, 13, 14], "self": [0, 9, 11, 13, 14], "modality_1": [0, 13, 14], "modality_n": 0, "load": [0, 1, 2, 5, 6, 13], "label": [0, 1, 2, 10, 13, 14], "provid": [0, 5, 6, 10, 12], "_load_config": 0, "pth": [0, 2, 13], "_parse_arg": 0, "yml": [0, 13, 14], "specifi": [0, 10, 13, 14], "cfg": [0, 10, 14], "argument": [0, 10, 13, 14], "ani": [0, 2, 8, 13, 14], "addit": 0, "overrid": [0, 14], "valu": [0, 7, 8, 10, 14], "return": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "rtype": [0, 2, 5, 8, 9, 11, 13, 14], "_setup_savedir": 0, "creat": [0, 6, 10, 13, 14], "directori": [0, 6, 13, 14], "result": 0, "folder": [0, 5, 13], "save": [0, 2, 5, 8, 10, 13], "copi": [0, 13], "change_se": 0, "seednum": 0, "dump_config": 0, "find_vers": 0, "get_vis_dir": 0, "path": [0, 2, 5, 6, 8, 10, 13, 14], "": [0, 9], "visualis": 0, "type": [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 13, 14], "str": [0, 1, 2, 4, 5, 7, 8, 9, 10, 11, 13], "parse_param": 0, "get": [0, 1, 7, 13], "pars": 0, "param": [0, 5, 8, 9, 11, 13, 14], "paramet": [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 14], "argpars": [0, 10], "argumentpars": [0, 10], "class": [1, 3, 4, 6, 8, 9, 12, 14], "datamodul": [1, 6, 10, 13], "config": [1, 6, 10, 12, 14], "lightningdatamodul": 1, "check_load_testdata": 1, "check_testdata_avail": 1, "collate_fn": 1, "batch": [1, 3, 4, 9], "custom": [1, 12, 13], "collat": [1, 13], "function": [1, 2, 7, 8, 9, 11, 12, 13, 14], "put": [1, 13], "data": [1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 14], "dictionari": [1, 7, 8, 9, 10, 11, 14], "prepar": [1, 2, 7, 13], "mask": [1, 2, 3, 4, 13, 14], "need": [1, 9, 13, 14], "input": [1, 2, 5, 7, 8, 9, 10, 11, 13, 14], "get_dataset_class": 1, "dataset": [1, 10, 12], "accord": [1, 2, 7, 8, 10], "name": [1, 7, 8, 10, 11, 14], "get_label_for_indic": 1, "indic": [1, 5, 7], "split": [1, 2, 9, 10, 13], "given": [1, 2, 5, 7, 13], "get_label": [1, 2], "avail": [1, 2, 6, 13], "train": [1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14], "val": [1, 6, 10], "test": [1, 2, 6, 10, 13], "depend": [1, 13, 14], "get_num_sampl": 1, "num_sampl": [1, 10], "predict_dataload": 1, "togeth": 1, "make_mask": [1, 2], "sequenti": [1, 13], "torch": [1, 2, 3, 4, 7, 8, 9, 10, 11, 13, 14], "tensor": [1, 2, 3, 4, 7, 8, 9, 10, 11, 13, 14], "batch_siz": [1, 13, 14], "size": [1, 8, 9, 13, 14], "prepare_data_class": 1, "prepare_singlemod": 1, "mod_index": 1, "singlemod": 1, "int": [1, 2, 4, 5, 7, 8, 9, 10, 11, 14], "index": 1, "order": [1, 13], "setup": [1, 9], "stage": 1, "none": [1, 3, 4, 5, 7, 8, 9, 10, 11, 13, 14], "appropri": [1, 7], "test_dataload": 1, "train_dataload": 1, "val_dataload": 1, "basedataset": [2, 13], "testpth": [2, 13], "mod_typ": [2, 13, 14], "abstract": [2, 4, 7], "share": [2, 7, 8, 9, 14], "_mod_specific_load": [2, 13], "assign": [2, 13], "preprocess": 2, "_mod_specific_sav": [2, 13], "postprocess": [2, 13], "_postprocess": 2, "output_data": 2, "output": [2, 3, 5, 7, 8, 9, 10, 11, 13, 14], "_postprocess_all2img": [2, 13], "convert": [2, 13], "kind": 2, "imag": [2, 5, 13, 14], "travers": [2, 10, 11], "visual": [2, 10, 13], "process": [2, 6, 13], "_preprocess": 2, "_preprocess_imag": [2, 13], "dimens": [2, 8, 10, 13], "gener": [2, 7, 8, 9, 10, 11, 12, 13], "them": [2, 10, 13, 14], "feature_dim": [2, 10, 11, 13], "_preprocess_text_onehot": [2, 13], "text": [2, 5, 9, 13], "string": [2, 5, 10, 13], "one": [2, 13], "hot": [2, 13], "encod": [2, 7, 8, 10, 11, 12, 13, 14], "current_datatyp": 2, "whther": 2, "current": [2, 8, 12, 14], "point": 2, "eval_statistics_fn": 2, "option": [2, 3, 9, 10, 14], "run": [2, 6, 14], "systemat": [2, 12], "evalu": [2, 6, 10, 12, 13], "get_data": 2, "get_data_raw": [2, 13], "raw": 2, "get_processed_recon": 2, "recons_raw": 2, "came": 2, "decod": [2, 7, 9, 11, 12, 13, 14], "reconstruct": [2, 3, 5, 7, 8, 9, 10, 11, 13, 14], "get_test_data": 2, "whole": 2, "save_travers": 2, "recon": [2, 9, 13], "num_dim": 2, "grid": 2, "number": [2, 5, 7, 10, 12], "latent": [2, 3, 7, 8, 9, 10, 11, 14], "cdspritesplu": 2, "_postprocess_imag": [2, 13], "_postprocess_text": [2, 13], "_preprocess_text": [2, 13], "64": [2, 13], "3": [2, 8, 13, 14], "45": 2, "27": [2, 13], "1": [2, 3, 4, 5, 7, 8, 9, 11, 13, 14], "extract": 2, "level": [2, 5, 12, 13], "save_recon": [2, 13], "mod_nam": [2, 13], "set_vis_image_shap": 2, "celeba": [2, 13], "_postprocess_att": 2, "_preprocess_att": 2, "att": [2, 5], "4": [2, 3, 4], "2": [2, 3, 4, 5, 8, 9, 13, 14], "cub": [2, 13], "our": [2, 13, 14], "version": [2, 9, 13], "caltech": [2, 13], "ucsd": [2, 13], "bird": [2, 13], "we": [2, 8, 12, 13, 14], "us": [2, 5, 6, 8, 9, 10, 11, 12, 13, 14], "origin": [2, 13], "repres": [2, 13], "sequenc": [2, 4, 13], "each": [2, 9, 10, 12, 13, 14], "charact": [2, 5, 13], "incl": [2, 13], "space": [2, 7, 11, 13, 14], "246": [2, 13], "No": [2, 13], "t": [2, 10, 13], "snae": [2, 13], "fashionmnist": [2, 13], "_postprocess_label": 2, "_process_imag": 2, "_process_label": 2, "28": 2, "10": [2, 10, 14], "mnist_svhn": [2, 13, 14], "mnist": [2, 13, 14], "svhn": [2, 13, 14], "bimod": [2, 12], "can": [2, 6, 12, 13, 14], "also": [2, 5, 10, 12, 13, 14], "unimod": [2, 9, 11, 14], "_postprocess_mnist": 2, "_postprocess_svhn": 2, "_process_mnist": 2, "_process_svhn": 2, "32": [2, 14], "polymnist": [2, 13], "m0": 2, "m1": 2, "m2": 2, "m3": 2, "m4": 2, "sprite": [2, 13], "_postprocess_act": 2, "_postprocess_attribut": 2, "_postprocess_fram": 2, "action": [2, 13], "9": 2, "attribut": [2, 5, 12, 13], "6": 2, "frame": 2, "8": [2, 3, 4], "get_act": 2, "get_attribut": [2, 5], "get_fram": 2, "iter_over_input": 2, "out": [2, 10, 13], "f": [2, 13], "0": [2, 3, 4, 8, 9, 13, 14], "shape": [2, 8, 13], "anim": 2, "gif": 2, "dec_cnn": 3, "latent_dim": [3, 4], "data_dim": [3, 4, 11], "latent_priv": [3, 4], "vaedecod": 3, "_is_full_backward_hook": [3, 4, 7, 8, 11], "bool": [3, 4, 5, 7, 8, 10, 11], "forward": [3, 4, 7, 8, 10, 11, 14], "z": [3, 9, 14], "pass": [3, 4, 7, 8, 10, 11, 14], "sampl": [3, 5, 7, 8, 9, 10, 11, 12, 14], "vector": [3, 5, 10], "log": [3, 4, 8, 9, 14], "varianc": [3, 4], "tupl": [3, 4, 5, 7, 8, 9, 11, 14], "dec_fnn": 3, "dec_mnist": 3, "dec_mnist2": 3, "num_hidden_lay": [3, 4], "dec_polymnist": 3, "dec_svhn": 3, "dec_svhn2": 3, "dec_transform": 3, "ff_size": [3, 4], "1024": [3, 4], "num_lay": [3, 4], "num_head": [3, 4], "dropout": [3, 4], "activ": [3, 4], "gelu": [3, 4], "boolean": [3, 4, 8, 13, 14], "desir": 3, "length": [3, 13], "dec_transformerimg": 3, "dec_txttransform": 3, "dec_videogpt": 3, "n_res_lay": [3, 4], "x": [3, 4, 8, 9, 11, 13, 14], "net_typ": [3, 4], "networktyp": [3, 4], "vaecompon": [3, 4], "extra_hidden_lay": [3, 4], "hidden_dim": [3, 4], "enc_cnn": 4, "vaeencod": 4, "mean": [4, 5, 7, 8, 9, 11, 14], "enc_fnn": 4, "enc_mnist": 4, "enc_mnist2": 4, "enc_polymnist": 4, "enc_svhn": 4, "enc_svhn2": 4, "enc_transform": 4, "transform": [4, 13], "vae": [4, 5, 9, 10, 14], "implement": [4, 7, 11, 12, 14], "http": [4, 8, 9, 14], "github": [4, 8, 9, 12, 13, 14], "com": [4, 8, 9, 14], "mathux": 4, "actor": 4, "enc_transformerimg": 4, "enc_txttransform": 4, "enc_videogpt": 4, "downsampl": 4, "unspecifi": 4, "net_rol": 4, "networkrol": 4, "modul": [4, 7, 8, 10, 11, 12, 14], "_backward_hook": [4, 11], "callabl": [4, 11], "_buffer": [4, 11], "_forward_hook": [4, 11], "_forward_pre_hook": [4, 11], "_load_state_dict_post_hook": [4, 11], "_load_state_dict_pre_hook": [4, 11], "_modul": [4, 11], "_non_persistent_buffers_set": [4, 11], "_paramet": [4, 11], "_state_dict_hook": [4, 11], "eval": [5, 6], "eval_cdsprit": 5, "calculate_cross_coher": 5, "model_exp": 5, "classifi": 5, "calcul": [5, 6, 7, 8, 9, 11, 14], "cross": [5, 7, 9, 10, 12], "coher": 5, "accuraci": 5, "img": 5, "txt": 5, "multimod": [5, 9, 10, 11, 13, 14], "calculate_joint_coher": 5, "joint": [5, 7, 8, 10, 12, 14], "check_cross_sample_correct": 5, "testtext": 5, "m_exp": 5, "reconimag": 5, "recontext": 5, "detect": 5, "featur": [5, 10], "check": [5, 9], "thei": [5, 13, 14], "ar": [5, 7, 8, 13, 14], "ground": [5, 9], "truth": [5, 9], "ndarrai": 5, "whether": [5, 8, 10, 11], "i": [5, 7, 8, 9, 10, 12, 13, 14], "complet": 5, "correct": 5, "how": [5, 9, 11, 13, 14], "mani": [5, 9, 11], "ok": 5, "letter": 5, "float32": 5, "count_same_lett": 5, "b": [5, 9], "count": 5, "same": [5, 13], "two": [5, 9, 13], "match": [5, 13], "eval_al": 5, "eval_cdsprites_over_se": 5, "parent_dir": 5, "eval_single_model": 5, "eval_with_classifi": 5, "fill_cat": 5, "text_imag": 5, "image_text": 5, "find_in_list": 5, "target": [5, 9], "sourc": [5, 8, 9], "get_all_classifi": 5, "get_attribute_from_recon": 5, "get_mean_stat": 5, "list_of_stat": 5, "percentag": 5, "true": [5, 6, 7, 10, 11, 13, 14], "nest": 5, "multipl": [5, 6], "report": 5, "percent": 5, "fraction": 5, "get_mod_map": 5, "mod_dict": 5, "image_to_text": 5, "where": [5, 8, 10, 14], "load_classifi": 5, "class_typ": 5, "load_imag": [5, 13], "png": [5, 13], "dir": 5, "manhattan_dist": 5, "manharran": 5, "distanc": 5, "between": [5, 9], "vec": 5, "float": [5, 9], "search_att": 5, "idx": [5, 14], "text_to_imag": 5, "try_retrieve_att": 5, "multimodalvaeinf": 6, "includ": [6, 7, 14], "method": [6, 13, 14], "direct": 6, "The": [6, 7, 9, 12, 13, 14], "user": 6, "thi": [6, 7, 9, 12, 13, 14], "own": [6, 12, 13, 14], "outsid": [6, 9, 10], "dataload": [6, 10, 12], "compar": [6, 12, 14], "etc": [6, 13], "eval_statist": 6, "offici": [6, 12], "routin": 6, "defin": [6, 13, 14], "trainer": [6, 10], "py": [6, 13, 14], "If": [6, 10, 13, 14], "applic": [6, 14], "statist": 6, "get_base_path": 6, "find": [6, 12], "get_config": 6, "instanc": [6, 10, 13, 14], "get_datamodul": 6, "load_data": [6, 13], "an": [6, 9, 10, 13, 14], "necessari": [6, 14], "access": 6, "tool": 6, "get_wrapped_model": 6, "pytorch": [6, 7, 9, 12, 13, 14], "lightn": [6, 12, 14], "make_dataload": 6, "within": 6, "mmvae_bas": [7, 14], "torchmmva": [7, 8, 14], "n_latent": [7, 8, 11, 13, 14], "obj": [7, 8, 9, 13, 14], "beta": [7, 9, 11, 13, 14], "mmvae": [7, 9, 12], "add_va": 7, "vae_dict": 7, "moduledict": [7, 14], "updat": 7, "nn": [7, 13, 14], "A": [7, 8], "kei": [7, 8, 9, 10, 11, 13, 14], "baseva": [7, 11], "distribut": [7, 8, 9, 11], "qz_x": [7, 8, 9, 14], "k": [7, 8, 9, 11, 14], "namedtupl": 7, "vaeoutput": [7, 14], "get_missing_mod": 7, "mod": [7, 8, 14], "miss": [7, 8, 10, 14], "properti": [7, 8, 10, 11], "latent_factor": 7, "factor": [7, 9], "subspac": 7, "els": [7, 13], "make_output_dict": [7, 14], "encoder_dist": [7, 14], "decoder_dist": [7, 14], "latent_sampl": [7, 14], "joint_dist": [7, 14], "enc_dist_priv": [7, 14], "dec_dist_priv": [7, 14], "joint_decoder_dist": [7, 14], "cross_decoder_dist": [7, 14], "singl": 7, "come": [7, 9, 10], "modality_mix": [7, 8, 14], "mix": [7, 8, 13, 14], "chosen": [7, 8], "approach": [7, 8], "loss": [7, 8, 9, 10, 11, 14], "static": [7, 9], "product_of_expert": [7, 14], "mu": [7, 8, 14], "logvar": [7, 8, 11, 14], "product": [7, 14], "expert": [7, 8, 14], "posterior": [7, 8, 9, 11, 14], "pz_param": [7, 8, 9, 11, 14], "mmvae_model": [8, 14], "dmvae": [8, 12, 14], "obj_config": [8, 14], "model_config": [8, 14], "take": [8, 10, 14], "privat": [8, 11], "replac": [8, 13, 14], "get_remaining_mods_data": 8, "exclude_mod": 8, "logsumexp": 8, "dim": [8, 13, 14], "keepdim": 8, "smooth": 8, "maximum": 8, "keep": [8, 14], "squeez": 8, "seqam": 8, "lab": 8, "obligatori": [8, 14], "which": [8, 10, 13, 14], "optim": [8, 9, 10, 13, 14], "plu": [8, 14], "other": [8, 13, 14], "you": [8, 12, 13, 14], "wish": [8, 13, 14], "moe": [8, 13, 14], "runpath": 8, "epoch": [8, 10, 13, 14], "individu": [8, 14], "mopo": [8, 12, 14], "mixture_component_select": 8, "w_modal": 8, "input_batch": 8, "moe_fus": 8, "weight": [8, 9], "comput": [8, 9, 14], "elbo": [8, 9, 13, 14], "arxiv": 8, "org": 8, "pdf": 8, "2105": 8, "02470": 8, "poe_fus": 8, "reparameter": 8, "reweight_weight": 8, "w": [8, 13], "set_subset": 8, "powerset": 8, "poe": [8, 14], "infer": [8, 10, 12, 14], "prior_expert": [8, 14], "use_cuda": [8, 14], "univers": [8, 14], "prior": [8, 9, 14], "here": [8, 13, 14], "spheric": [8, 14], "gaussian": [8, 9, 14], "n": [8, 11, 13, 14], "dimension": [8, 14], "cast": [8, 14], "cuda": [8, 14], "baseobject": 9, "calc_kld": [9, 14], "dist1": 9, "dist2": 9, "kl": [9, 14], "diverg": [9, 14], "dist": [9, 11, 14], "latent_dist": 9, "th": 9, "kld": [9, 14], "compute_microbatch_split": 9, "broken": 9, "down": 9, "further": 9, "fit": 9, "memori": 9, "made": 9, "microbatch": 9, "lpx_z": [9, 14], "most": 9, "e": [9, 10, 13, 14], "disentangl": 9, "iwa": [9, 14], "lp_z": 9, "lqz_x": 9, "probabl": 9, "learn": 9, "normal": [9, 11, 14], "recon_loss_fn": [9, 14], "reshape_for_loss": 9, "reshap": [9, 13], "likelihood": [9, 11], "ltype": [9, 11, 14], "set_ltyp": [9, 14], "through": 9, "assert": [9, 13], "weighted_group_kld": 9, "group": 9, "multimodalobject": [9, 14], "common": [9, 10, 13], "calculate_loss": [9, 14], "px_z": [9, 14], "requir": 9, "e_": 9, "p": 9, "reconloss": 9, "store": [9, 14], "bce": [9, 13, 14], "binari": 9, "entropi": 9, "category_c": [9, 13], "categor": [9, 13], "classif": 9, "problem": 9, "gaussian_nl": 9, "nll": 9, "sigma": 9, "orybkin": 9, "l1": 9, "lprob": 9, "mse": 9, "squar": 9, "error": 9, "l2": 9, "norm": 9, "unimodalobject": 9, "onli": [9, 13, 14], "prior_dist": [9, 11, 14], "were": 9, "drawn": 9, "dreg": 9, "estim": 9, "p_": 9, "heta": 9, "fulli": 9, "vectoris": 9, "iffsid": 9, "import": [9, 14], "lightningmodul": 10, "architectur": 10, "configur": [10, 13, 14], "analyse_data": 10, "250": 10, "path_nam": 10, "savedir": 10, "plot": 10, "sne": 10, "under": 10, "valid": 10, "check_config": 10, "configure_optim": 10, "datamod": 10, "when": 10, "pl": 10, "eval_forward": 10, "g": [10, 13], "dure": 10, "get_mod_nam": 10, "get_model": 10, "file": [10, 13, 14], "save_joint_sampl": 10, "16": [10, 13], "random": [10, 11], "randomli": 10, "save_reconstruct": 10, "iter": 10, "over": [9, 10], "test_epoch_end": 10, "end": 10, "test_step": 10, "test_batch": 10, "batch_idx": 10, "loader": 10, "otherwis": 10, "training_step": 10, "train_batch": 10, "validation_epoch_end": 10, "validation_step": 10, "val_batch": 10, "enc": [11, 14], "dec": [11, 14], "likelihood_dist": 11, "post_dist": 11, "inp": 11, "dencoderfactori": 11, "classmethod": 11, "get_nework_class": 11, "enc_nam": 11, "dec_nam": 11, "private_lat": 11, "instanti": 11, "network": [11, 13, 14], "obj_fn": [11, 14], "id_nam": 11, "mod_1": [11, 14], "generate_sampl": 11, "traversal_rang": 11, "rang": [11, 14], "plausibl": 11, "scenario": [11, 12], "loss_fn": 11, "pz_params_priv": 11, "qz_x_param": 11, "set_objective_fn": 11, "case": [11, 13, 14], "repositori": [12, 13], "purpos": 12, "offer": 12, "unifi": [12, 14], "wai": 12, "state": 12, "art": 12, "variat": 12, "autoencod": [12, 14], "arbitrari": [12, 13, 14], "both": [12, 13], "uni": 12, "By": [12, 13], "default": [12, 13, 14], "mvae": 12, "paper": 12, "anyon": 12, "free": 12, "contribut": 12, "synthet": 12, "call": [12, 14], "cdsprite": [12, 13], "design": 12, "capabl": 12, "read": 12, "about": 12, "util": [12, 13], "propos": [12, 13], "link": 12, "ad": 12, "soon": 12, "5": [12, 14], "difficulti": 12, "minim": 12, "moreov": 12, "its": 12, "rigid": 12, "structur": [12, 13], "enabl": 12, "automat": [12, 13, 14], "qualit": 12, "For": [12, 13], "more": [12, 14], "see": [12, 13, 14], "below": [12, 14], "framework": [12, 14], "page": 12, "work": 12, "progress": 12, "sub": 12, "add": 12, "new": 12, "multimodalva": 12, "well": [13, 14], "describ": [13, 14], "your": [13, 14], "In": [13, 14], "prefer": 13, "pickl": 13, "pkl": [13, 14], "pt": 13, "numpi": 13, "npy": 13, "hdf5": 13, "h5": 13, "contain": [13, 14], "jpg": 13, "To": [13, 14], "600": 13, "exp_nam": [13, 14], "lr": [13, 14], "1e": [13, 14], "adam": [13, 14], "pre_train": [13, 14], "null": [13, 14], "seed": [13, 14], "viz_freq": [13, 14], "test_split": [13, 14], "dataset_nam": [13, 14], "cnn": 13, "recon_loss": [13, 14], "modality_2": [13, 14], "txttransform": 13, "cub_capt": 13, "exampl": 13, "download": 13, "readm": 13, "As": 13, "caption": 13, "expect": [13, 14], "so": 13, "semant": 13, "pair": 13, "first": [13, 14], "should": [13, 14], "must": [13, 14], "inherit": [13, 14], "have": 13, "some": 13, "like": [13, 14], "def": [13, 14], "__init__": [13, 14], "super": [13, 14], "text2img_s": 13, "380": 13, "has_mask": 13, "one_hot_encod": 13, "len": [13, 14], "from_numpi": 13, "np": 13, "asarrai": 13, "lengths_to_mask": 13, "unsqueez": [13, 14], "rnn": 13, "pad_sequ": 13, "batch_first": 13, "padding_valu": 13, "data_and_mask": 13, "cat": [13, 14], "isinst": 13, "output_onehot2text": 13, "count_nonzero": 13, "enumer": [13, 14], "phrase": 13, "phr": 13, "newphr": 13, "deepcopi": 13, "stringcount": 13, "40": 13, "insert": 13, "join": 13, "d": [13, 14], "output_process": 13, "add_recon_titl": 13, "170": 13, "input_process": 13, "item": [13, 14], "turn_text2imag": 13, "img_siz": 13, "255": 13, "append": [13, 14], "vstack": 13, "ones": [13, 14], "125": 13, "hstack": 13, "astyp": 13, "uint8": 13, "final": [13, 14], "cv2": 13, "imwrit": 13, "cvtcolor": 13, "color_bgr2rgb": 13, "eventhough": 13, "therefor": 13, "constructor": 13, "eventu": 13, "after": [13, 14], "modality_typ": 13, "distinguish": 13, "chose": 13, "produc": 13, "adjust": 13, "layer": 13, "next": [13, 14], "thing": 13, "sinc": 13, "rewrit": 13, "handl": [13, 14], "perform": 13, "note": 13, "concaten": 13, "last": 13, "anoth": 13, "do": [13, 14], "map": 13, "just": 13, "abov": 13, "mention": 13, "select": [13, 14], "onc": 13, "done": 13, "readi": 13, "launch": 13, "shown": 13, "suit": 13, "print": 13, "text2image_s": 13, "line": [13, 14], "11": 13, "want": 13, "unsupport": 13, "issu": 13, "altern": 13, "try": 13, "incorpor": 13, "matter": 13, "suffix": 13, "startswith": 13, "o": 13, "get_root_fold": 13, "exist": [13, 14], "doe": 13, "isdir": 13, "pathlib": 13, "load_pickl": 13, "h5py": 13, "r": 13, "rais": 13, "except": 13, "unrecogn": 13, "pleas": 13, "32x32x3": 13, "64x64x3": 13, "resolut": 13, "resp": 13, "28x28x1": 13, "pixel": 13, "suitabl": 13, "encourag": 14, "author": 14, "toolkit": 14, "written": 14, "possibl": 14, "dedic": 14, "cours": 14, "support": 14, "show": 14, "step": 14, "tutori": 14, "start": 14, "modelnam": 14, "variaion": 14, "mhw32": 14, "public": 14, "obj_cofig": 14, "model_cofig": 14, "pz": [9, 14], "bare": 14, "minimum": 14, "newli": 14, "integr": 14, "possibli": 14, "single_param": 14, "rsampl": 14, "qz_d": 14, "px_d": 14, "z_d": 14, "find_out_batch_s": 14, "initi": 14, "m": 14, "mod_mu": 14, "mod_logvar": 14, "combin": 14, "integ": 14, "zero": 14, "locat": 14, "output_storag": 14, "proper": 14, "placement": 14, "insid": 14, "thu": 14, "what": 14, "mod_2": 14, "data2": 14, "procedur": 14, "_": 14, "mods_input": 14, "subsample_input_mod": 14, "output_d": 14, "unpack_valu": 14, "sum": 14, "loc_lpx_z": 14, "llik_scal": [11, 14], "mod_": 14, "format": 14, "stack": 14, "ind_loss": 14, "reconstruction_loss": 14, "subsampl": 14, "strategi": 14, "13": 14, "It": 14, "term": 14, "help": 14, "code": 14, "part": 14, "among": 14, "1d": 14, "tensorboard": 14, "__all__": 14, "although": 14, "re": 14, "now": 14, "abl": 14, "follow": 14, "700": 14, "poe_exp": 14, "20": 14, "experi": 14, "cd": 14, "python": 14, "main": 14, "check_indices_pres": 2, "set_likelihood_scal": 7, "_m_dreg_loos": 9, "zss": 9, "multi": 9, "looser": 9, "bound": 9, "averag": 9, "auto": 11}, "objects": {"multimodal_compare.eval": [[5, 0, 0, "-", "eval_cdsprites"], [6, 0, 0, "-", "infer"]], "multimodal_compare.eval.eval_cdsprites": [[5, 1, 1, "", "calculate_cross_coherency"], [5, 1, 1, "", "calculate_joint_coherency"], [5, 1, 1, "", "check_cross_sample_correct"], [5, 1, 1, "", "count_same_letters"], [5, 1, 1, "", "eval_all"], [5, 1, 1, "", "eval_cdsprites_over_seeds"], [5, 1, 1, "", "eval_single_model"], [5, 1, 1, "", "eval_with_classifier"], [5, 1, 1, "", "fill_cats"], [5, 1, 1, "", "find_in_list"], [5, 1, 1, "", "get_all_classifiers"], [5, 1, 1, "", "get_attribute"], [5, 1, 1, "", "get_attribute_from_recon"], [5, 1, 1, "", "get_mean_stats"], [5, 1, 1, "", "get_mod_mappings"], [5, 1, 1, "", "image_to_text"], [5, 1, 1, "", "load_classifier"], [5, 1, 1, "", "load_images"], [5, 1, 1, "", "manhattan_distance"], [5, 1, 1, "", "search_att"], [5, 1, 1, "", "text_to_image"], [5, 1, 1, "", "try_retrieve_atts"]], "multimodal_compare.eval.infer": [[6, 2, 1, "", "MultimodalVAEInfer"]], "multimodal_compare.eval.infer.MultimodalVAEInfer": [[6, 3, 1, "", "eval_statistics"], [6, 3, 1, "", "get_base_path"], [6, 3, 1, "", "get_config"], [6, 3, 1, "", "get_datamodule"], [6, 3, 1, "", "get_wrapped_model"], [6, 3, 1, "", "make_dataloaders"]], "multimodal_compare.models": [[0, 0, 0, "-", "config_cls"], [1, 0, 0, "-", "dataloader"], [2, 0, 0, "-", "datasets"], [3, 0, 0, "-", "decoders"], [4, 0, 0, "-", "encoders"], [7, 0, 0, "-", "mmvae_base"], [8, 0, 0, "-", "mmvae_models"], [9, 0, 0, "-", "objectives"], [10, 0, 0, "-", "trainer"], [11, 0, 0, "-", "vae"]], "multimodal_compare.models.config_cls": [[0, 2, 1, "", "Config"]], "multimodal_compare.models.config_cls.Config": [[0, 3, 1, "", "_define_params"], [0, 3, 1, "", "_get_mods_config"], [0, 3, 1, "", "_load_config"], [0, 3, 1, "", "_parse_args"], [0, 3, 1, "", "_setup_savedir"], [0, 3, 1, "", "change_seed"], [0, 3, 1, "", "dump_config"], [0, 3, 1, "", "find_version"], [0, 3, 1, "", "get_vis_dir"], [0, 3, 1, "", "parse_params"]], "multimodal_compare.models.dataloader": [[1, 2, 1, "", "DataModule"]], "multimodal_compare.models.dataloader.DataModule": [[1, 3, 1, "", "check_load_testdata"], [1, 3, 1, "", "check_testdata_avail"], [1, 3, 1, "", "collate_fn"], [1, 3, 1, "", "get_dataset_class"], [1, 3, 1, "", "get_label_for_indices"], [1, 3, 1, "", "get_labels"], [1, 3, 1, "", "get_num_samples"], [1, 3, 1, "", "make_masks"], [1, 3, 1, "", "predict_dataloader"], [1, 3, 1, "", "prepare_data_classes"], [1, 3, 1, "", "prepare_singlemodal"], [1, 3, 1, "", "setup"], [1, 3, 1, "", "test_dataloader"], [1, 3, 1, "", "train_dataloader"], [1, 3, 1, "", "val_dataloader"]], "multimodal_compare.models.datasets": [[2, 2, 1, "", "BaseDataset"], [2, 2, 1, "", "CDSPRITESPLUS"], [2, 2, 1, "", "CELEBA"], [2, 2, 1, "", "CUB"], [2, 2, 1, "", "FASHIONMNIST"], [2, 2, 1, "", "MNIST_SVHN"], [2, 2, 1, "", "POLYMNIST"], [2, 2, 1, "", "SPRITES"]], "multimodal_compare.models.datasets.BaseDataset": [[2, 3, 1, "", "_mod_specific_loaders"], [2, 3, 1, "", "_mod_specific_savers"], [2, 3, 1, "", "_postprocess"], [2, 3, 1, "", "_postprocess_all2img"], [2, 3, 1, "", "_preprocess"], [2, 3, 1, "", "_preprocess_images"], [2, 3, 1, "", "_preprocess_text_onehot"], [2, 3, 1, "", "current_datatype"], [2, 3, 1, "", "eval_statistics_fn"], [2, 3, 1, "", "get_data"], [2, 3, 1, "", "get_data_raw"], [2, 3, 1, "", "get_labels"], [2, 3, 1, "", "get_processed_recons"], [2, 3, 1, "", "get_test_data"], [2, 3, 1, "", "labels"], [2, 3, 1, "", "save_traversals"]], "multimodal_compare.models.datasets.CDSPRITESPLUS": [[2, 3, 1, "", "_mod_specific_loaders"], [2, 3, 1, "", "_mod_specific_savers"], [2, 3, 1, "", "_postprocess_images"], [2, 3, 1, "", "_postprocess_text"], [2, 3, 1, "", "_preprocess_images"], [2, 3, 1, "", "_preprocess_text"], [2, 3, 1, "", "eval_statistics_fn"], [2, 4, 1, "", "feature_dims"], [2, 3, 1, "", "labels"], [2, 3, 1, "", "save_recons"], [2, 3, 1, "", "set_vis_image_shape"]], "multimodal_compare.models.datasets.CELEBA": [[2, 3, 1, "", "_mod_specific_loaders"], [2, 3, 1, "", "_mod_specific_savers"], [2, 3, 1, "", "_postprocess_all2img"], [2, 3, 1, "", "_postprocess_atts"], [2, 3, 1, "", "_postprocess_images"], [2, 3, 1, "", "_preprocess_atts"], [2, 3, 1, "", "_preprocess_images"], [2, 4, 1, "", "feature_dims"], [2, 3, 1, "", "save_recons"], [2, 3, 1, "", "save_traversals"]], "multimodal_compare.models.datasets.CUB": [[2, 3, 1, "", "_mod_specific_loaders"], [2, 3, 1, "", "_mod_specific_savers"], [2, 3, 1, "", "_postprocess_text"], [2, 3, 1, "", "_preprocess_images"], [2, 3, 1, "", "_preprocess_text"], [2, 3, 1, "", "_preprocess_text_onehot"], [2, 4, 1, "", "feature_dims"], [2, 3, 1, "", "labels"], [2, 3, 1, "", "save_recons"]], "multimodal_compare.models.datasets.FASHIONMNIST": [[2, 3, 1, "", "_mod_specific_loaders"], [2, 3, 1, "", "_mod_specific_savers"], [2, 3, 1, "", "_postprocess_image"], [2, 3, 1, "", "_postprocess_label"], [2, 3, 1, "", "_process_image"], [2, 3, 1, "", "_process_label"], [2, 4, 1, "", "feature_dims"], [2, 3, 1, "", "get_data_raw"], [2, 3, 1, "", "labels"], [2, 3, 1, "", "save_recons"]], "multimodal_compare.models.datasets.MNIST_SVHN": [[2, 3, 1, "", "_mod_specific_loaders"], [2, 3, 1, "", "_mod_specific_savers"], [2, 3, 1, "", "_postprocess_all2img"], [2, 3, 1, "", "_postprocess_mnist"], [2, 3, 1, "", "_postprocess_svhn"], [2, 3, 1, "", "_process_mnist"], [2, 3, 1, "", "_process_svhn"], [2, 3, 1, "", "check_indices_present"], [2, 4, 1, "", "feature_dims"], [2, 3, 1, "", "labels"], [2, 3, 1, "", "save_recons"]], "multimodal_compare.models.datasets.POLYMNIST": [[2, 3, 1, "", "_mod_specific_loaders"], [2, 3, 1, "", "_mod_specific_savers"], [2, 3, 1, "", "_postprocess_mnist"], [2, 3, 1, "", "_process_mnist"], [2, 4, 1, "", "feature_dims"], [2, 3, 1, "", "save_recons"], [2, 3, 1, "", "save_traversals"]], "multimodal_compare.models.datasets.SPRITES": [[2, 3, 1, "", "_mod_specific_loaders"], [2, 3, 1, "", "_mod_specific_savers"], [2, 3, 1, "", "_postprocess_actions"], [2, 3, 1, "", "_postprocess_all2img"], [2, 3, 1, "", "_postprocess_attributes"], [2, 3, 1, "", "_postprocess_frames"], [2, 3, 1, "", "eval_statistics_fn"], [2, 4, 1, "", "feature_dims"], [2, 3, 1, "", "get_actions"], [2, 3, 1, "", "get_attributes"], [2, 3, 1, "", "get_frames"], [2, 3, 1, "", "iter_over_inputs"], [2, 3, 1, "", "labels"], [2, 3, 1, "", "make_masks"], [2, 3, 1, "", "save_recons"], [2, 3, 1, "", "save_traversals"]], "multimodal_compare.models.decoders": [[3, 2, 1, "", "Dec_CNN"], [3, 2, 1, "", "Dec_FNN"], [3, 2, 1, "", "Dec_MNIST"], [3, 2, 1, "", "Dec_MNIST2"], [3, 2, 1, "", "Dec_PolyMNIST"], [3, 2, 1, "", "Dec_SVHN"], [3, 2, 1, "", "Dec_SVHN2"], [3, 2, 1, "", "Dec_Transformer"], [3, 2, 1, "", "Dec_TransformerIMG"], [3, 2, 1, "", "Dec_TxtTransformer"], [3, 2, 1, "", "Dec_VideoGPT"], [3, 2, 1, "", "VaeDecoder"], [3, 1, 1, "", "extra_hidden_layer"]], "multimodal_compare.models.decoders.Dec_CNN": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 3, 1, "", "forward"], [3, 4, 1, "", "training"]], "multimodal_compare.models.decoders.Dec_FNN": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 3, 1, "", "forward"], [3, 4, 1, "", "training"]], "multimodal_compare.models.decoders.Dec_MNIST": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 3, 1, "", "forward"], [3, 4, 1, "", "training"]], "multimodal_compare.models.decoders.Dec_MNIST2": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 3, 1, "", "forward"], [3, 4, 1, "", "training"]], "multimodal_compare.models.decoders.Dec_PolyMNIST": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 3, 1, "", "forward"], [3, 4, 1, "", "training"]], "multimodal_compare.models.decoders.Dec_SVHN": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 3, 1, "", "forward"], [3, 4, 1, "", "training"]], "multimodal_compare.models.decoders.Dec_SVHN2": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 3, 1, "", "forward"], [3, 4, 1, "", "training"]], "multimodal_compare.models.decoders.Dec_Transformer": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 3, 1, "", "forward"], [3, 4, 1, "", "training"]], "multimodal_compare.models.decoders.Dec_TransformerIMG": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 3, 1, "", "forward"], [3, 4, 1, "", "training"]], "multimodal_compare.models.decoders.Dec_TxtTransformer": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 3, 1, "", "forward"], [3, 4, 1, "", "training"]], "multimodal_compare.models.decoders.Dec_VideoGPT": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 3, 1, "", "forward"], [3, 4, 1, "", "training"]], "multimodal_compare.models.decoders.VaeDecoder": [[3, 4, 1, "", "_is_full_backward_hook"], [3, 4, 1, "", "training"]], "multimodal_compare.models.encoders": [[4, 2, 1, "", "Enc_CNN"], [4, 2, 1, "", "Enc_FNN"], [4, 2, 1, "", "Enc_MNIST"], [4, 2, 1, "", "Enc_MNIST2"], [4, 2, 1, "", "Enc_PolyMNIST"], [4, 2, 1, "", "Enc_SVHN"], [4, 2, 1, "", "Enc_SVHN2"], [4, 2, 1, "", "Enc_Transformer"], [4, 2, 1, "", "Enc_TransformerIMG"], [4, 2, 1, "", "Enc_TxtTransformer"], [4, 2, 1, "", "Enc_VideoGPT"], [4, 2, 1, "", "VaeComponent"], [4, 2, 1, "", "VaeEncoder"], [4, 1, 1, "", "extra_hidden_layer"]], "multimodal_compare.models.encoders.Enc_CNN": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.Enc_FNN": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.Enc_MNIST": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.Enc_MNIST2": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.Enc_PolyMNIST": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.Enc_SVHN": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.Enc_SVHN2": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.Enc_Transformer": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.Enc_TransformerIMG": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.Enc_TxtTransformer": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.Enc_VideoGPT": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.VaeComponent": [[4, 4, 1, "", "_is_full_backward_hook"], [4, 3, 1, "", "forward"], [4, 4, 1, "", "training"]], "multimodal_compare.models.encoders.VaeEncoder": [[4, 4, 1, "", "_backward_hooks"], [4, 4, 1, "", "_buffers"], [4, 4, 1, "", "_forward_hooks"], [4, 4, 1, "", "_forward_pre_hooks"], [4, 4, 1, "", "_is_full_backward_hook"], [4, 4, 1, "", "_load_state_dict_post_hooks"], [4, 4, 1, "", "_load_state_dict_pre_hooks"], [4, 4, 1, "", "_modules"], [4, 4, 1, "", "_non_persistent_buffers_set"], [4, 4, 1, "", "_parameters"], [4, 4, 1, "", "_state_dict_hooks"], [4, 4, 1, "", "training"]], "multimodal_compare.models.mmvae_base": [[7, 2, 1, "", "TorchMMVAE"]], "multimodal_compare.models.mmvae_base.TorchMMVAE": [[7, 4, 1, "", "_is_full_backward_hook"], [7, 3, 1, "", "add_vaes"], [7, 3, 1, "", "decode"], [7, 3, 1, "", "encode"], [7, 3, 1, "", "forward"], [7, 3, 1, "", "get_missing_modalities"], [7, 5, 1, "", "latent_factorization"], [7, 3, 1, "", "make_output_dict"], [7, 3, 1, "", "modality_mixing"], [7, 3, 1, "", "objective"], [7, 3, 1, "", "product_of_experts"], [7, 5, 1, "", "pz_params"], [7, 3, 1, "", "set_likelihood_scales"], [7, 4, 1, "", "training"]], "multimodal_compare.models.mmvae_models": [[8, 2, 1, "", "DMVAE"], [8, 2, 1, "", "MOE"], [8, 2, 1, "", "MoPOE"], [8, 2, 1, "", "POE"]], "multimodal_compare.models.mmvae_models.DMVAE": [[8, 4, 1, "", "_is_full_backward_hook"], [8, 3, 1, "", "forward"], [8, 3, 1, "", "get_remaining_mods_data"], [8, 3, 1, "", "logsumexp"], [8, 3, 1, "", "objective"], [8, 5, 1, "", "pz_params"], [8, 4, 1, "", "training"]], "multimodal_compare.models.mmvae_models.MOE": [[8, 4, 1, "", "_is_full_backward_hook"], [8, 3, 1, "", "forward"], [8, 3, 1, "", "objective"], [8, 5, 1, "", "pz_params"], [8, 3, 1, "", "reconstruct"], [8, 4, 1, "", "training"]], "multimodal_compare.models.mmvae_models.MoPOE": [[8, 4, 1, "", "_is_full_backward_hook"], [8, 3, 1, "", "forward"], [8, 3, 1, "", "mixture_component_selection"], [8, 3, 1, "", "modality_mixing"], [8, 3, 1, "", "moe_fusion"], [8, 3, 1, "", "objective"], [8, 3, 1, "", "poe_fusion"], [8, 5, 1, "", "pz_params"], [8, 3, 1, "", "reparameterize"], [8, 3, 1, "", "reweight_weights"], [8, 3, 1, "", "set_subsets"], [8, 4, 1, "", "training"]], "multimodal_compare.models.mmvae_models.POE": [[8, 4, 1, "", "_is_full_backward_hook"], [8, 3, 1, "", "forward"], [8, 3, 1, "", "modality_mixing"], [8, 3, 1, "", "objective"], [8, 3, 1, "", "prior_expert"], [8, 5, 1, "", "pz_params"], [8, 4, 1, "", "training"]], "multimodal_compare.models.objectives": [[9, 2, 1, "", "BaseObjective"], [9, 2, 1, "", "MultimodalObjective"], [9, 2, 1, "", "ReconLoss"], [9, 2, 1, "", "UnimodalObjective"]], "multimodal_compare.models.objectives.BaseObjective": [[9, 3, 1, "", "calc_kld"], [9, 3, 1, "", "calc_klds"], [9, 3, 1, "", "compute_microbatch_split"], [9, 3, 1, "", "elbo"], [9, 3, 1, "", "iwae"], [9, 3, 1, "", "normalize"], [9, 3, 1, "", "recon_loss_fn"], [9, 3, 1, "", "reshape_for_loss"], [9, 3, 1, "", "set_ltype"], [9, 3, 1, "", "weighted_group_kld"]], "multimodal_compare.models.objectives.MultimodalObjective": [[9, 3, 1, "", "_m_dreg_looser"], [9, 3, 1, "", "calculate_loss"], [9, 3, 1, "", "dreg"], [9, 3, 1, "", "elbo"], [9, 3, 1, "", "iwae"]], "multimodal_compare.models.objectives.ReconLoss": [[9, 3, 1, "", "bce"], [9, 3, 1, "", "category_ce"], [9, 3, 1, "", "gaussian_nll"], [9, 3, 1, "", "l1"], [9, 3, 1, "", "lprob"], [9, 3, 1, "", "mse"]], "multimodal_compare.models.objectives.UnimodalObjective": [[9, 3, 1, "", "calculate_loss"], [9, 3, 1, "", "dreg"], [9, 3, 1, "", "elbo"], [9, 3, 1, "", "iwae"]], "multimodal_compare.models.trainer": [[10, 2, 1, "", "MultimodalVAE"]], "multimodal_compare.models.trainer.MultimodalVAE": [[10, 3, 1, "", "analyse_data"], [10, 3, 1, "", "check_config"], [10, 3, 1, "", "configure_optimizers"], [10, 5, 1, "", "datamod"], [10, 3, 1, "", "eval_forward"], [10, 3, 1, "", "get_mod_names"], [10, 3, 1, "", "get_model"], [10, 3, 1, "", "save_joint_samples"], [10, 3, 1, "", "save_reconstructions"], [10, 3, 1, "", "test_epoch_end"], [10, 3, 1, "", "test_step"], [10, 3, 1, "", "training_step"], [10, 3, 1, "", "validation_epoch_end"], [10, 3, 1, "", "validation_step"]], "multimodal_compare.models.vae": [[11, 2, 1, "", "BaseVae"], [11, 2, 1, "", "DencoderFactory"], [11, 2, 1, "", "VAE"]], "multimodal_compare.models.vae.BaseVae": [[11, 4, 1, "", "_is_full_backward_hook"], [11, 3, 1, "", "decode"], [11, 3, 1, "", "encode"], [11, 3, 1, "", "forward"], [11, 4, 1, "", "training"]], "multimodal_compare.models.vae.DencoderFactory": [[11, 3, 1, "", "get_nework_classes"]], "multimodal_compare.models.vae.VAE": [[11, 4, 1, "", "_backward_hooks"], [11, 4, 1, "", "_buffers"], [11, 4, 1, "", "_forward_hooks"], [11, 4, 1, "", "_forward_pre_hooks"], [11, 4, 1, "", "_is_full_backward_hook"], [11, 4, 1, "", "_load_state_dict_post_hooks"], [11, 4, 1, "", "_load_state_dict_pre_hooks"], [11, 4, 1, "", "_modules"], [11, 4, 1, "", "_non_persistent_buffers_set"], [11, 4, 1, "", "_parameters"], [11, 4, 1, "", "_state_dict_hooks"], [11, 3, 1, "", "generate_samples"], [11, 3, 1, "", "objective"], [11, 5, 1, "", "pz_params"], [11, 5, 1, "", "pz_params_private"], [11, 5, 1, "", "qz_x_params"], [11, 3, 1, "", "set_objective_fn"], [11, 4, 1, "", "training"]]}, "objtypes": {"0": "py:module", "1": "py:function", "2": "py:class", "3": "py:method", "4": "py:attribute", "5": "py:property"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "function", "Python function"], "2": ["py", "class", "Python class"], "3": ["py", "method", "Python method"], "4": ["py", "attribute", "Python attribute"], "5": ["py", "property", "Python property"]}, "titleterms": {"config": [0, 13], "class": [0, 2, 7, 10, 11, 13], "dataload": 1, "dataset": [2, 5, 13], "decod": 3, "encod": 4, "evalu": 5, "cdsprite": 5, "infer": 6, "modul": 6, "multimod": [7, 8, 12], "vae": [7, 8, 11, 12], "base": 7, "model": [8, 14], "object": 9, "multimodalva": 10, "comparison": 12, "toolkit": 12, "tutori": 12, "code": 12, "document": 12, "add": [13, 14], "new": [13, 14], "support": 13, "data": 13, "format": 13, "ad": [13, 14], "differ": 13, "gener": 14, "requir": 14}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 57}, "alltitles": {"Config class": [[0, "module-multimodal_compare.models.config_cls"]], "DataLoader": [[1, "dataloader"]], "Dataset Classes": [[2, "module-multimodal_compare.models.datasets"]], "Decoders": [[3, "decoders"]], "Encoders": [[4, "encoders"]], "Evaluate on CdSprites+ dataset": [[5, "evaluate-on-cdsprites-dataset"]], "Inference module": [[6, "module-multimodal_compare.eval.infer"]], "Multimodal VAE Base Class": [[7, "module-multimodal_compare.models.mmvae_base"]], "Multimodal VAE models": [[8, "module-multimodal_compare.models.mmvae_models"]], "Objectives": [[9, "objectives"]], "MultimodalVAE class": [[10, "multimodalvae-class"]], "VAE class": [[11, "module-multimodal_compare.models.vae"]], "Multimodal VAE Comparison Toolkit": [[12, "multimodal-vae-comparison-toolkit"]], "Tutorials": [[12, null]], "Code documentation": [[12, null]], "Add a new model": [[14, "add-a-new-model"]], "General requirements": [[14, "general-requirements"]], "Adding a new model": [[14, "adding-a-new-model"]], "Add a new dataset": [[13, "add-a-new-dataset"]], "Supported data formats, config": [[13, "supported-data-formats-config"]], "Adding a new dataset class": [[13, "adding-a-new-dataset-class"]], "Different data formats": [[13, "different-data-formats"]]}, "indexentries": {}})