from .FCN import *
from .VAE import *
from .xSampling import *
from .models import PytorchEncoder, PytorchDecoder, ClassConditionalEncoder, ClassConditionalDecoder

import numpy as np


def init_generator(state, variant, im_mean=None, im_std=None):
    if "MNIST" in state['dataset']:
        resolution = [1, 28, 28]
        resize_dim = None
        original_dim = None

        if variant == "sampling":
            enc = MNIST_Downsampler(state["resize_dim"])
            dec = MNIST_Upsampler()

    elif "CIFAR10" in state['dataset']:
        resolution = [3, 32, 32]
        resize_dim = None
        original_dim = None

        if variant == "sampling":
            enc = CIFAR_Downsampler(state["resize_dim"])
            dec = CIFAR_Upsampler()
        if variant == "randsampling":
            enc = CIFAR_RandomDownsampler(state["resize_dim"])
            dec = CIFAR_RandomUpsampler()
            
    elif "Imagenet" in state['dataset']:
        resolution = [3, 224, 224]
        resize_dim = state['encoder_resize']
        original_dim = state['original_size']
        
        if variant == "sampling":
            enc = Imagenet_Downsampler(state["resize_dim"])
            dec = Imagenet_Upsampler()
        if variant == "randsampling":
            enc = Imagenet_RandomDownsampler(state["resize_dim"])
            dec = Imagenet_RandomUpsampler()
        
    if variant == "AE":
        enc = Encoder(resolution, compress_mode=state['compress_mode'], resize_dim=resize_dim)
        dec = Decoder(resolution, compress_mode=state['compress_mode'], original_dim=original_dim)
            
        enc = PytorchEncoder(enc, [0, 1], im_mean, im_std)
        dec = PytorchDecoder(dec, [0, 1], im_mean, im_std)
        
        enc.load_state_dict(state['enc_path'])
        dec.load_state_dict(state['dec_path'])
    elif variant == "VAE":
        enc = VariationalEncoder(resolution, compress_mode=state['compress_mode'], resize_dim=resize_dim)
        dec = VariationalDecoder(resolution, compress_mode=state['compress_mode'], original_dim=original_dim)
            
        enc = PytorchEncoder(enc, [0, 1], im_mean, im_std)
        dec = PytorchDecoder(dec, [0, 1], im_mean, im_std)
        
        enc.load_state_dict(state['enc_path'])
        dec.load_state_dict(state['dec_path'])
    elif variant == 'CC-AE':
        possible_ix = np.load(state['classes_path'])                           
        
        enc_dict = {k: PytorchEncoder(
                            Encoder(resolution, compress_mode=state['compress_mode'], resize_dim=None),
                            [0, 1], im_mean, im_std)
                    for k in possible_ix}
                
            
        dec_dict = {k: PytorchDecoder(
                            Decoder(resolution, compress_mode=state['compress_mode'], original_dim=None),
                            [0, 1], im_mean, im_std) 
                    for k in possible_ix}
        
        enc = ClassConditionalEncoder(enc_dict)
        dec = ClassConditionalDecoder(dec_dict)
        
        enc.load_state_dict(path_template=state['enc_path'])
        dec.load_state_dict(path_template=state['dec_path'])
    elif variant == 'CC-VAE':
        raise NotImplementedError()
    elif variant == 'sampling':
        enc = PytorchEncoder(enc, [-2, 2], im_mean, im_std)
        dec = PytorchDecoder(dec, [-2, 2], im_mean, im_std)
    elif variant == 'randsampling':
        enc = PytorchEncoder(enc, [-2, 2], im_mean, im_std)
        enc.max_batch_size = 1
        dec = PytorchDecoder(dec, [-2, 2], im_mean, im_std)
        dec.max_batch_size = 1
    else:
        raise NotImplementedError(f"No HLM variant of {variant}. Options: [AE, VAE, CC-AE, sampling].")
        
        
    return enc, dec
