#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import PIL
import numpy as np
import torch
import matplotlib.pyplot as plt
import json
from scipy import io
from folders import folders
from modules import PredsegModule, PredsegSequential
from helpers import get_neigh, get_sparse_p, get_neigh
from corr_clust import corr_clust
from pixel import sparse_spectral_clustering
from script_general import get_module, get_shifts


def show_all_BSD(model, version, neigh, split='train', device='cpu'):
    files = os.listdir(os.path.join(
        folders['BSD_predictions'],
        model, 'version%d' % version,
        split))
    for f in files:
        number = int(f.split('.')[0])
        show_BSD(model, version, neigh, number=number, split=split,
                 device=device)


def show_BSD(model, version, neigh, number=310007, split='train',
             device='cpu', savefig=None):
    im = PIL.Image.open(
        os.path.join(folders['BSD'],
                     split,
                     '%d.jpg' % number))
    im_tensor = torch.Tensor(np.array(im)).permute(2, 0, 1).unsqueeze(0)
    cont_global = io.loadmat(
        os.path.join(folders['BSD_predictions'],
                     model, 'version%d' % version,
                     split,
                     '%d.mat' % number))
    cont_global = cont_global['ucm']
    neighbors = get_neigh(neigh)
    module, dim_out = get_module(model, neighbors)
    shift, subsamp = get_shifts(model)

    base_folder = os.path.join(folders['models'], model)
    folder = os.path.join(base_folder, 'version%d' % version)
    pars_file = os.path.join(folder, 'pars.pth')
    module.load_state_dict(torch.load(
        pars_file, map_location=torch.device(device)))

    _ = module(im_tensor)

    w_maps, neighbors, resolution = module.infer_w(shift, subsamp, interpolate=True)
    edgemap = np.nansum(np.nansum(np.exp(w_maps) / (1 + np.exp(w_maps)), 0), 0)

    plt.figure(figsize=(20, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(im)
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.imshow(edgemap)
    plt.axis('off')
    plt.subplot(1, 3, 3)
    plt.imshow(cont_global)
    plt.axis('off')

    if savefig:
        plt.savefig('%s.pdf' % savefig, bbox_inches='tight')


def show_w_mat(model, version, neigh, number=310007, split='train',
               device='cpu', savefig=None):
    im = PIL.Image.open(
        os.path.join(folders['BSD'],
                     split,
                     '%d.jpg' % number))
    im_tensor = torch.Tensor(np.array(im)).permute(2, 0, 1).unsqueeze(0)
    neighbors = get_neigh(neigh)
    module, dim_out = get_module(model, neighbors)
    shift, subsamp = get_shifts(model)

    base_folder = os.path.join(folders['models'], model)
    folder = os.path.join(base_folder, 'version%d' % version)
    pars_file = os.path.join(folder, 'pars.pth')
    module.load_state_dict(torch.load(
        pars_file, map_location=torch.device(device)))

    _ = module(im_tensor)

    w_maps, neighbors, resolution = module.infer_w(shift, subsamp, interpolate=True)
    edge_mat = get_sparse_p(w_maps[0], neighbors)

    plt.figure(figsize=(5, 5))
    plt.imshow(edge_mat[10000:12000, 10000:12000].todense(), cmap='gray_r')
    plt.colorbar()
    plt.tick_params(
        bottom=False ,top=False, left=False, right=False,
        labelbottom=False, labelleft=False
    )

    if savefig:
        plt.savefig('%s.pdf' % savefig, bbox_inches='tight')


def plot_factor(C=10, p=0.2):
    # 2x 1D factor plots
    x = np.linspace(-3, 3, 1000)
    y = np.flip(np.linspace(-3, 3, 1000))
    xx, yy = np.meshgrid(x, y)

    fx = np.exp(-0.5 * x**2)
    fy = np.exp(-0.5 * y**2)
    Z = np.sqrt(1 + 2 * C)
    fxy1 = p * np.exp(-0.5 * (xx-yy) ** 2 * 10) * Z
    fxy = (1 - p) + p * np.exp(-0.5 * (xx-yy) ** 2 * 10) * Z
    f = fx * fy.reshape(-1, 1) * fxy
    f0 = fx * fy.reshape(-1, 1)
    f1 = fx * fy.reshape(-1, 1) * np.exp(-0.5 * (xx-yy) ** 2 * 10) * Z

    plt.figure()
    plt.subplot(1, 3, 1)
    plt.imshow((1 - p) * np.ones((1000, 1000)), cmap='inferno', vmin=0, vmax=np.max(fxy))
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.imshow(fxy1, cmap='inferno', vmin=0, vmax=np.max(fxy))
    plt.axis('off')
    plt.subplot(1, 3, 3)
    plt.imshow(fxy, cmap='inferno', vmin=0, vmax=np.max(fxy))
    plt.axis('off')

    plt.savefig('../figures/factors_2D_1.pdf', bbox_inches='tight')

    plt.figure()
    plt.subplot(1, 3, 1)
    plt.imshow(f0, cmap='inferno')
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.imshow(f1, cmap='inferno')
    plt.axis('off')
    plt.subplot(1, 3, 3)
    plt.imshow(f, cmap='inferno')
    plt.axis('off')

    plt.savefig('../figures/factors_2D.pdf', bbox_inches='tight')

    plt.rc('text', usetex=True)
    plt.figure()
    f1d = (1 - p) + p * np.exp(-0.5 * x ** 2 * 10) * Z
    plt.plot(x, f1d)
    plt.ylim(bottom=0)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.xlabel('$f_i - f_j$', fontsize=16)
    plt.ylabel('$\psi(f_i, f_j)$', fontsize=16)
    plt.xticks([])
    plt.yticks([])

    plt.savefig('../figures/factor_1D.pdf', bbox_inches='tight')


def show_weights_subplots(w, title=False):
    w = np.array(w)
    n = w.shape[0]
    k = np.ceil(np.sqrt(n))
    l = np.ceil(n / k)
    for i in range(n):
        plt.subplots_adjust(wspace=0.05, hspace=0.5)
        plt.subplot(k, l, i+1)
        w_i = w[i].transpose(1, 2, 0)
        w_i = (w_i - np.min(w_i)) / (np.max(w_i) - np.min(w_i))
        plt.imshow(w_i)
        if title:
            plt.title('%d: %.3f - %.3f' % (i, np.min(w[i]), np.max(w[i])))
        plt.xticks([])
        plt.yticks([])


def show_weights(w, title=False):
    f = plt.figure(figsize=(10, 10), dpi=100)
    w = np.array(w)
    n = w.shape[0]
    k = int(np.ceil(np.sqrt(n)))
    l = int(np.ceil(n / k))
    w_all = np.zeros((3, w.shape[2] * k, w.shape[3] * l))
    for i in range(n):
        i_k = np.mod(i, k)
        j_k = int(np.floor(i / k))
        w_i = (w[i] - np.min(w[i])) / (np.max(w[i]) - np.min(w[i]))
        w_all[:,
              (i_k * w.shape[2]):((i_k+1) * w.shape[2]),
              (j_k * w.shape[3]):((j_k+1) * w.shape[3])] = w_i
    w_all = w_all.transpose(1, 2, 0)
    plt.imshow(w_all)
    plt.xticks([])
    plt.yticks([])
    return f
