from __future__ import print_function

import os
import sys
from abc import abstractmethod
import itertools

import torch
from torch import optim
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
import wandb

from pixyz.utils import tolist
from tqdm import tqdm
import numpy as np
from pixyz.distributions import Normal
from pixyz.models.model import Model
from pixyz.utils import get_dict_values, detach_dict
from copy import copy
import Levenshtein


from .metrics import accuracy, avg_precision, calculate_fretchet, calculate_activation, get_fid
from .utils import get_mean, get_ac_number_of_modalities, get_modality_id, metric_update, text_to_tensor

sys.path.append('../')
from distributions.distributions import *
from distributions.networks import *
from distributions.clf import *
from distributions.inception import InceptionV3
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

class Base(Model):
    def __init__(self, params, device="cpu"):
        for k in params:
            setattr(self, k, params[k])
        self.device = device

        # set distributions
        self.dist_dict = {}
        self.set_distributions()

        # get a loss function
        loss = self.get_loss()

        # load Pretrained InceptionV3
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_v3 = InceptionV3([block_idx]).to(device)

        super().__init__(loss=loss.mean(),
                         distributions=list(self.dist_dict.values()),
                         optimizer=self.get_optimizer(),
                         optimizer_params=self.optimizer_params, 
                         clip_grad_norm=1.0,
                         retain_graph=False)

        params = self.latent_clf.parameters()
        self.optimizer_clf = self.get_optimizer()(params, **self.optimizer_params)

    def set_optimizer(self, distributions):
        distributions = nn.ModuleList(tolist(distributions))
        params = distributions.parameters()
        self.optimizer = self.get_optimizer()(params, **self.optimizer_params)

    def get_optimizer(self):
        optimizers = {'rmsprop': optim.RMSprop,
                      'adadelta': optim.Adadelta,
                      'adagrad': optim.Adagrad,
                      'adam': optim.Adam}
        return optimizers[self.optimizer_name]

    @abstractmethod
    def get_loss(self):
        raise NotImplementedError()

    def get_aggregated_inference(self):
        raise NotImplementedError()

    def get_caption(self, label):
        return str(label.item())

    def id2txt(self, x2, blank_id=69):
        return ["".join([self.alphabet[ids] for ids in _x2.argmax(-1) if ids!=blank_id]) for _x2 in x2.cpu().numpy()]

    def pred_text(self, x2):
        numbers = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
        return torch.tensor([np.argmin([Levenshtein.distance(_str, _str2) for _str in numbers]) for _str2 in self.id2txt(x2)])

    def plot_data(self, epoch, loader, batch, _plot_id, _plot_sample, _plot_label=None):
        if (self.data != "CelebAText") and (self.data != "SVHNMNIST") and (self.data != "CUB"):
            if _plot_label is not None:
                images = [wandb.Image(x, caption=self.get_caption(l)) for x, l in zip(_plot_sample, _plot_label)]
            else:
                images = [wandb.Image(x) for x in _plot_sample]

        elif (self.data == "CelebAText"):
            if _plot_id==0:
                images = [wandb.Image(x) for x in _plot_sample]
            else:
                images = None

        elif (self.data == "CUB"):
            if _plot_id==0: # Image
                images = [wandb.Image(x) for x in _plot_sample]
            if _plot_id==1: # Text
                data = [[loader.dataset.convert_text(batch["x1"][j].cpu()),
                            loader.dataset.convert_text(_plot_sample[j].cpu())] for j in range(len(_plot_sample))]
                images = wandb.Table(data=data, columns=["Original", "Generation"])

        elif (self.data == "SVHNMNIST"):
            if _plot_id == 2:
                _plot_sample = text_to_tensor(_plot_sample, alphabet=self.alphabet, img_size=(3, 128, 128))

            if _plot_label is not None:
                images = [wandb.Image(x, caption=self.get_caption(l)) for x, l in zip(_plot_sample, _plot_label)]
            else:
                images = [wandb.Image(x) for x in _plot_sample]

        return images

    def eval_fid_flag(self, _plot_id, cond_id=None):
        if not (self.data == "PolyMNIST" or self.data == "TranslatedPolyMNIST"):
            if cond_id == None:
                return False
            elif len(cond_id) == 1:
                return not (((self.data == "CelebAText") and (_plot_id==1)) or
                            ((self.data == "SVHNMNIST") and ((_plot_id==2) or (cond_id[0]==2))) or
                            ((self.data == "CUB") and (_plot_id==1)))
        else:
            return True
        return False

    def eval_metrics(self, epoch, train_flag, loader, batch, metrics_dict, fid_metrics_dict,
                    iterative=None, eval_fid=False, eval_reconst=False, plot_images=False, plot_image_size=8, wandb_log=True,
                    std=1.0, sampling=True):
        self.eval()

        sample_mean_z_all = []
        pred_random_label_all =[]

        # generate random images
        fake_imgs = self.random_generation(self.modality_id, batch["x0"].shape[0])

        for j in range(1, len(self.modality_id)+1):    
            for conb in itertools.combinations(self.modality_id, j):
                if (set(self.modality_id) == set(conb)) or eval_reconst:
                    generation_id = self.modality_id
                else:
                    generation_id = list(set(self.modality_id)-set(conb))
                # select inputs
                var = ["x"+str(k) for k in conb]
                name = "".join(var)
                _batch = get_dict_values(batch, var, return_dict=True)
                
                if j == 1: # unimodal
                    # generate random images
                    _plot_id = conb[0]
                    true_img = batch["x%d" %_plot_id]
                    fake_img = fake_imgs["x%d" %_plot_id]

                    ## prediction labels
                    if (self.data != "CUB") and (self.data != "CelebAText"):
                        with torch.no_grad():
                            if (self.data == "SVHNMNIST") and (_plot_id==2):
                                _plot_label = self.pred_text(fake_img)
                            else:
                                preds = self.clf["f_%d" %_plot_id](fake_img)
                                _plot_label = preds.argmax(-1)

                        pred_random_label_all.append(_plot_label.cpu().float().numpy())

                    ## FID
                    if eval_fid and self.eval_fid_flag(_plot_id):
                        if (self.data == "SVHNMNIST") and (_plot_id==0):
                            _true_img, _fake_img = true_img.expand(-1, 3, -1, -1), fake_img.expand(-1, 3, -1, -1)
                        else:
                            _true_img, _fake_img = true_img, fake_img

                        act = calculate_activation(_true_img, self.inception_v3)
                        metric_update(fid_metrics_dict["true"], "%d" % _plot_id, act)

                        act = calculate_activation(_fake_img, self.inception_v3)
                        metric_update(fid_metrics_dict["fake"], "%d" % _plot_id, act)

                    # plot images
                    if plot_images:
                        _plot_label = batch["labels"][:plot_image_size]
                        _plot_sample = fake_img[:plot_image_size]
                        images = self.plot_data(epoch, loader, batch, _plot_id, _plot_sample, _plot_label)
                        if wandb_log and (images is not None):
                            wandb.log({"%s_%d_original" % (train_flag, _plot_id): images}, step=epoch)

                        _plot_sample = fake_img[:plot_image_size]
                        images = self.plot_data(epoch, loader, batch, _plot_id, _plot_sample)
                        if wandb_log and (images is not None):
                            wandb.log({"%s_%d_random" % (train_flag, _plot_id): images}, step=epoch)

                # inference
                sample_z = self.conditional_inference(_batch, sampling=sampling, std=std, iterative=iterative)
                sample_mean_z = self.conditional_inference(_batch, sampling=False, iterative=iterative)

                if len(conb) == 1:
                    sample_mean_z_all.append(sample_mean_z["z"])

                # generation
                for i in generation_id:
                    sample = self.conditional_generation(sample_z, [i])["x%d" %i].detach() # for fid
                    sample_mean = self.conditional_generation(sample_mean_z, [i])["x%d" %i].detach() # for accuracy
                 
                    ## log-likelihood
                    input_dict = {"x%d" %i: batch["x%d" %i], "z": sample_z["z"]}
                    ll = self.dist_dict["p_x_%d" % i].log_prob().eval(input_dict)
                    metric_update(metrics_dict, "ll_%d_%s" % (i, name), ll.cpu().detach().numpy())

                    if self.data != "CUB":
                        with torch.no_grad():
                            preds_mean = self.clf["f_%d" %i](sample_mean)
                            preds = self.clf["f_%d" %i](sample)
                        _plot_mean_label = preds_mean.argmax(-1)[:plot_image_size]
                        _plot_label = preds.argmax(-1)[:plot_image_size]
                    else:
                        _plot_mean_label = False
                        _plot_label = False

                    ## cross-coherence
                    if plot_images:
                        _plot_id = i
                        _plot_sample = sample[:plot_image_size].cpu()

                        images = self.plot_data(epoch, loader, batch, _plot_id, _plot_sample, _plot_label)
                        if wandb_log and (images is not None):
                            wandb.log({"%s_%d_%s_gen" % (train_flag, _plot_id, name): images}, step=epoch)

                    ### FID
                    if eval_fid and self.eval_fid_flag(i, conb):
                        if (self.data == "SVHNMNIST") and (i==0):
                            _true_img, _fake_img = batch["x%d" %i].expand(-1, 3, -1, -1), sample.expand(-1, 3, -1, -1)
                        else:
                            _true_img, _fake_img = batch["x%d" %i], sample

                        act = calculate_activation(_true_img, self.inception_v3)
                        metric_update(fid_metrics_dict["true"], "%d_%s" % (i, name), act)

                        act = calculate_activation(_fake_img, self.inception_v3)
                        metric_update(fid_metrics_dict["fake"], "%d_%s" % (i, name), act)

                    ### accuracy
                    if self.data != "CUB":
                        if (self.data == "SVHNMNIST") and (i==2):
                            preds_label_mean = self.pred_text(sample_mean)
                            preds_label = self.pred_text(sample)  
                            ac_mean = (preds_label_mean==batch["labels"].cpu()).float().numpy()
                            ac = (preds_label==batch["labels"].cpu()).float().numpy()
                        elif (self.data != "CelebAText"):
                            ac_mean = accuracy(preds_mean, batch["labels"])
                            ac = accuracy(preds, batch["labels"])
                        else:
                            ac_mean = avg_precision(preds_mean, batch["labels"])
                            ac = avg_precision(preds, batch["labels"])
                        metric_update(metrics_dict, "ac_%d_%s" % (i, name), ac)
                        metric_update(metrics_dict, "ac_mean_%d_%s" % (i, name), ac_mean)

                # latent accuracy
                if self.data != "CUB":
                    preds = self.latent_clf.sample_mean({"z": sample_z["z"]})
                    if (self.data != "CelebAText"):
                        lt = accuracy(preds, batch["labels"])
                    else:
                        lt = avg_precision(preds, batch["labels"])
                    metric_update(metrics_dict, "lt_%s" % name, lt)
        
        ac_flag_all = []

        for conb in itertools.combinations(self.modality_id, 2):
            # accuracy of random images
            ac_flag = pred_random_label_all[conb[0]] == pred_random_label_all[conb[1]]
            ac_flag_all.append(ac_flag)

            # cosine similarity
            var = ["x"+str(k) for k in conb]
            name = "".join(var)            
            eval_cos = cos(sample_mean_z_all[conb[0]], sample_mean_z_all[conb[1]]).cpu().numpy()
            metric_update(metrics_dict, "cosine_sim_%s" % name, eval_cos)

        ac_flag_all = np.prod(ac_flag_all, 0)
        metric_update(metrics_dict, "ac", ac_flag_all)
        print(ac_flag_all.mean())

        return metrics_dict, fid_metrics_dict

    def train_clf(self, input_dict):
        self.latent_clf.train()

        self.optimizer_clf.zero_grad()
        loss = -self.latent_clf.log_prob().mean().eval(input_dict)

        # backprop
        loss.backward()

        if self.clip_norm:
            clip_grad_norm_(self.latent_clf.parameters(), self.clip_norm)
        if self.clip_value:
            clip_grad_value_(self.latent_clf.parameters(), self.clip_value)

        # update params
        self.optimizer_clf.step()

        return loss

    def train(self, epoch, loader, **kwargs):
        t = tqdm(loader)
        metrics_dict = {}
        beta = self.set_beta(epoch)

        for i, batch in enumerate(t):
            input_dict = {}
            labels_batch = batch[-1].to(self.device)
        
            for i in self.modality_id:
                input_dict["x%d" %i] = batch[i].to(self.device)

            train_dict = {"beta": beta}
            train_dict.update(input_dict)
            eval_dict = {"labels": labels_batch}
            eval_dict.update(input_dict)            

            loss = super(Base, self).train(train_dict)
            metric_update(metrics_dict, "loss", loss.item())
            
            # train clf given latent variable
            if (epoch==1) or (epoch % self.test_epoch == 0):
                self.distributions.eval()
                sample_zlf = self.conditional_inference(input_dict, sampling=False)

                if (self.data != "CelebAText") and (self.data != "CUB"):
                    labels_batch= torch.eye(10)[labels_batch]
                sample_zlf.update({"labels": labels_batch.to(self.device)})

                if (self.data != "CUB"):
                    loss_clf = self.train_clf(sample_zlf)
                    metric_update(metrics_dict, "loss_clf", loss_clf.item())

            t.set_postfix(Epoch=epoch, beta=beta, loss=np.mean(metrics_dict["loss"]))

        if self.eval_train:
            self.eval_metrics(epoch, "train", loader, eval_dict, metrics_dict, {}, plot_images=True, wandb_log=True)

        metrics_dict = get_mean(metrics_dict)
        return metrics_dict


    def test(self, epoch, loader, iterative=None, wandb_log=True, **kwargs):
        t = tqdm(loader)
        metrics_dict = {}
        fid_metrics_dict = {"true":{}, "fake":{}}
        beta = self.set_beta(epoch)

        for j, batch in enumerate(t):
            input_dict = {}
            labels_batch = batch[-1].to(self.device)
            for i in self.modality_id:
                input_dict["x%d" %i] = batch[i].to(self.device)

            train_dict = {"beta": beta}
            train_dict.update(input_dict)
            loss = super(Base, self).test(train_dict)
            metric_update(metrics_dict, "loss", loss.item())

            eval_dict = {"labels": labels_batch}
            eval_dict.update(input_dict)
            # eval metrics
            if "eval_fid" in kwargs.keys():
                eval_fid = kwargs["eval_fid"]
            else:
                eval_fid = self.eval_fid
            if "eval_reconst" in kwargs.keys():
                eval_reconst = kwargs["eval_reconst"]
            else:
                eval_reconst = self.eval_reconst                

            plot_image = ((j==0) and self.plot_image) # plot only the 1st batch

            self.eval_metrics(epoch, "val", loader, eval_dict, metrics_dict, fid_metrics_dict,
                            iterative=iterative, eval_fid=eval_fid, eval_reconst=eval_reconst,
                            plot_images=plot_image, wandb_log=wandb_log)

            t.set_postfix(Epoch=epoch, loss=np.mean(metrics_dict["loss"]))

        metrics_dict = get_mean(metrics_dict)
        fid_metrics_dict = get_fid(fid_metrics_dict)
        metrics_dict.update(fid_metrics_dict)
        metrics_dict = get_ac_number_of_modalities(metrics_dict, self.modality_id)
    
        return metrics_dict

    def conditional_inference(self, modality_input, sampling=True, iterative=None, std=1.0):
        with torch.no_grad():
            modality_input = copy(modality_input)

            modality_input_id = get_modality_id(modality_input)
            q_z = [self.dist_dict["q_z_%d" %i] for i in modality_input_id]
            q_z_group = self.get_aggregated_inference(q_z)

            if sampling:
                mean = q_z_group.sample_mean(modality_input)
                noise = torch.normal(mean=torch.zeros_like(mean), std=std*torch.ones_like(mean))
                sample_z = {"%s" % q_z_group.var[0]: mean + noise * q_z_group.sample_variance(modality_input)}
