#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import torch
import numpy as np
import json
from PIL import Image
from matplotlib import pyplot as plt
from lucent.optvis import render, param, transform, objectives
from folders import folders
from helpers import get_neigh, get_last_hparam
from get_module import get_module, get_shifts
from make_illustrations import show_weights


def show_one_feature(model, version, neigh, layer, channel, device='cpu',
                     resolution=128, show_image=True):
    """makes a lucent plot for a model using render.render_vis """
    neighbors = get_neigh(neigh)
    module, dim_out = get_module(model, neighbors)
    folder = os.path.join(folders['models'], model, 'version%d' % version)

    pars_file = os.path.join(folder, 'pars.pth')
    if os.path.isfile(pars_file):
        module.load_state_dict(torch.load(
            pars_file, map_location=torch.device(device)))
    module.to(device).eval()
    param_f = lambda: param.image(resolution, fft=False, decorrelate=False)
    transforms = [transform.jitter(4)]
    out = render.render_vis(module, objectives.neuron(layer, channel),
                            param_f=param_f, show_image=False,
                            transforms=transforms)
    if show_image:
        plt.imshow(out[0][0])
    return out


def show_kernel(model, version, channel, neigh=None, device='cpu', resolution=24):
    """ shows a linear kernel approximation at 0
    this is particularly helpful for the first layers of the network that
    are still linear.
    """
    folder = os.path.join(folders['models'], model, 'version%d' % version)
    if neigh is None:
        p_file = os.path.join(folder, 'hparam.json')
        hparam = get_last_hparam(p_file)
        neigh = hparam.neigh
    neighbors = get_neigh(neigh)
    module, dim_out = get_module(model, neighbors)

    pars_file = os.path.join(folder, 'pars.pth')
    if os.path.isfile(pars_file):
        module.load_state_dict(torch.load(
            pars_file, map_location=torch.device(device)))
    if model == 'predseg1':
        module = module[:3]
    module.to(device)
    im_in = torch.zeros(1, 3, resolution, resolution, requires_grad=True)
    out = module(im_in)
    pos = np.floor(out.shape[-1]/2).astype('int')
    out[0, channel, pos, pos].backward()
    d_grad = im_in.grad.detach().numpy()[0].transpose(1, 2, 0)
    d_grad -= np.min(d_grad)
    d_grad /= np.max(d_grad)
    plt.imshow(d_grad)


def make_training_animation(model, version, neigh=None, device='cpu',
                            filename='animation.gif'):
    if neigh is None:
        models = json.load(open('models.json', 'r'))
        if model == 'predseg1':
            neigh = models[48 + version]['neigh']
        elif model == 'linearbig':
            neigh = models[24 + version]['neigh']
    neighbors = get_neigh(neigh)
    module, dim_out = get_module(model, neighbors)
    folder = os.path.join(folders['models'], model, 'version%d' % version)

    cp_file = os.path.join(folder, 'checkpoints', 'cp_%d.pth')
    i = 0
    images = []
    while os.path.exists(cp_file % i):
        pars = torch.load(cp_file % i, map_location=device)
        if model == 'predseg1':
            fig = show_weights(pars['2.raw_module.weight'])
        else:
            fig = show_weights(pars['raw_module.weight'])
        fig.canvas.draw()
        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        images.append(Image.fromarray(image_from_plot))
        i += 1
        plt.close(fig)
    images[0].save(fp=filename, format='GIF', append_images=images[1:], save_all=True,
                   duration = 200, loop=True)
    return Image.open(filename)
