import math
import os
from typing import Tuple
import pickle

import torch
import torchvision
import torchvision.transforms.functional as F

import sys


def adjust_gen_images(imgs: torch.tensor,
                      bounds: Tuple[torch.tensor, torch.tensor], size: int):
    """
    Change the value range of images generated by StyleGAN2. Outputs are usually roughly in the range [-1, 1]. 
    A linear transformation is then applied following the transformation used in the official implementation to save images.
    Images are resized to a given size using bilinear interpolation.
    """
    lower_bound, upper_bound = bounds
    lower_bound = lower_bound.float().to(imgs.device)
    upper_bound = upper_bound.float().to(imgs.device)
    imgs = torch.where(imgs > upper_bound, upper_bound, imgs)
    imgs = torch.where(imgs < lower_bound, lower_bound, imgs)
    imgs = F.center_crop(imgs, (700, 700))
    imgs = F.resize(imgs, size)
    return imgs


def save_images(imgs: torch.tensor, folder, filename, center_crop=800):
    """Save StyleGAN output images in file(s).

    Args:
        imgs (torch.tensor): generated images in [-1, 1] range
        folder (str): output folder
        filename (str): name of the files
    """
    imgs = imgs.detach()
    if center_crop:
        imgs = F.center_crop(imgs, (center_crop, center_crop))
    imgs = (imgs * 0.5 + 128 / 255).clamp(0, 1)
    for i, img in enumerate(imgs):
        path = os.path.join(folder, f'{filename}_{i}.png')
        torchvision.utils.save_image(img, path)


def create_image(w,
                 generator,
                 crop_size=None,
                 resize=None,
                 batch_size=20,
                 device='cuda:0'):
    with torch.no_grad():
        if w.shape[1] == 1:
            w_expanded = torch.repeat_interleave(w,
                                                 repeats=generator.num_ws,
                                                 dim=1)
        else:
            w_expanded = w

        w_expanded = w_expanded.to(device)
        imgs = []
        for i in range(math.ceil(w_expanded.shape[0] / batch_size)):
            w_batch = w_expanded[i * batch_size:(i + 1) * batch_size]
            imgs_generated = generator(w_batch,
                                                 noise_mode='const',
                                                 force_fp32=True)
            imgs.append(imgs_generated.cpu())

        imgs = torch.cat(imgs, dim=0)
        if crop_size is not None:
            imgs = F.center_crop(imgs, (crop_size, crop_size))
        if resize is not None:
            imgs = F.resize(imgs, resize)
        return imgs

def load_generator(filepath):
    """Load pre-trained generator using the running average of the weights ('ema').

    Args:
        filepath (str): Path to .pkl file

    Returns:
        torch.nn.Module: G_ema from pickle
    """
    with open(filepath, 'rb') as f:
        sys.path.insert(0, 'stylegan2-ada-pytorch')
        G = pickle.load(f)['G_ema'].cuda()
    return G


def load_discrimator(filepath):
    """Load pre-trained discriminator

    Args:
        filepath (str): Path to .pkl file

    Returns:
        torch.nn.Module: D from pickle
    """
    with open(filepath, 'rb') as f:
        sys.path.insert(0, 'stylegan2-ada-pytorch')
        D = pickle.load(f)['D'].cuda()
    return D
