
import sys
root = "/home/******/Documents/code/mld/"
sys.path.append(root)

import pickle
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer, seed_everything

from src.dataLoaders.MnistSvhnText.MnistSvhnText import get_data_set_svhn_mnist

from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from src.utils import create_forlder
import os
import torch
from src.unimodal_vae.BaseAE import AE
from src.models.MLD_uni import MLD_Uni
from src.models.LateFusionAE import LateFusionAE
from src.dataLoaders.MnistSvhnText.modalities import MNIST, SVHN, LABEL

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)


parser.add_argument('--model', type=str, default="mmld_uni",
                    help=("Options possible are mvae, mmvae, mopoe"))

parser.add_argument('--dropout', type=str, default=0.2,
                    help=("Options possible are "))

parser.add_argument('--seed', type=str, default=0,
                    help=("Options possible are "))

do_evaluation = False
do_fd = True
do_class = False
eval_epoch = 15
log_epoch = 50


NUM_epoch = 150
batch_size = 128

r_w_mnist = 1.0
r_w_svhn = 1.0
r_w_text = 1.0
lr = 1e-4


PATHS = {
    "mnist": "/home/******/work/mld/trained_models/MNISTSVHN/ae_deter_4_laplace_64mnist/version_0/checkpoints/epoch=149-step=1314000.ckpt",
    "svhn": "/home/******/work/mld/trained_models/MNISTSVHN/ae_deter_4_laplace_64svhn/version_0/checkpoints/epoch=149-step=1314000.ckpt"}


train, test = get_data_set_svhn_mnist(with_text=False)

train_loader = DataLoader(train, batch_size=batch_size,
                          shuffle=True,
                          num_workers=8, drop_last=True, pin_memory=True)

test_loader = DataLoader(test, batch_size=512,
                         shuffle=True,
                         num_workers=4, drop_last=True, pin_memory=True)


CHECKPOINT_DIR = "trained_models/"
create_forlder(CHECKPOINT_DIR)
CHECKPOINT_DIR = os.path.join(CHECKPOINT_DIR, 'MNISTSVHN')


def get_stat(file_name):
    t = pickle.load(open(file_name, "rb"))
    for key in t.keys():
        if key != 'cat':
            t[key]['mean'] = t[key]['mean'].to("cuda")
            t[key]['std'] = t[key]['std'].to("cuda")
    return t


if __name__ == "__main__":

    args = parser.parse_args()

    dim_mnist = 16
    dim_svhn = 64
    dim_label = 10

    modalities_list = [
        MNIST(latent_dim=dim_mnist, reconstruction_weight=r_w_mnist,
              deterministic=True, lhood_name="laplace"),
        SVHN(latent_dim=dim_svhn, reconstruction_weight=r_w_svhn,
             deterministic=True, lhood_name="laplace", v2='3')
    ]
    aes = []
    for mod in modalities_list:
        aes.append(
            AE.load_from_checkpoint(PATHS[mod.name], modality=mod)
        )
    aes_model = LateFusionAE(aes=aes)
    aes_model.eval()

    dropout = args.dropout

    seeds = [int(args.seed)]

    for seed in seeds:

        seed_everything(seed)

        model = MLD_Uni(
            aes=aes_model,
            model_name="mmld uni",
            batch_size=batch_size,
            train_loader=train_loader,
            test_loader=test_loader,
            eval_epoch=eval_epoch,
            do_evaluation=do_evaluation,
            do_fd=do_fd,
            nb_samples=8,
            n_fd=5000,
            log_epoch=log_epoch,
            lr=lr,
            nb_batchs=None,
            init_dim=512,
            do_class=False,
            time_dim=256,
            unet_architecture=(1, 1),
            unet_type='linear',
            dropout=float(dropout),
            preprocess_type="modality",
            preprocess_op="standerdize",
            check_stat=False,
            group_norm=8,
            # preprocess_type = None,
            # preprocess_op = None,
            # check_stat = False,
            betas=[0.1, 20],
            train_batch_size=512,
            N_step=250,
            importance_sampling=False,
            ll_weight=False,
            debug=False,
            use_attention=False,
            shift_scale=False,
            num_head=3,
            use_ema=True,
            cross_gen="repaint"
        )
        CHECKPOINT_DIR = os.path.join(CHECKPOINT_DIR, str(args.model))

        tb_logger = TensorBoardLogger(save_dir=CHECKPOINT_DIR,
                                      name=str(args.seed))
        # model.stat = get_stat("/home/******/work/mld/trained_models/MNISTSVHN/mmld_uni/3/version_0/stat.pickle")
        print(model)
        trainer = pl.Trainer(
            logger=tb_logger,
            check_val_every_n_epoch=5,
            accelerator='gpu',
            devices=1,
            max_epochs=NUM_epoch,
            default_root_dir=CHECKPOINT_DIR,
            num_sanity_val_steps=0,
            # resume_from_checkpoint= "/home/******/work/mld/trained_models/MNISTSVHN/mmld_uni/3/version_0/checkpoints/epoch=9-step=87600.ckpt"
        )

        trainer.fit(model=model, train_dataloaders=model.train_loader,
                    val_dataloaders=model.test_loader)
