import yaml
import torch
import torch.nn as nn
from apex import amp

from models.munit.networks import AdaINGen

def load_model(args, reverse=False):
    """Load MUNIT model and initialize with half-precision
    if args.half_precision flag is set.
    
    Params:
        args: Command line arguments for main.py.

    Returns:
        MUNIT models as nn.Module instance.
    """

    G1 = MUNIT_Model(args.model_path, reverse=reverse).cuda()
    if args.half_prec:
        G1 = amp.initialize(G1, opt_level='O1').half()
    return G1

    # G1 = MUNIT_Model('training/models/imagenet-snow-3.pt', reverse=False).cuda()
    # G2 = MUNIT_Model('training/models/imagenet-brightness-3.pt', reverse=False).cuda()
    # if args.half_prec:
    #     G1 = amp.initialize(G1, opt_level='O1').half()
    #     G2 = amp.initialize(G2, opt_level='O1').half()

    # return Compose_Model(G1, G2)
    

class Compose_Model(nn.Module):

    def __init__(self, G1, G2):
        super(Compose_Model, self).__init__()
        self.G1, self.G2 = G1, G2

    def forward(self, x, delta):
        return self.G2.forward(self.G1.forward(x, delta), delta)

    

class MUNIT_Model(nn.Module):
    def __init__(self, fname: str, reverse: bool):
        """Instantiantion of pre-trained MUNIT model.
        
        Params:
            fname: File name of trained MUNIT checkpoint file.
        """

        super(MUNIT_Model, self).__init__()

        self._config = self.__get_config('training/models/munit.yaml')
        self._fname = fname
        self._reverse = reverse
        self._gen_A, self._gen_B = self.__load()
        self.delta_dim = self._config['gen']['style_dim']


    def forward(self, x, delta):

        orig_content, _ = self._gen_A.encode(x)
        orig_content = orig_content.clone().detach().requires_grad_(False)
        new_x = self._gen_B.decode(orig_content, delta)

        return new_x

    def __load(self):

        gen_A = self.__load_munit(self._fname, 'a')
        gen_B = self.__load_munit(self._fname, 'b')

        if self._reverse is False:
            return gen_A, gen_B     # original order
        return gen_B, gen_A         # reversed order

    def __load_munit(self, fname: str, letter: str):

        gen = AdaINGen(self._config[f'input_dim_{letter}'], self._config['gen'])
        gen.load_state_dict(torch.load(fname)[letter])
        gen.eval()

        return gen

    @staticmethod
    def __get_config(path):
        """Load .yaml file as dictionary.
        
        Params:
            path: Path to .yaml configuration file.
        """

        with open(path, 'r') as stream:
            return yaml.load(stream, Loader=yaml.FullLoader)