import wandb
from load_models import load_model as _load_model
import pandas as pd 
import torch
import numpy as np

import matplotlib.pyplot as plt
import torchvision
from PIL import ImageDraw
from PIL import ImageFont
from PIL import Image

import textwrap
FONT = ImageFont.truetype('FreeSerif.ttf', 38)

def text_to_pil(t, imgsize, alphabet):
    w = 128
    h = 128
    blank_img = torch.ones([3, w, h]);
    pil_img = torchvision.transforms.ToPILImage()(blank_img.cpu()).convert("RGB")
    draw = ImageDraw.Draw(pil_img)
    text_sample = tensor_to_text(alphabet, t)[0]
    lines = textwrap.wrap(''.join(text_sample), width=8)
    y_text = h
    num_lines = len(lines);
    for l, line in enumerate(lines):
        width, height = FONT.getsize(line)
        draw.text((0, (h/2) - (num_lines/2 - l)*height), line, (0, 0, 0), font=FONT)
        y_text += height
    text_pil = torchvision.transforms.ToTensor()(pil_img.resize((imgsize[1], imgsize[2]), Image.ANTIALIAS))
    
    return text_pil

def text_to_pil_celeba(t, imgsize, alphabet, w=256, h=256):
    blank_img = torch.ones([3, w, h]);
    pil_img = torchvision.transforms.ToPILImage()(blank_img.cpu()).convert("RGB")
    draw = ImageDraw.Draw(pil_img)
    text_sample = tensor_to_text(alphabet, t)[0]
    text_sample = ''.join(text_sample).translate({ord('*'): None})
    lines = textwrap.wrap(text_sample, width=16)
    y_text = h
    num_lines = len(lines);
    for l, line in enumerate(lines):
        width, height = FONT.getsize(line)
        draw.text((0, (h/2) - (num_lines/2 - l)*height), line, (0, 0, 0), font=FONT)
        y_text += height
    text_pil = torchvision.transforms.ToTensor()(pil_img.resize((imgsize[1], imgsize[2]),
                                                    Image.ANTIALIAS));
    return text_pil;

def text_to_tensor(t_tensor, alphabet, img_size=(3, 32, 32)):
    text_img = []
    for t in t_tensor:
        text_img.append(text_to_pil(t.unsqueeze(0), img_size, alphabet))
    text_img = torch.stack(text_img)
    return text_img

def tensor_to_text(alphabet, gen_t):
    gen_t = gen_t.cpu().data.numpy()
    gen_t = np.argmax(gen_t, axis=-1)
    decoded_samples = []
    for i in range(len(gen_t)):
        decoded = seq2text(alphabet, gen_t[i])
        decoded_samples.append(decoded)
    return decoded_samples

def seq2text(alphabet, seq):
    decoded = []
    for j in range(len(seq)):
        decoded.append(alphabet[seq[j]])
    return decoded


def get_metrics(name, load_model=False, test_model=False, eval_fid=True, eval_reconst=False, wandb_log=False, std=1.0, sampling=True, update=False, force_update=False, device="cuda"):
    path = "masa-su/SMVAEs/" + name
        
    api = wandb.Api()
    run = api.run(path)
    print(run.state)
    if run.state!="finished":
        return None
    
    test_dict = dict(run.summary["val"])
    
    fid_flag = len([s for s in test_dict.keys() if 'fid' in s])>0
    if fid_flag and (force_update is False):
        eval_fid=False    
    
    if ("cosine_sim_x0x1" in test_dict.keys()) and (eval_fid is False) and (force_update is False):
        if load_model:
            model, _, _ = _load_model(path, "cuda")
            test_dict.update({"model": model})
            
    elif test_model:
        model, train_loader, test_loader = _load_model(path, device)
        test_dict.update(model.test(0, test_loader, eval_fid=eval_fid, eval_reconst=eval_reconst,
                                    wandb_log=wandb_log, std=std, sampling=sampling))
        if update:
            run.summary["val"] = test_dict
            run.summary.update()

        if load_model:
            test_dict.update({"model": model})
    
    return test_dict


def load_table(path):
    api = wandb.Api()
    runs = api.runs(path) 

    summary_list, config_list, name_list, id_list = [], [], [], []
    for run in runs:
        # .summary contains the output keys/values for metrics like accuracy.
        #  We call ._json_dict to omit large files 
        summary_list.append(run.summary._json_dict)
        id_list.append(run.id)

        # .config contains the hyperparameters.
        #  We remove special values that start with _.
        config_list.append(
            {k: v for k,v in run.config.items()
             if not k.startswith('_')})

        # .name is the human-readable name of the run.
        name_list.append(run.name)

    df_dict = pd.DataFrame.from_dict({id: config for id ,config in zip(id_list, config_list)}).T

    return df_dict