import os, sys
from functools import partial
from numpy.core.defchararray import startswith

import torch
import argparse
from torchvision.transforms import functional
from collections import OrderedDict

from hps import Hyperparams
from vae import VAE, ConditionalVAE
from train import enforce_obs as enforce_obs_fn


normalizers = {"cifar10": {"mean": 120.63838 / 255, "std": 64.16736 / 255},
               "ffhq_256": {"mean": 112.8666757481 / 255, "std": 69.84780273 / 255},
               "xray": {"mean": 0.5, "std": 0.25},
}


@torch.no_grad()
def inpaint(self, img_batch, mask_batch, data_normalizer, H, enforce_obs=True):
    old_mode = self.training
    self.training = False
    ## Prepare the full image batch and the part image batch
    img_batch = img_batch.permute(0, 2, 3, 1)
    img_batch_normalized = (img_batch - data_normalizer["mean"]) / data_normalizer["std"] # Normalize the data
    mask_batch = mask_batch.unsqueeze(1).permute(0, 2, 3, 1)
    normalized_masked_img_batch = torch.cat([img_batch_normalized * mask_batch, mask_batch], dim=-1)
    unnorm_masked_img_batch = torch.cat([(img_batch*255) * mask_batch, mask_batch], dim=-1)
    ## Forward-pass
    part_activations = self.part_encoder(normalized_masked_img_batch)
    px_z, _ = self.decoder.run(sample_from="part",
                               part_activations=part_activations)
    sample_batch = self.decoder.out_net.sample(px_z)
    ## Put the observed parts in the image
    if enforce_obs:
        inpainted = enforce_obs_fn(H, sample_batch, unnorm_masked_img_batch)
    else:
        inpainted = sample_batch
    inpainted = torch.tensor(inpainted).to(img_batch.device)
    inpainted = inpainted.permute(0, 3, 1, 2).contiguous()
    ## Map the pixel values to [0-1] from [0-255]
    inpainted = inpainted / 255
    self.training = old_mode
    return inpainted


def create_model_from_path(path):
    config = torch.load(os.path.join(path, "config.th"), map_location=lambda storage, loc: storage)
    state_dict = torch.load(os.path.join(path, "model-ema.th"), map_location=lambda storage, loc: storage)
    H = Hyperparams(config)
    ## Instantiate the mdoel
    VAE_type = ConditionalVAE if H.conditional else VAE
    assert H.conditional # NOTE: unconditional mode is not supperted by the inpaint function yet.
    model = VAE_type(H)
    ## Remove "module." from the beginning of state_dict names
    ## (caused by the model being a (Distributed)DataParallel object)
    l = len("module.")
    state_dict = OrderedDict((k[l:], v) if k.startswith("module.") else (k, v) for k,v in state_dict.items())
    ## Load the model weights
    model.load_state_dict(state_dict)
    ## Add the inpainting function to the model
    model.inpaint = partial(inpaint, model, data_normalizer=normalizers[H.dataset], H=H)
    return model
