import os
import sys
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import blobfile as bf
import torch


def create_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)


def create_mask(width, height, mask_width, mask_height, x=None, y=None):
    mask = np.zeros((height, width))
    mask_x = x if x is not None else random.randint(0, width - mask_width)
    mask_y = y if y is not None else random.randint(0, height - mask_height)
    mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1
    return mask


def stitch_images(inputs, *outputs, img_per_row=2):
    gap = 5
    columns = len(outputs) + 1

    width, height = inputs[0][:, :, 0].shape
    img = Image.new('RGB', (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row)))
    images = [inputs, *outputs]

    for ix in range(len(inputs)):
        xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap
        yoffset = int(ix / img_per_row) * height

        for cat in range(len(images)):
            im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze()
            im = Image.fromarray(im)
            img.paste(im, (xoffset + cat * width, yoffset))

    return img


def imshow(img, title=''):
    fig = plt.gcf()
    fig.canvas.set_window_title(title)
    plt.axis('off')
    plt.imshow(img, interpolation='none')
    plt.show()


def imsave(img, path):
    im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze())
    im.save(path)


class Progbar(object):

    def __init__(self, target, width=25, verbose=1, interval=0.05,
                 stateful_metrics=None):
        self.target = target
        self.width = width
        self.verbose = verbose
        self.interval = interval
        if stateful_metrics:
            self.stateful_metrics = set(stateful_metrics)
        else:
            self.stateful_metrics = set()

        self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
                                  sys.stdout.isatty()) or
                                 'ipykernel' in sys.modules or
                                 'posix' in sys.modules)
        self._total_width = 0
        self._seen_so_far = 0
        # We use a dict + list to avoid garbage collection
        # issues found in OrderedDict
        self._values = {}
        self._values_order = []
        self._start = time.time()
        self._last_update = 0

    def update(self, current, values=None):

        values = values or []
        for k, v in values:
            if k not in self._values_order:
                self._values_order.append(k)
            if k not in self.stateful_metrics:
                if k not in self._values:
                    self._values[k] = [v * (current - self._seen_so_far),
                                       current - self._seen_so_far]
                else:
                    self._values[k][0] += v * (current - self._seen_so_far)
                    self._values[k][1] += (current - self._seen_so_far)
            else:
                self._values[k] = v
        self._seen_so_far = current

        now = time.time()
        info = ' - %.0fs' % (now - self._start)
        if self.verbose == 1:
            if (now - self._last_update < self.interval and
                    self.target is not None and current < self.target):
                return

            prev_total_width = self._total_width
            if self._dynamic_display:
                sys.stdout.write('\b' * prev_total_width)
                sys.stdout.write('\r')
            else:
                sys.stdout.write('\n')

            if self.target is not None:
                numdigits = int(np.floor(np.log10(self.target))) + 1
                barstr = '%%%dd/%d [' % (numdigits, self.target)
                bar = barstr % current
                prog = float(current) / self.target
                prog_width = int(self.width * prog)
                if prog_width > 0:
                    bar += ('=' * (prog_width - 1))
                    if current < self.target:
                        bar += '>'
                    else:
                        bar += '='
                bar += ('.' * (self.width - prog_width))
                bar += ']'
            else:
                bar = '%7d/Unknown' % current

            self._total_width = len(bar)
            sys.stdout.write(bar)

            if current:
                time_per_unit = (now - self._start) / current
            else:
                time_per_unit = 0
            if self.target is not None and current < self.target:
                eta = time_per_unit * (self.target - current)
                if eta > 3600:
                    eta_format = '%d:%02d:%02d' % (eta // 3600,
                                                   (eta % 3600) // 60,
                                                   eta % 60)
                elif eta > 60:
                    eta_format = '%d:%02d' % (eta // 60, eta % 60)
                else:
                    eta_format = '%ds' % eta

                info = ' - ETA: %s' % eta_format
            else:
                if time_per_unit >= 1:
                    info += ' %.0fs/step' % time_per_unit
                elif time_per_unit >= 1e-3:
                    info += ' %.0fms/step' % (time_per_unit * 1e3)
                else:
                    info += ' %.0fus/step' % (time_per_unit * 1e6)

            for k in self._values_order:
                info += ' - %s:' % k
                if isinstance(self._values[k], list):
                    avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
                    if abs(avg) > 1e-3:
                        info += ' %.4f' % avg
                    else:
                        info += ' %.4e' % avg
                else:
                    info += ' %s' % self._values[k]

            self._total_width += len(info)
            if prev_total_width > self._total_width:
                info += (' ' * (prev_total_width - self._total_width))

            if self.target is not None and current >= self.target:
                info += '\n'

            sys.stdout.write(info)
            sys.stdout.flush()

        elif self.verbose == 2:
            if self.target is None or current >= self.target:
                for k in self._values_order:
                    info += ' - %s:' % k
                    avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
                    if avg > 1e-3:
                        info += ' %.4f' % avg
                    else:
                        info += ' %.4e' % avg
                info += '\n'

                sys.stdout.write(info)
                sys.stdout.flush()

        self._last_update = now

    def add(self, n, values=None):
        self.update(self._seen_so_far + n, values)

def pil_sample(img_batches, img_names, idx, dest, shape=(256, 256), name_prefix=None, assemble_masked=True, refine_masked=False):
    """NOTE This function cannot be used for saving masked images."""

    if not os.path.exists(dest):
        os.makedirs(dest)

    # convert images into [0, 255] except masks
    for i, img_batch in enumerate(img_batches):
        if 'mask' in img_names[i] or 'uMap' in img_names[i]:
            img_batches[i] = img_batch * 255.0
        else:
            img_batches[i] = (img_batch+1)*127.5

    # construct masked gt&tv
    if assemble_masked:
        gt, tv = img_batches[img_names.index('gt')], img_batches[img_names.index('tv')]
        gt_mask, tv_mask = img_batches[img_names.index('gt_mask')], img_batches[img_names.index('tv_mask')]
        img_batches+=[gt*(1-gt_mask), tv*(1-tv_mask)]
        img_names+=['gt_masked','tv_masked']

        if refine_masked:
            gt_misf, tv_misf, gt_mk, tv_mk = img_batches[img_names.index('gt_misf')], img_batches[img_names.index('tv_misf')], img_batches[img_names.index('gt_refine_mask')], img_batches[img_names.index('tv_refine_mask')]
            img_batches+=[gt_misf*(1-gt_mk), tv_misf*(1-tv_mk)]
            img_names+=['gt_conf_masked', 'tv_conf_masked']


    img_list = [batch[idx] for batch in img_batches]
    for i, img in enumerate(img_list):

        # restore img_copy and do not destroy the data of img
        img_copy = img.clone().data.permute(1, 2, 0).cpu().numpy()
        img_copy = np.clip(img_copy, 0, 255)
        img_copy = img_copy.astype(np.uint8)

        pil_img = Image.fromarray(img_copy)
        if pil_img.size != shape:
            pil_img = pil_img.resize(shape)

        # Save to certain path
        if name_prefix:
            save_img_name = name_prefix + '_' + img_names[i] + '.png'
        else:
            save_img_name = img_names[i]+'.png'
        save_img_path = os.path.join(dest, save_img_name)
        
        pil_img.save(save_img_path)

def cv2_sample(img_batches, img_names, idx, dest, mode, shape=(256, 256), name_prefix=None):
    img_save_dir = os.path.join(dest, mode)
    if not os.path.exists(img_save_dir):
        os.makedirs(img_save_dir)

    # convert images into [0, 255] except masks
    for i, img_batch in enumerate(img_batches):
        if 'mask' in img_names[i]:
            continue
        img_batches[i] = (img_batch+1)*127.5

    # construct mask gt&tv
    gt, gt_mask = img_batches[img_names.index('gt')], img_batches[img_names.index('gt_mask')]
    img_batches.append(gt*(1-gt_mask))
    img_names.append('gt_masked')
    tv, tv_mask = img_batches[img_names.index('tv')], img_batches[img_names.index('tv_mask')]
    img_batches.append(tv*(1-tv_mask))
    img_names.append('tv_masked')

    img_list = [batch[idx] for batch in img_batches]
    for i, img in enumerate(img_list):
        # Recover normalization of masks
        if img_names[i] in ['gt_mask', 'tv_mask']:
            img = img*255

        # restore img_copy and do not destroy the data of img
        img_copy = img.clone().data.permute(1, 2, 0).cpu().numpy()
        img_copy = np.clip(img_copy, 0, 255)
        img_copy = img_copy.astype(np.uint8)
        img_copy = cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB)
        if img_copy.shape[:2]!=shape:
            img_copy = cv2.resize(img_copy, shape)

        # Save to certain path
        if name_prefix:
            save_img_name = name_prefix + '_' + img_names[i] + '.png'
        else:
            # self.logger.log('no name prefix provided')
            save_img_name = img_names[i]+'.png'
        save_img_path = os.path.join(img_save_dir, save_img_name)
        
        cv2.imwrite(save_img_path, img_copy)

