from __future__ import print_function

import os
import sys
import random

import numpy as np
import torch
from torch.utils.data import DataLoader
import wandb
import matplotlib.pyplot as plt

from distributions.clf import ClfImgMNIST, ClfImgSVHN, ClfText, ClfImgMMNIST, ClfImgTransMMNIST, ClfCelebAImg, ClfCelebAText
from models.clf_model import Model
from models import JVAE, MVAE, MMVAE, MoPoE, MMJSD, CRMVAE, CRMVAE_Update, CRMVAE_New, CRMVAE_New2, CRMVAE2

# import dataset
sys.path.append('../')
from multimodal_datasets.SVHNMNISTDataset import SVHNMNIST
from multimodal_datasets.CelebADataset import CelebaDataset
from multimodal_datasets.PolyMNIST import PolyMNIST
from multimodal_datasets.CUBImageCaptions import MMCUB
dir_data = os.path.join(os.getcwd().replace('jmvae_journal', ''), 'multimodal_datasets/data')

models = {"JVAE": JVAE, "MVAE": MVAE, "MMVAE": MMVAE, "MoPoE": MoPoE, "MMJSD": MMJSD, "CRMVAE": CRMVAE, "MVTCAE": CRMVAE, "MMJSD_PoE": CRMVAE,
 "CRMVAE_Update": CRMVAE_Update, "MVTCAE_Update": CRMVAE_Update,
 "CRMVAE_New": CRMVAE_New, "CRMVAE_New2": CRMVAE_New2, "CRMVAE2": CRMVAE2}

api = wandb.Api()

def load_model(path, device, params_update={}, return_dataset=True):
    run = api.run(path)
    params = run.config
    params.update(params_update)

    torch.backends.cudnn.benchmark = True
    torch.manual_seed(params["seed"])
    np.random.seed(params["seed"])
    random.seed(params["seed"])

    if type(params["modality_id"]) != list:
        params["modality_id"] =  [int(item) for item in params["modality_id"].split(',')]

    if params["data"] == "SVHNMNIST":
        d_size_m1 = 28**2
        d_size_m2 = 3*32**2
        d_size_m3 = 71

        if params["rec_weight_all"] is None:
            params["rec_weight_all"] = [d_size_m2/d_size_m1, 1.0, d_size_m2/d_size_m3]

        if return_dataset:
            svhnmnist_train = SVHNMNIST(dir_data, len_sequence=8, train=True,
                                        data_multiplications=20, two_modality=False)
            svhnmnist_test = SVHNMNIST(dir_data, len_sequence=8, train=False,
                                    data_multiplications=20, two_modality=False)

            train_loader = DataLoader(svhnmnist_train, batch_size=params["batch_size"], shuffle=True, num_workers=5, drop_last=True)
            test_loader = DataLoader(svhnmnist_test, batch_size=params["batch_size"], shuffle=True, num_workers=5, drop_last=True)            

    elif params["data"]=="CelebAText":
        d_size_m1 = 64*64*3
        d_size_m2 = 256

        if params["rec_weight_all"] is None:
            params["rec_weight_all"] = [1.0, d_size_m1/d_size_m2]

        train_dataset = CelebaDataset(dir_data, partition=0, use_text=True)
        eval_dataset = CelebaDataset(dir_data, partition=1, use_text=True)

        train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=8, drop_last=True)
        test_loader = DataLoader(eval_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=8, drop_last=True)

    elif params["data"]=="PolyMNIST":
        if params["rec_weight_all"] is None:
            params["rec_weight_all"] = [1.0, 1.0, 1.0, 1.0, 1.0]

        train_dataset = PolyMNIST(os.path.join(dir_data, "MMNIST"), split="train")
        eval_dataset = PolyMNIST(os.path.join(dir_data, "MMNIST"), split="test")

        train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=1, drop_last=True)
        test_loader = DataLoader(eval_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=1, drop_last=True)

    elif params["data"] == "TranslatedPolyMNIST":
        if params["rec_weight_all"] is None:
            params["rec_weight_all"] = [1.0, 1.0, 1.0, 1.0, 1.0]

        data_path = os.path.join(dir_data, "translated_polymnist_new")
        train_dataset = PolyMNIST(data_path, split="train")
        eval_dataset = PolyMNIST(data_path, split="test")

        train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=1, drop_last=True)
        test_loader = DataLoader(eval_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=1, drop_last=True)

    elif params["data"] == "CUB":
        d_size_m1 = 64*64*3
        d_size_m2 = 32

        if params["rec_weight_all"] is None:
            params["rec_weight_all"] = [1.0, 1.0]

        train_dataset = MMCUB(dir_data, split="train")
        eval_dataset = MMCUB(dir_data, split="test")

        train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=1, drop_last=True)
        test_loader = DataLoader(eval_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=1, drop_last=True)        

    else:
        raise NotImplementedError()

    params["optimizer_params"] = {"lr": run.config["lr"], "weight_decay": run.config["weight_decay"]}
    print(params)

    model = models[params["model"]](params, device=device)
    model.load_models(path)
    model.eval()

    # load classifier
    model.load_clf()    

    if return_dataset:
        return model, train_loader, test_loader

    return model

def imshow(x, show=True):
    img_plot = np.transpose(x.detach().cpu().numpy(), (1, 2, 0))
    if show:
        plt.imshow(img_plot)
    else:
        return img_plot