import torch
import numpy as np
from matplotlib import pyplot as plt

import torch.nn as nn
from collections import OrderedDict


console_plot = False


import io

from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"

class MLP(nn.Module):
    def __init__(self, input_dims, n_hiddens, n_class):
        super(MLP, self).__init__()
        assert isinstance(input_dims, int), 'Please provide int for input_dims'
        self.input_dims = input_dims
        current_dims = input_dims
        layers = OrderedDict()

        if isinstance(n_hiddens, int):
            n_hiddens = [n_hiddens]
        else:
            n_hiddens = list(n_hiddens)
        for i, n_hidden in enumerate(n_hiddens):
            layers['fc{}'.format(i+1)] = nn.Linear(current_dims, n_hidden)
            layers['relu{}'.format(i+1)] = nn.ReLU()
            layers['drop{}'.format(i+1)] = nn.Dropout(0.2)
            current_dims = n_hidden
        layers['out'] = nn.Linear(current_dims, n_class)

        self.model=nn.Sequential(layers).to(device)

    def forward(self, input):
        input = input.view(input.size(0), -1)
        assert input.size(1) == self.input_dims
        return self.model.forward(input)



memo_model = [None]
def get_classification_accuracy(images, target):
    global memo_model
    if memo_model[0] is None:
        memo_model[0] = mnist(pretrained=True)
    logits = memo_model[0].forward(images)

    values, argmax = logits.max(dim=1)

    accuracy = (argmax == target).float().mean()
    return accuracy



def plot_data(triples, targets, wnum = 0):
    triples = triples.cpu().numpy()[wnum]
    targets = targets.cpu().numpy()[wnum]
    im = np.zeros((28, 28)) - 1

    for i in range(triples.shape[0]):
        x, y = triples[i, 0], triples[i, 2]

        im[x, y] = targets[i]

    plt.imshow(im, cmap='hot', interpolation='none')
    plt.show()

def plotGeneratedImages(generatedImages, image_shape=(28,28), cell_shape=(28,28), console_plot=console_plot):
    dim = generatedImages.shape[0], generatedImages.shape[1]
    generatedImages = generatedImages.view(*dim, *image_shape)
    if isinstance(generatedImages, torch.Tensor):
        generatedImages = generatedImages.detach().cpu().numpy()



    plt.figure(figsize=( 2 * dim[1] * image_shape[1] // cell_shape[1], 2 *  dim[0] * image_shape[0] // cell_shape[0]))
    for i in range(dim[0]):
        for j in range(dim[1]):
            ind = i * dim[1] + j + 1
            plt.subplot(dim[0], dim[1], ind)
            plt.imshow(generatedImages[i, j], interpolation='none', cmap='gnuplot', vmin=-1., vmax=1)
            plt.axis('off')
    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    if console_plot:
        plt.show()
    plt.close()
    return img

def plot_gif(generatedImages):
    import wandb
    generatedImages[0].save("temp.gif", save_all=True, append_images=generatedImages[1:], duration=200, loop=0)
    return wandb.Video("temp.gif", format = "gif")

def all_triples_tensor(num_relations=1, num_entities=28):
    pixels = torch.arange(num_entities ** 2)
    triples = torch.zeros(num_entities ** 2, 3, dtype=torch.long)
    triples[:, 0] = pixels // num_entities
    triples[:, 2] = pixels % num_entities
    ret = torch.cat([triples] * num_relations, dim=0)
    ret[:, 1] = torch.arange(num_entities ** 2 * num_relations, dtype=torch.long) // num_entities ** 2
    return ret


def get_acc(ps, ys):
    return float((ps.flatten().ge(0.0) == ys.flatten().ge(0.01)).double().mean())

def visualiseModel(CONFIG, model, datahandler, world_states,
                    run=None, batch_size=1, init_frame=0, relation_offset=0, labeled=False, progressive=False,
                    console_plot=console_plot):
    dataloader = datahandler.dataloader
    global device
    if run is None:
        raise NotImplementedError
        run = dataloader.get_run(batch_size)
    num_entities = CONFIG.model_vars["model_inits"]["num_entities"]
    num_relations = CONFIG.model_vars["model_inits"]["num_relations"]
    grid_size = num_entities ** 2
    if batch_size > 5:
        batch_size = 5
    image_idx, triples, targets = run
    triples = triples[:batch_size]
    targets = targets[:batch_size].float()
    images = datahandler.get_images(init_frame, num_relations)[:batch_size].float()
    all_triples = all_triples_tensor(num_relations, num_entities)
    observations = torch.zeros(batch_size, num_relations, grid_size).to(device) - 1
    pixels = triples[:, :, 0] * num_entities + triples[:, :, 2] + (triples[:, :, 1] - relation_offset) * grid_size
    offset = torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1) * grid_size * num_relations
    # print(offset+pixels.shape, targets.shape)
    observations.put_(offset + pixels, targets, False)
    # print(init_frame)
    chunk_size = 1024
    chunks = (num_entities ** 2 * num_relations + chunk_size - 1) // chunk_size
    preds = []
    for c in range(chunks):
        preds += [model.forward(all_triples[c * chunk_size: (c + 1) * chunk_size].unsqueeze(0).repeat(batch_size, 1, 1),
                                world_states[:batch_size])]
    pred = torch.cat(preds, dim=1).view(batch_size, num_relations, grid_size)
    reconstructions = torch.sigmoid(pred)
    threshold_reconstructions = (reconstructions > .5).float()
        # print(i, get_acc(pred, images[i]))
    image_size = (num_relations * num_entities, num_entities)
    stack = [images.view_as(observations)]
    if progressive:
        stack+= [datahandler.get_seen_images(CONFIG)[:batch_size].view_as(observations)]
    stack += [observations, reconstructions, threshold_reconstructions]
    if labeled:
        predicted_label = model.forward(torch.zeros((batch_size, 1, 1), dtype=torch.long), world_states[:batch_size])
        stack += [torch.sigmoid(predicted_label).expand_as(observations)]
    stacked = torch.stack([x.to("cpu") for x in stack])
    stacked = stacked.view(len(stack), batch_size, num_relations, num_entities, num_entities)#.permute([1, 0, 2, 3, 4])
    return plotGeneratedImages(stacked.reshape(len(stack), batch_size, *image_size), image_shape=image_size,
                               cell_shape=(num_entities, num_entities), console_plot=console_plot)




def digit_distance(samples_by_digit):
    with torch.no_grad():
        num_digits = len(samples_by_digit)
        centers = torch.zeros((num_digits, samples_by_digit[0].shape[1]), device=device)
        variance = []
        for i in range(num_digits):
            centers[i] = samples_by_digit[i].mean(dim=0)
            variance += [(samples_by_digit[i] - centers[i]).norm() / (samples_by_digit[i].shape[0] ** .5)]
        dists = torch.zeros(num_digits, num_digits, device=device)
        for i in range(num_digits):
            for j in range(i):
                dists[i][j] = (centers[i] - centers[j]).norm()
                dists[j][i] = dists[i][j]
        # centers_mag = centers.norm(dim=1)
        # dists = - 2 * centers @ center.T
        # dists =
        print(torch.Tensor(variance))
        print(dists)
        return variance, dists


if __name__ == "__main__":
    r = plotGeneratedImages(torch.rand((3,3,28,28)))
    r2 = [plotGeneratedImages(torch.rand((3,3,28,28))) for _ in range(10)]
    plot_gif(r2)
    r.save("testimg.png")
    r.save("testgif.gif", save_all=True, append_images=r2, duration=200, loop=0)