def mask_img_with_grid(img, pix_interval=50, grid_width=10):
    """Mask image with grid mask by pixel interval."""

    width, height = img.size()[-2:]

    for i in range(0, width, pix_interval):
        if i+pix_interval < width:
            img[i+pix_interval-grid_width:i+pix_interval,:] *= 0
    for i in range(0, height, pix_interval):
        if i+pix_interval < height:
            img[:, i+pix_interval-grid_width:i+pix_interval] *= 0

    return img

def load_grid_mask(fpath, size=(256, 256), device='cpu'):
    """Load the grid mask file into specified shape and device."""

    with bf.BlobFile(fpath, "rb") as f:
        pil_mask = Image.open(f)
        pil_mask.load()
    pil_mask = pil_mask.convert("RGB").resize(size)

    arr_mask = np.array(pil_mask)
    arr_mask = (arr_mask>127.5).astype(np.float32).transpose(2,0,1)

    return torch.from_numpy(arr_mask).to(device)

def get_grid_mask(size: tuple, pix_interval: int = 30, mask_width: int = 5, device: str = 'cpu'):

    pure_mask = Image.new("RGB", size)

    pure_mask_arr = np.array(pure_mask)

    for i in range(0, pure_mask_arr.shape[0], pix_interval):
        if i+pix_interval<pure_mask_arr.shape[0]:
            pure_mask_arr[i+pix_interval-mask_width:i+pix_interval,:] += 255
        else:
            continue

    for i in range(0, pure_mask_arr.shape[1], pix_interval):
        if i+pix_interval<pure_mask_arr.shape[1]:
            pure_mask_arr[:,i+pix_interval-mask_width:i+pix_interval] += 255
        else:
            continue

    grid_mask = (pure_mask_arr>127.5).astype(np.float32).transpose(2,0,1)
    return torch.from_numpy(grid_mask).to(device)