#                sample_z = get_dict_values(q_z_group.sample(modality_input), q_z_group.var, return_dict=True)
            else:
                sample_z = {"%s" % q_z_group.var[0]: q_z_group.sample_mean(modality_input)}

            sample_z = detach_dict(sample_z)
        return sample_z

    def conditional_generation(self, modality_input, given_modality_id):
        with torch.no_grad():        
            p_x = [self.dist_dict["p_x_%d" %i] for i in given_modality_id]

            # generate
            generate = {}
            for i, p in zip(given_modality_id, p_x):
                _sample = get_dict_values(modality_input, p.cond_var, return_dict=True)
                _sample = p.sample_mean(_sample).detach()

                generate["x%d" %i] = _sample

        return generate

    def random_generation(self, given_modality_id, batch_n=1):
        with torch.no_grad():
            modality_input = self.dist_dict["prior_z"].sample(batch_n=batch_n)["z"]
            gen_images = self.conditional_generation({"z": modality_input}, given_modality_id)
        return gen_images

    def set_distributions(self):
        if self.data == "SVHNMNIST":
            # MNIST
            input_mnist = 28
            num_hidden_layers = 1
            hidden_dim_mnist = 400

            # SVHN
            hidden_dim_svhn = 128

            # Text
            num_features = 71
            text_dim = 64

            self.input_dim = [(1, input_mnist, input_mnist), (3, 32, 32), (8, num_features)]

            encoder_mnist = EncoderMNIST(hidden_dim_mnist, num_hidden_layers)
            encoder_svhn = EncoderSVHN()
            encoder_text = EncoderText(text_dim, num_features)
            self.enc_all = [encoder_mnist, encoder_svhn, encoder_text]

            self.hidden_dim = [hidden_dim_mnist, hidden_dim_svhn, 2 * text_dim]

            z_dim = self.z_dim + self.s_dim

            decoder_mnist = DecoderMNIST(z_dim, hidden_dim_mnist, num_hidden_layers, input_mnist)
            decoder_svhn = DecoderSVHN(z_dim)
            decoder_text = DecoderText(z_dim, num_features)
            dec_all = [decoder_mnist, decoder_svhn, decoder_text]
            dec_dist_all = [GenerationLaplace, GenerationLaplace, GenerationCategorical] #(input_var)

            clf_all = [ClfImgMNIST(), ClfImgSVHN(), ClfText(num_features, text_dim)]

            num_category = 10
            self.latent_clf = CLF(z_dim=self.z_dim, num_category=num_category).to(self.device)

        elif self.data == "CelebAText":
            DIM_text = 128
            # CelebA
            hidden_dim_celeba = DIM_text #5*DIM_text

            # CelebA Text
            hidden_dim_text = DIM_text #5*DIM_text
            num_features = 71

            self.input_dim = [(3, 64, 64), (256, 71)]

            self.enc_all = [EncoderCelebA(), EncoderCelebAText()]

            self.hidden_dim = [hidden_dim_celeba, hidden_dim_text]

            z_dim = self.z_dim + self.s_dim

            dec_all = [DecoderCelebA(z_dim), DecoderCelebAText(z_dim)]
            dec_dist_all = [GenerationLaplace, GenerationCategorical] #(input_var)

            clf_all = [ClfCelebAImg(), ClfCelebAText()]

            num_category = 40
            self.latent_clf = CLFBernoulli(z_dim=self.z_dim, num_category=num_category).to(self.device)

        elif self.data == "PolyMNIST":
            # MNIST
            input_mnist = 28

            self.input_dim = [(3, input_mnist, input_mnist)] * 5

            self.enc_all = [EncoderMMNIST(self.z_dim) for _ in range(5)]

            self.hidden_dim = [self.z_dim] * 5

            z_dim = self.z_dim + self.s_dim

            dec_all = [DecoderMMNIST(z_dim) for _ in range(5)]
            dec_dist_all = [GenerationLaplace] * 5 #(input_var)

            clf_all = [ClfImgMMNIST() for _ in range(5)]

            num_category = 10
            self.latent_clf = CLF(z_dim=self.z_dim, num_category=num_category).to(self.device)

        elif self.data == "TranslatedPolyMNIST":
            # MNIST
            input_mnist = 28

            self.input_dim = [(3, input_mnist, input_mnist)] * 5

            self.enc_all = [EncoderTransMMNIST(self.z_dim) for _ in range(5)]

            self.hidden_dim = [self.z_dim] * 5

            z_dim = self.z_dim + self.s_dim

            dec_all = [DecoderTransMMNIST(z_dim) for _ in range(5)]
            dec_dist_all = [GenerationLaplace] * 5 #(input_var)

            clf_all = [ClfImgTransMMNIST() for _ in range(5)]

            num_category = 10
            self.latent_clf = CLF(z_dim=self.z_dim, num_category=num_category).to(self.device)

        elif self.data == "CUB":
            self.enc_all = [EncoderCubImage(), EncoderCubText()]

            self.hidden_dim = [1024, 160]

            z_dim = self.z_dim + self.s_dim

            dec_all = [DecoderCubImage(z_dim), DecoderCubText(z_dim)]
            dec_dist_all = [GenerationLaplace, GenerationCategorical]

            clf_all = []
            num_category = 32
            self.latent_clf = CLFCategorical(z_dim=self.z_dim, num_category=32, y_dim=1590).to(self.device)
        else:
            raise NotImplementedError

        for i in self.modality_id:
            self.dist_dict["q_z_%d" %i] = Inference([self.enc_all[i]], [self.hidden_dim[i]], self.z_dim, cond_var=["x%d" %i]).to(self.device)
            if self.s_dim > 0:
                self.dist_dict["q_s_%d" %i] = Inference([self.enc_all[i]], [self.hidden_dim[i]], self.s_dim, cond_var=["x%d" %i], var=["s%d" %i]).to(self.device)
                self.dist_dict["p_x_%d" %i] = dec_dist_all[i](dec_all[i], var=["x%d" %i], cond_var=["c%d" %i, "z"]).to(self.device)
                self.dist_dict["prior_s_%d" %i] =  Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
                                                        var=["s%d" %i], features_shape=[self.s_dim],
                                                        name="p_{prior}").to(self.device)
            else:
                self.dist_dict["p_x_%d" %i] = dec_dist_all[i](dec_all[i], var=["x%d" %i]).to(self.device)

        if self.s_dim > 0:
            pass

        self.dist_dict["prior_z"] =  Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
                                            var=["z"], features_shape=[self.z_dim],
                                            name="p_{prior}").to(self.device)

        self.clf = {}
        if self.data != "CUB":
            for i in self.modality_id:
                self.clf["f_%d" %i] = clf_all[i]        

    def set_beta(self, epoch):
        if self.kl_annealing > 0:
            step_size = self.beta / (self.kl_annealing - self.kl_annealing_start)
            if epoch < self.kl_annealing:
                beta = (epoch - self.kl_annealing_start) * step_size
                if beta < 0:
                    beta = 0
            else:
                beta = self.beta
        else:
            beta = self.beta

        if (self.fix_elbo_smvae) and (epoch == self.kl_annealing_start+1):
            print("train only unimodal inferences")

        return beta

    def log_models(self, path):
        for k, v in self.dist_dict.items():
            torch.save(v.state_dict(), os.path.join(path, k))
        torch.save(self.latent_clf.state_dict(), os.path.join(path, "latent_clf"))

    def load_clf(self):
        if self.data == "SVHNMNIST":
            dir_clf = "./trained_classifiers/trained_clfs_mst"

            for i in self.modality_id:
                self.clf["f_%d" %i].load_state_dict(torch.load(os.path.join(dir_clf, "clf_m%d" %(i+1))))
                self.clf["f_%d" %i].to(self.device)

        elif self.data == "CelebAText":
            dir_clf = "./trained_classifiers/trained_clfs_celeba"

            for i in self.modality_id:
                self.clf["f_%d" %i].load_state_dict(torch.load(os.path.join(dir_clf, "clf_m%d" %(i+1))))
                self.clf["f_%d" %i].to(self.device)

        elif self.data == "PolyMNIST":
            dir_clf = "./trained_classifiers/trained_clfs_polyMNIST"
            for i in self.modality_id:
                self.clf["f_%d" %i].load_state_dict(torch.load(os.path.join(dir_clf, "pretrained_img_to_digit_clf_m%d" %(i))))
                self.clf["f_%d" %i].to(self.device)     

        elif self.data == "TranslatedPolyMNIST":
            dir_clf = "./trained_classifiers/trained_clfs_translatedpolyMNIST"
            for i in self.modality_id:
                self.clf["f_%d" %i].load_state_dict(torch.load(os.path.join(dir_clf, "pretrained_img_to_digit_clf_m%d" %(i))))
                self.clf["f_%d" %i].to(self.device)                               

    def load_models(self, path):
        model_id = os.path.split(path)[-1]
        for k, v in self.dist_dict.items():
            model = wandb.restore(k, run_path=path, root="./log_models/%s" %model_id)
            v.load_state_dict(torch.load(model.name))
            v.to(self.device)
        model = wandb.restore("latent_clf", run_path=path, root="./log_models/%s" %model_id)
        self.latent_clf.load_state_dict(torch.load(model.name))
        self.latent_clf.to(self.device)

    def eval(self):
        for _, v in self.dist_dict.items():
            v.eval()
        self.latent_clf.eval()


