# coding: utf-8
import copy
import json
import logging
import os

import numpy as np
import torch
from PIL import Image
from scipy.stats import truncnorm

logger = logging.getLogger()


def generate_categorical_labels(label, num_classes, n_samples):
    """Create a one-hot vector from a class index or a list of class indices.
    Params:
        int_or_list: int, or list of int, of the imagenet classes (between 0 and 999)
        batch_size: batch size.
            If int_or_list is an int create a batch of identical classes.
            If int_or_list is a list, we should have `len(int_or_list) == batch_size`
    Output:
        array of shape (batch_size, 1000)
    """
    categorical_labels = np.repeat(label, n_samples)

    array = np.zeros((n_samples, num_classes), dtype=np.float32)
    for i, j in enumerate(categorical_labels):
        array[i, j] = 1.0
    return torch.from_numpy(array)


def sample_categorical_labels(num_classes, n_samples):
    """Create a one-hot vector from a class index or a list of class indices.
    Params:
        int_or_list: int, or list of int, of the imagenet classes (between 0 and 999)
        batch_size: batch size.
            If int_or_list is an int create a batch of identical classes.
            If int_or_list is a list, we should have `len(int_or_list) == batch_size`
    Output:
        array of shape (batch_size, 1000)
    """
    categorical_labels = np.random.randint(low=0, high=num_classes, size=(n_samples))

    array = np.zeros((n_samples, num_classes), dtype=np.float32)
    for i, j in enumerate(categorical_labels):
        array[i, j] = 1.0
    return torch.from_numpy(array)


def convert_to_images(obj):
    """Convert an output tensor from BigGAN in a list of images.
    Params:
        obj: tensor or numpy array of shape (batch_size, channels, height, width)
    Output:
        list of Pillow Images of size (height, width)
    """
    if not isinstance(obj, np.ndarray):
        obj = obj.detach().numpy()

    obj = obj.transpose((0, 2, 3, 1))
    obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255)

    img = []
    for i, out in enumerate(obj):
        out_array = np.asarray(np.uint8(out), dtype=np.uint8)
        img.append(Image.fromarray(out_array))
    return img


def save_label_pair_as_binary(pseudo_labels, target_labels, config):
    """Convert and save output label tensors as .npz file
    Args:
        pseudo_labels ([type]): [description]
        target_labels ([type]): [description]
        config ([type]): [description]
    """
    out_dir = config["output_dataset_path"]
    out_dir = os.path.join(out_dir, "labels")
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    file_name = os.path.join(out_dir, "target_and_pseudo_labels.npz")
    np.savez(file_name, pseudo_labels=pseudo_labels, target_labels=target_labels)
    logger.info(f"## Completed to save {file_name}")


def save_as_dataset(data, labels, config, start_index):
    """Convert and save an output tensor from BigGAN in a list of saved images.
    Params:
        obj: tensor or numpy array of shape (batch_size, channels, height, width)
        file_name: path and beggingin of filename to save.
            Images will be saved as `file_name_{image_number}.png`
    """
    img = convert_to_images(data)
    out_path = config["output_dataset_path"]
    for i, out in enumerate(img):
        class_dir = out_path + "/train/" + str(labels[i].item()).zfill(5)
        if not os.path.exists(class_dir):
            os.makedirs(class_dir)
        current_file_name = f"{class_dir}/{(i + start_index):08}.png"
        out.save(current_file_name, "png")


def save_as_dataset_per_class(data, label, config, start_index, test=False):
    img = convert_to_images(data)
    out_path = config["output_dataset_path"]
    dest = "train" if not test else "test"
    for i, out in enumerate(img):
        class_dir = out_path + f"/{dest}/" + str(label).zfill(5)
        if not os.path.exists(class_dir):
            os.makedirs(class_dir)
        current_file_name = f"{class_dir}/{(i + start_index):08}.png"
        out.save(current_file_name, "png")


def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1.0, seed=None):
    """Create a truncated noise vector.
    Params:
        batch_size: batch size.
        dim_z: dimension of z
        truncation: truncation value to use
        seed: seed for the random generator
    Output:
        array of shape (batch_size, dim_z)
    """
    state = None if seed is None else np.random.RandomState(seed)
    values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32)
    return truncation * values


class BigGANConfig(object):
    """Configuration class to store the configuration of a `BigGAN`.
    Defaults are for the 128x128 model.
    layers tuple are (up-sample in the layer ?, input channels, output channels)
    """

    def __init__(
        self,
        output_dim=128,
        z_dim=128,
        class_embed_dim=128,
        channel_width=128,
        num_classes=1000,
        layers=[
            (False, 16, 16),
            (True, 16, 16),
            (False, 16, 16),
            (True, 16, 8),
            (False, 8, 8),
            (True, 8, 4),
            (False, 4, 4),
            (True, 4, 2),
            (False, 2, 2),
            (True, 2, 1),
        ],
        attention_layer_position=8,
        eps=1e-4,
        n_stats=51,
    ):
        """Constructs BigGANConfig."""
        self.output_dim = output_dim
        self.z_dim = z_dim
        self.class_embed_dim = class_embed_dim
        self.channel_width = channel_width
        self.num_classes = num_classes
        self.layers = layers
        self.attention_layer_position = attention_layer_position
        self.eps = eps
        self.n_stats = n_stats

    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `BigGANConfig` from a Python dictionary of parameters."""
        config = BigGANConfig()
        for key, value in json_object.items():
            config.__dict__[key] = value
        return config

    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `BigGANConfig` from a json file of parameters."""
        with open(json_file, "r", encoding="utf-8") as reader:
            text = reader.read()
        return cls.from_dict(json.loads(text))

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