def get_uncertainty(kernel, batch_size, height=256, width=256):
    """Get the uncertainty map of given kernel.

    Args:
        kernel (torch.Tensor): the kernel to process
        batch_size (int): batch size
        height (int): the height of desired output
        width (int): the width of desired output

    Returns:
        vis_kernel (torch.Tensor): the processed kernel for better visualization
    """
    N = 3

    kernel = kernel.view(batch_size, N, -1, height, width)

    # kernel_tmp = kernel.norm(p=1, dim=2)
    # kernel_tmp = kernel_tmp.sum(dim=1) / kernel_tmp.size(1)
    # kernel_tmp = kernel_tmp.unsqueeze(dim=1)

    # kernel_tmp = kernel.norm(p=2, dim=2)
    # kernel_tmp = kernel_tmp.sum(dim=1) / kernel_tmp.size(1)
    # kernel_tmp = kernel_tmp.unsqueeze(dim=0)


    # kernel_tmp, _ = kernel.max(dim=2)
    # kernel_tmp = kernel_tmp.sum(dim=1) / kernel_tmp.size(1)
    # kernel_tmp = kernel_tmp.unsqueeze(dim=0)


    kernel_tmp = kernel.sum(dim=2) / kernel.size(2)
    kernel_tmp = kernel_tmp.sum(dim=1) / kernel_tmp.size(1)
    kernel_tmp = kernel_tmp.unsqueeze(dim=1)

    # normalize to [0,1]
    kernel_vis = (kernel_tmp - torch.min(kernel_tmp)) / (torch.max(kernel_tmp) - torch.min(kernel_tmp))

    return torch.cat([kernel_vis]*3, dim=1)