import numpy as np
from torchvision.utils import make_grid
import numpy as np
import torch
import itertools

import torchvision
from PIL import ImageDraw
from PIL import ImageFont
from PIL import Image

import textwrap
FONT = ImageFont.truetype('FreeSerif.ttf', 38)

def get_mean(metrics_dict):
    return_metrics_dict = {}
    for k, v in metrics_dict.items():
        if len(v) != 0:
            return_metrics_dict[k] = np.nanmean(v)
    return return_metrics_dict

def get_modality_id(modality_input):
    return [int(i[1])  for i in modality_input.keys() if i[0]=="x"]


def get_ac_number_of_modalities(metrics_dict, modality_id):
    """
    {"ac_0_1", "ac_0_2", "ac_1_0", "ac_1_2", "ac_2_0", "ac_2_1"} -> {"ac1"}
    """
    for j in range(1, len(modality_id)+1):
        acc_list = []
        fid_list = []

        for conb in itertools.combinations(modality_id, j):
            generation_id = list(set(modality_id)-set(conb))

            var = ["x"+str(k) for k in conb]
            name = "".join(var)          
            for i in generation_id:
                try:
                    acc = metrics_dict["ac_%d_%s"% (i, name)]
                    acc_list.append(acc)
                except:
                    pass
                try:
                    fid = metrics_dict["fid_%d_%s"% (i, name)]
                    fid_list.append(fid)
                except:
                    pass            
        if len(acc_list) > 0:
            metrics_dict["ac%d"% j] = np.mean(acc_list)
        if len(fid_list) > 0:
            metrics_dict["fid%d"% j] = np.mean(fid_list)

    return metrics_dict


def metric_update(metric_dict, index, value):
    try:
        metric_dict[index].append(value)
    except:
        metric_dict[index]=[value]

def text_to_pil(t, imgsize, alphabet):
    w = 128
    h = 128
    blank_img = torch.ones([3, w, h]);
    pil_img = torchvision.transforms.ToPILImage()(blank_img.cpu()).convert("RGB")
    draw = ImageDraw.Draw(pil_img)
    text_sample = tensor_to_text(alphabet, t)[0]
    lines = textwrap.wrap(''.join(text_sample), width=8)
    y_text = h
    num_lines = len(lines);
    for l, line in enumerate(lines):
        width, height = FONT.getsize(line)
        draw.text((0, (h/2) - (num_lines/2 - l)*height), line, (0, 0, 0), font=FONT)
        y_text += height
    text_pil = torchvision.transforms.ToTensor()(pil_img.resize((imgsize[1], imgsize[2]), Image.ANTIALIAS))
    
    return text_pil

def text_to_pil_celeba(t, imgsize, alphabet, w=256, h=256):
    blank_img = torch.ones([3, w, h]);
    pil_img = torchvision.transforms.ToPILImage()(blank_img.cpu()).convert("RGB")
    draw = ImageDraw.Draw(pil_img)
    text_sample = tensor_to_text(alphabet, t)[0]
    text_sample = ''.join(text_sample).translate({ord('*'): None})
    lines = textwrap.wrap(text_sample, width=16)
    y_text = h
    num_lines = len(lines);
    for l, line in enumerate(lines):
        width, height = FONT.getsize(line)
        draw.text((0, (h/2) - (num_lines/2 - l)*height), line, (0, 0, 0), font=FONT)
        y_text += height
    text_pil = torchvision.transforms.ToTensor()(pil_img.resize((imgsize[1], imgsize[2]),
                                                    Image.ANTIALIAS));
    return text_pil;

def text_to_tensor(t_tensor, alphabet, img_size=(3, 32, 32)):
    text_img = []
    for t in t_tensor:
        text_img.append(text_to_pil(t.unsqueeze(0), img_size, alphabet))
    text_img = torch.stack(text_img)
    return text_img

def tensor_to_text(alphabet, gen_t):
    gen_t = gen_t.cpu().data.numpy()
    gen_t = np.argmax(gen_t, axis=-1)
    decoded_samples = []
    for i in range(len(gen_t)):
        decoded = seq2text(alphabet, gen_t[i])
        decoded_samples.append(decoded)
    return decoded_samples

def seq2text(alphabet, seq):
    decoded = []
    for j in range(len(seq)):
        decoded.append(alphabet[seq[j]])
    return decoded