#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from collections.abc import Iterable
import multiprocessing
import time
import pathlib
import os
import json
import torch
import torchvision
import numpy as np
import scipy
from scipy import io
import PIL
from tqdm import tqdm
from helpers import extract_cluster_features, get_sparse_laplacian
from helpers import edges_to_laplacian, get_sparse_p
from data import ImageFolderName
from contour import contour
from folders import folders


os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# BSD handling
def _process_image_BSD(w_m, neighbors, out_dir, file, im_size=None,
                       verbose=False, transform='quant', prior_shift=4,
                       q_zero=0.3, pad=0, n_neigh=None):
    if n_neigh is None:
        n_neigh = w_m.shape[0]
    # normalization to avoid ceiling effects connecting everything
    if transform == 'expit':
        # sets maximum over all w_maps to prior_shift
        # 4 (= 98 percent connected)
        w_m -= np.nanmax(w_m) - prior_shift
        w_m = scipy.special.expit(w_m)
        laplacian, divisor = get_sparse_laplacian(w_m, neighbors)
    elif transform == 'quant':
        quantiles = np.nanquantile(
            np.reshape(w_m, (w_m.shape[0], np.prod(w_m.shape[1:]))),
            q_zero, axis=1)
        w_m -= np.reshape(quantiles, (w_m.shape[0], 1, 1))
        w_m[w_m < 0.01] = 0.01
        laplacian, divisor = get_sparse_laplacian(w_m, neighbors)
    if verbose:
        print('performed transform')
    cont = contour(laplacian, divisor, im_size=im_size, verbose=verbose,
                   im_size_small=(w_m.shape[1], w_m.shape[2]), pad=pad)
    if verbose:
        print('got contours')
    # normalization: first to [0,1] then set mean to 0.2
    # the second step is good for improving the dataset-wide threshold evaluation
    cont -= np.min(cont)
    cont /= np.max(cont)
    cont = np.mean(cont, 0)
    cont -= np.mean(cont)
    cont += 0.2
    file = os.path.split(file)[-1]
    # out_file = os.path.join(out_dir, file.replace('.jpg', '.png'))
    # pil_im = PIL.Image.fromarray(255*cont)
    # pil_im.convert('RGB').save(out_file, format='png')
    out_file = os.path.join(out_dir, file.replace('.jpg', '.mat'))
    io.savemat(out_file, {'ucm': cont})
    if verbose:
        print('saved contours')


def save_BSD(
    module, shift, subsamp,
    out_dir=os.path.join(folders['BSD_predictions'], 'version999', 'train'),
    in_dir=os.path.join(folders['BSD'], 'train'),
    verbose=False, t_max=np.inf, num_workers=1, transform='dist',
    downsample=0, prior_shift=4, separate=False, interpolate=False, q_zero=0.15):
    """ saves the predictions for the BSD segmentation benchmark

    most of the computation here is done on CPU and this runs a separate
    process for each image in the batch to achieve faster processing.

    An informed guess for a good number of processes is the number of CPUs-1
    to leave one for collecting the new data
    """
    data = ImageFolderName(in_dir)
    os.makedirs(out_dir, exist_ok=True)
    if isinstance(t_max, Iterable):
        data = torch.utils.data.Subset(data, t_max)
    elif np.isfinite(t_max):
        idx_set = range(t_max)
        data = torch.utils.data.Subset(data, idx_set)
    if num_workers > 0:
        data_load = torch.utils.data.DataLoader(
            data, batch_size=num_workers,
            num_workers=0, collate_fn=collate_fn_BSD)
    else:
        data_load = torch.utils.data.DataLoader(
            data, batch_size=1,
            num_workers=0, collate_fn=collate_fn_BSD)
    if verbose:
        print('starting outer loop\n')
    for files, images in tqdm(data_load):
        if verbose:
            print('starting pytorch processing\n')
        # first loop applying the model via parallelism in pytorch
        w_mat = []
        neighbors = []
        im_size = []
        for im in images:
            im_s = im.shape[1:]
            if downsample >= 1:
                im = torch.nn.functional.interpolate(
                    im.unsqueeze(0), scale_factor=1 / downsample, mode='bilinear')[0]
            _ = module(im.unsqueeze(0))
            if separate:
                w_m, neigh, _ = module.infer_w_sep(shift, subsamp)
                w_mat.append([wm[0] for wm in w_m])
                pad = 0
                n_neigh = int(w_m[0].shape[1])
            else:
                # assuming that the first entry has the least shifts
                i0 = shift[0][0]
                j0 = shift[0][1]
                shift_0 = [[i-i0, j-j0] for i,j in shift]
                w_m, neigh, _ = module.infer_w(shift_0, subsamp, interpolate=interpolate)
                w_mat.append(w_m[0])
                pad = shift[0]
                n_neigh = int(w_m.shape[1] / len(shift))
            neighbors.append(neigh)
            im_size.append(im_s)
        # second loop using CPU parallelism to run contour extraction
        if verbose:
            print('starting CPU processing\n')
        if num_workers > 0:
            workers = multiprocessing.Pool(processes=num_workers)
            processes = []
            for w_m, neigh, im_s, file in zip(w_mat, neighbors, im_size, files):
                if separate:
                    for i_w, (w_m_s, nei) in enumerate(zip(w_m, neigh)):
                        os.makedirs(os.path.join(out_dir, str(i_w)), exist_ok=True)
                        p = workers.apply_async(
                            _process_image_BSD,
                            args=(w_m_s, nei, os.path.join(out_dir, str(i_w)), file),
                            kwds={'verbose': verbose, 'im_size': im_s, 'prior_shift': prior_shift,
                                  'transform': transform, 'pad': pad, 'q_zero': q_zero, 'n_neigh': n_neigh})
                        processes.append(p)
                else:
                    p = workers.apply_async(
                        _process_image_BSD,
                        args=(w_m, neigh, out_dir, file),
                        kwds={'verbose': verbose, 'im_size': im_s, 'prior_shift': prior_shift,
                              'transform': transform, 'pad': pad, 'q_zero': q_zero, 'n_neigh': n_neigh})
                    processes.append(p)
            for p in processes:
                p.get()
            # just to be safe, i.e. not to interfere with pytorches multiprocessing
            workers.close()
        else:
            for w_m, neigh, im_s, file in zip(w_mat, neighbors, im_size, files):
                if separate:
                    for i_w, (w_m_s, nei) in enumerate(zip(w_m, neigh)):
                        os.makedirs(os.path.join(out_dir, str(i_w)), exist_ok=True)
                        _process_image_BSD(w_m_s, nei, os.path.join(out_dir, str(i_w)), file, prior_shift=prior_shift,
                                           q_zero=q_zero, im_size=im_s, verbose=verbose, pad=pad, transform=transform,
                                           n_neigh=n_neigh)
                else:
                    _process_image_BSD(w_m, neigh, out_dir, file, prior_shift=prior_shift, q_zero=q_zero,
                                       im_size=im_s, verbose=verbose, pad=pad, transform=transform,
                                       n_neigh=n_neigh)


def collate_fn_BSD(data):
    files = [i_dat[0] for i_dat in data]
    images = [i_dat[1] for i_dat in data]
    return files, images
