from typing import Optional, Tuple

import numpy as np
import torch

from configs.paths_config import interfacegan_edit_paths
# from models.stylegan3.model import GeneratorType
# from models.stylegan3.networks_stylegan3 import Generator

from models.swagan.swagan import Generator
from utils.common import tensor2im


class FaceEditor:

    def __init__(self, stylegan_generator: Generator):
        self.generator = stylegan_generator
        paths = interfacegan_edit_paths
        
        self.interfacegan_directions = {
            'age': torch.from_numpy(np.load(paths['age'])).cuda(),
            # 'smile': torch.from_numpy(np.load(paths['smile'])).cuda(),
            'pose': torch.from_numpy(np.load(paths['pose'])).cuda(),
            # 'Male': torch.from_numpy(np.load(paths['Male'])).cuda(),
        }

    def edit_sg3(self, latents: torch.tensor, direction: str, factor: int = 1, factor_range: Optional[Tuple[int, int]] = None,
             user_transforms: Optional[np.ndarray] = None, apply_user_transformations: Optional[bool] = False, all_s = None):
        edit_latents = []
        edit_images = []
        direction = self.interfacegan_directions[direction]
        if factor_range is not None:  # Apply a range of editing factors. for example, (-5, 5)
            for f in range(*factor_range):
                # print(latents.shape, direction.shape)
                edit_latent = latents + f * 30 * direction
                # print(direction)
                edit_image, user_transforms = self._latents_to_image(edit_latent,
                                                                     apply_user_transformations,
                                                                     user_transforms, all_s = all_s)
                edit_latents.append(edit_latent)
                edit_images.append(edit_image)
        else:
            edit_latents = latents + factor * direction
            edit_images, _ = self._latents_to_image(edit_latents, apply_user_transformations)
        return edit_images, edit_latents

    def edit(self, latents: torch.tensor, direction: str, factor: int = 1, factor_range: Optional[Tuple[int, int]] = None):
        edit_latents = []
        edit_images = []
        direction = self.interfacegan_directions[direction]
        if factor_range is not None:  # Apply a range of editing factors. for example, (-5, 5)
            for f in range(*factor_range):
                # print(latents.shape, direction.shape)
                edit_latent = latents + 1.5 * f * direction

                # edit_image, _ = self.generator(edit_latent)
                edit_image, _, _ = self.generator([edit_latent], None, None, input_is_latent=True, randomize_noise=False, return_latents=True, is_inference=True)
                edit_image = [tensor2im(image) for image in edit_image]

                edit_latents.append(edit_latent)
                edit_images.append(edit_image)
        else:
            raise NotImplementedError
        return edit_images, edit_latents


    def edit_single(self, latents: torch.tensor, direction: str, factor: int = 1, edit_degree=None):

        direction = self.interfacegan_directions[direction]
        if edit_degree is not None:  # Apply a range of editing factors. for example, (-5, 5)
        
            edit_latent = latents + 1.5 * edit_degree * direction
            edit_image, _, _ = self.generator([edit_latent], None, None, input_is_latent=True, randomize_noise=False, return_latents=True, is_inference=True)

        else:
            raise NotImplementedError
        return edit_image, edit_latent
        

    def _latents_to_image(self, all_latents: torch.tensor, apply_user_transformations: bool = False,
                          user_transforms: Optional[torch.tensor] = None, all_s = None):
        with torch.no_grad():
            
            images = self.generator.synthesis(all_latents, noise_mode='const')
            images = [tensor2im(image) for image in images]
        return images, user_transforms
