import math

from PIL import Image
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

from deptheval.utils.common import pathlib_file


def imread_rgb(img_f):
    return cv2.imread(str(pathlib_file(img_f)))[..., ::-1].copy()


def imwrite_rgb(img_f, img, verbose=False):
    img_f = pathlib_file(img_f)
    img_f.parent.mkdir(parents=True, exist_ok=True)
    cv2.imwrite(str(img_f), img[..., ::-1])
    if verbose:
        print(f'Saved to {img_f.resolve()}')


def imwrite_rgb_pdf(img_f, img, verbose=False):
    img_f = pathlib_file(img_f)
    img_f.parent.mkdir(parents=True, exist_ok=True)
    pil_img = Image.fromarray(img, mode='RGB')
    pil_img.save(img_f)
    if verbose:
        print(f'Saved to {img_f.resolve()}')


def resize(img, H=None, W=None, interpolation=cv2.INTER_NEAREST, return_sc=False):
    '''
    if both H and W are specified, resize to smaller one while keeping aspect ratio
    :param img:
    :param H:
    :param W:
    :param interpolation:
    :param return_sc:
    :return:
    '''
    cur_H, cur_W = img.shape[:2]
    if (H is not None) and (W is not None):
        H = int(H)
        W = int(W)
        if H / cur_H < W / cur_W:
            W = None
        else:
            H = None
    if H is not None:
        H = int(H)
        img = cv2.resize(img, (int(img.shape[1] / img.shape[0] * H), H), interpolation=interpolation)
    if W is not None:
        W = int(W)
        img = cv2.resize(img, (W, int(img.shape[0] / img.shape[1] * W)), interpolation=interpolation)
    if return_sc:
        sc = img.shape[0] / cur_H
        return img, sc
    return img


def display_img(img, H=None, W=None, order='rgb', name='img'):
    if isinstance(img, torch.Tensor):
        return display_img(img.cpu().numpy(), H, W, order)
    if (H is not None) and (W is not None):
        cur_H, cur_W = img.shape[0], img.shape[1]
        if W * cur_H > H * cur_W:
            W = None
        else:
            H = None
    img_to_show = img.copy()
    if H is not None:
        img_to_show = cv2.resize(img_to_show, (int(img_to_show.shape[1] / img_to_show.shape[0] * H), H))
    if W is not None:
        img_to_show = cv2.resize(img_to_show, (W, int(img_to_show.shape[0] / img_to_show.shape[1] * W)))
    if order == 'rgb':
        img_to_show = img_to_show[:, :, ::-1]
    cv2.imshow(name, img_to_show)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


def draw_rectangle(img, tl_i, tl_j, br_i, br_j, color=(0, 255, 0), thickness=None):
    '''
    :param img: numpy array, shape (H, W, 3), dtype=np.uint8
    :param tl_i: top left i
    :param tl_j: top left j
    :param br_i: bottom right i
    :param br_j: bottom right j
    :param color: tuple of 3 ints, e.g. (0, 255, 0)
    :param thickness: int, e.g. 2
    :return: modified img
    '''
    new_img = img.copy()
    # assert 0 <= tl_i <= br_i <= img.shape[0] and 0 <= tl_j <= br_j <= img.shape[1]
    if thickness is None:
        H, W = img.shape[:2]
        thickness = int(math.ceil(0.0025 * min(H, W)))
    cv2.rectangle(new_img, (tl_j, tl_i), (br_j, br_i), color, thickness)
    return new_img


def draw_dot(img, i, j, color=(255, 0, 0), rad=2, inplace=False):
    H, W = img.shape[:2]
    tl_i = max(0, i - rad)
    tl_j = max(0, j - rad)
    br_i = min(H, i + rad + 1)
    br_j = min(W, j + rad + 1)
    if inplace:
        img[tl_i:br_i, tl_j:br_j] = color
        return img
    else:
        new_img = img.copy()
        new_img[tl_i:br_i, tl_j:br_j] = color
        return new_img


def ds_img(img, valid, max_H, max_W):
    if max_H is None and max_W is None:
        return img, valid
    if img.ndim == 2:
        img = img[:, :, None]
        squeeze = True
    else:
        squeeze = False
    assert img.ndim == 3, img.shape
    H, W, D = img.shape
    if max_H is None:
        max_H = H
    if max_W is None:
        max_W = W
    dtype = img.dtype
    assert img.shape[:2] == valid.shape
    ds_sc = 1
    while H % ds_sc != 0 or W % ds_sc != 0 or H / ds_sc > max_H or W / ds_sc > max_W:
        ds_sc += 1
    n, m = H // ds_sc, W // ds_sc
    img = np.transpose(img.reshape(n, ds_sc, m, ds_sc, D), (0, 2, 1, 3, 4)).astype(np.float32)  # (n, m, ds_sc, ds_sc, D)
    valid = np.transpose(valid.reshape(n, ds_sc, m, ds_sc), (0, 2, 1, 3)).astype(np.float32)[:, :, :, :, None]  # (n, m, ds_sc, ds_sc, 1)
    img = ((img * valid).sum(axis=(2, 3)) / np.clip(valid.sum(axis=(2, 3)), a_min=1e-10, a_max=None)).astype(dtype)  # (n, m, D)
    valid = valid.squeeze(-1).sum(axis=(2, 3)) > 0  # (n, m)
    if squeeze:
        img = img.squeeze(axis=-1)
    return img, valid


def gen_heatmap(x: np.ndarray, valid=None, lb=None, ub=None, col_map=cv2.COLORMAP_JET):
    '''
    no color bar, only assign color to each pixel
    '''
    if valid is None:
        valid = ~(np.isnan(x) | np.isinf(x))
    if lb is None:
        lb = x[valid].min()
    if ub is None:
        ub = x[valid].max()
    x[~valid] = lb
    x = np.round(np.clip((x - lb) / (ub - lb), 0, 1) * 255).astype(np.uint8)
    ret = cv2.applyColorMap(x, col_map)
    ret[~valid] = 0
    return ret


def convert_png_image_to_pdf(png_image_f, pdf_image_f, verbose=False):
    png_image_f = pathlib_file(png_image_f)
    pdf_image_f = pathlib_file(pdf_image_f)
    pdf_image_f.parent.mkdir(parents=True, exist_ok=True)
    img = Image.open(png_image_f)
    img.save(pdf_image_f)
    if verbose:
        print(f'Saved to {pdf_image_f}')


def gen_err_heatmap(x, valid, max_H=720, max_W=1280, lb=0, ub=None, title=None, xlabel=None, ylabel=None,
                    xticks=None, yticks=None, xtick_labels=None, ytick_labels=None):
    '''
    add colorbar on side
    :param x:
    :param valid:
    :param max_H:
    :param max_W:
    :param ub:
    :param title:
    :param xlabel:
    :param ylabel:
    :return:
    '''
    assert isinstance(x, np.ndarray) and x.ndim == 2
    if valid is None:
        valid = ~(np.isnan(x) | np.isinf(x))

    x, valid = ds_img(x, valid, max_H, max_W)

    # for heatmap
    if ub is None:
        ub = np.quantile(x[valid], .95)

    x_masked = np.ma.masked_where(~valid, x)

    # Define a colormap with black for invalid pixels
    cmap = plt.cm.jet  # Change colormap as needed
    cmap.set_bad(color='black')  # Color for masked (invalid) pixels

    # Create the plot
    fig, ax = plt.subplots()
    norm = mcolors.Normalize(vmin=lb, vmax=ub)
    heatmap = ax.imshow(x_masked, cmap=cmap, norm=norm)
    cbar = plt.colorbar(heatmap, ax=ax, orientation='vertical', label='Error')

    # Add labels and title
    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if xticks is not None:
        ax.set_xticks(xticks)
    if yticks is not None:
        ax.set_yticks(yticks)
    if xtick_labels is not None:
        ax.set_xticklabels(xtick_labels)
    if ytick_labels is not None:
        ax.set_yticklabels(ytick_labels)

    # Render figure to a canvas and convert to NumPy array
    canvas = FigureCanvas(fig)
    canvas.draw()
    buf = canvas.buffer_rgba()
    ret = np.asarray(buf, dtype=np.uint8)
    ret = ret[:, :, :3]
    plt.close(fig)
    return ret


def reformat_to_list(x):
    if isinstance(x, list):
        return x
    elif isinstance(x, np.ndarray):
        return x.tolist()
    elif isinstance(x, torch.Tensor):
        return x.detach().cpu().tolist()


def plot_y_on_x(x, y, xmin=None, xmax=None, ymin=None, ymax=None, title=None, xlabel=None, ylabel=None):
    fig, ax = plt.subplots()
    ax.plot(reformat_to_list(x), reformat_to_list(y))
    if title is not None:
        ax.set_title(title)
    ax.legend()
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    canvas = FigureCanvas(fig)
    canvas.draw()

    img = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(canvas.get_width_height()[::-1] + (3,))

    plt.close(fig)
    return img


def plot_y_and_err_bar_on_x(x, y, y_err, xmin=None, xmax=None, ymin=None, ymax=None, title=None, xlabel=None, ylabel=None, label=None, hlines=[]):
    '''
    :param x:
    :param y:
    :param y_err:
    :param xmin:
    :param xmax:
    :param ymin:
    :param ymax:
    :param title:
    :param xlabel:
    :param ylabel:
    :param hlines: list of horizontal lines, each element takes form of dict(y=, label=, color=, linestyle=), where label, color, and linestyle are optional
    :return:
    '''
    fig, ax = plt.subplots()
    ax.errorbar(reformat_to_list(x), reformat_to_list(y), yerr=y_err, fmt='-o', label=label)
    if title is not None:
        ax.set_title(title)
    ax.legend()
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    need_legend = label is not None
    for hline in hlines:
        ax.axhline(**hline)
        if hline.get('label', None) is not None:
            need_legend = True
    if need_legend:
        ax.legend()

    canvas = FigureCanvas(fig)
    canvas.draw()

    img = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(canvas.get_width_height()[::-1] + (3,))

    plt.close(fig)
    return img


def plot_ys_and_err_bar_on_x(x, ys, y_errs, xmin=None, xmax=None, ymin=None, ymax=None, title=None, xlabel=None, ylabel=None, colors=None, labels=None, hlines=[]):
    '''
    :param x:
    :param y:
    :param y_err:
    :param xmin:
    :param xmax:
    :param ymin:
    :param ymax:
    :param title:
    :param xlabel:
    :param ylabel:
    :param hlines: list of horizontal lines, each element takes form of dict(y=, label=, color=, linestyle=), where label, color, and linestyle are optional
    :return:
    '''
    need_legend = False
    fig, ax = plt.subplots()
    if labels is None:
        labels = [None for _ in ys]
    if colors is None:
        colors = [None for _ in ys]
    for y, y_err, color, label in zip(ys, y_errs, colors, labels):
        ax.errorbar(reformat_to_list(x), reformat_to_list(y), yerr=reformat_to_list(y_err), fmt='-o', color=color, label=label)
        if label is not None:
            need_legend = True
    if title is not None:
        ax.set_title(title)
    ax.legend()
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    for hline in hlines:
        ax.axhline(**hline)
        if hline.get('label', None) is not None:
            need_legend = True
    if need_legend:
        ax.legend()

    canvas = FigureCanvas(fig)
    canvas.draw()

    img = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(canvas.get_width_height()[::-1] + (3,))

    plt.close(fig)
    return img


def plot_ys_on_same_x(x, ys, zs, xmin=None, xmax=None, ymin=None, ymax=None, title=None, xlabel=None, ylabel=None, legend_loc='best'):
    fig, ax = plt.subplots()
    for y, z in zip(ys, zs):
        ax.plot(reformat_to_list(x), reformat_to_list(y), label=z)
    if title is not None:
        ax.set_title(title)
    ax.legend(loc=legend_loc)
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    canvas = FigureCanvas(fig)
    canvas.draw()

    img = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(canvas.get_width_height()[::-1] + (3,))

    plt.close(fig)
    return img


def get_img_from_plt_fig(fig):
    canvas = FigureCanvas(fig)
    canvas.draw()

    img = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(canvas.get_width_height()[::-1] + (3,))

    plt.close(fig)
    return img


def plot_points(xs, ys, xlim=None, ylim=None, xlabel=None, ylabel=None, title=None):
    fig, ax = plt.subplots()

    ax.scatter(np.array(reformat_to_list(xs)), np.array(reformat_to_list(ys)), color='red')
    if xlim is not None:
        ax.set_xlim(**xlim)
    if ylim is not None:
        ax.set_ylim(**ylim)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if title is not None:
        ax.set_title(title)
    ax.legend()

    canvas = FigureCanvas(fig)
    canvas.draw()

    img = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(canvas.get_width_height()[::-1] + (3,))

    plt.close(fig)
    return img


def caption_below(img, txt, h=60, font=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255, 0, 0), thickness=2):
    txt_img = np.ones_like(np.repeat(img[[0]], h, axis=0)) * 255
    org = (h // 2, h // 2)
    txt_img = cv2.putText(txt_img, txt, org, font, fontScale, color, thickness, cv2.LINE_AA)
    return np.concatenate([img, txt_img], axis=0).copy()


def get_txt_img(w, txt, h=60, font=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(255, 0, 0), thickness=2):
    txt_img = np.ones((h, w, 3), dtype=np.uint8) * 255
    cx, cy = w // 2, h // 2
    (text_w, text_h), baseline = cv2.getTextSize(txt, font, fontScale, thickness)
    top_left_x = int(cx - text_w / 2)
    top_left_y = int(cy + text_h / 2)  # OpenCV puts origin at bottom-left of text
    org = (top_left_x, top_left_y)
    txt_img = cv2.putText(txt_img, txt, org, font, fontScale, color, thickness, cv2.LINE_AA)
    return txt_img


def select_pxl(img: np.ndarray, H=None, W=None, order='rgb', name='img', do_round=True):
    '''
    User clicks one pixel in the image, then presses `Esc` to close the image
    :param img:
    :param H:
    :param W:
    :param order:
    :param name:
    :return: u, v
    '''
    clicked_xy = None
    def click_event(event, x, y, flags, param):
        nonlocal clicked_xy
        if event == cv2.EVENT_LBUTTONDOWN:
            clicked_xy = (x, y)

    cur_H, cur_W = img.shape[0], img.shape[1]
    if (H is not None) and (W is not None):
        if W * cur_H > H * cur_W:
            W = None
        else:
            H = None
    img_to_show = img.copy()
    sc = 1
    if H is not None:
        sc = H / cur_H
        img_to_show = cv2.resize(img_to_show, (int(img_to_show.shape[1] / img_to_show.shape[0] * H), H))
    if W is not None:
        sc = W / cur_W
        img_to_show = cv2.resize(img_to_show, (W, int(img_to_show.shape[0] / img_to_show.shape[1] * W)))
    if order == 'rgb':
        img_to_show = img_to_show[:, :, ::-1]
    _name = f'select pxl-{name}'
    while clicked_xy is None:
        cv2.imshow(_name, img_to_show)
        cv2.setMouseCallback(_name, click_event)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    u, v = clicked_xy
    u /= sc
    v /= sc
    if do_round:
        u = int(round(u))
        v = int(round(v))
    return u, v


def save_video(imgs, video_path, format='MP4V', fps=15, order='rgb'):
    video_writer = cv2.VideoWriter(pathlib_file(video_path).as_posix(), cv2.VideoWriter_fourcc(*format), fps, (imgs[0].shape[1], imgs[0].shape[0]))
    for img in imgs:
        video_writer.write(img if order == 'bgr' else img[:, :, ::-1])
    video_writer.release()


def gen_hist_img(x, bins=500, amin=None, amax=None, vlines=[], title=None):
    '''
    :param x: np array, shape (n,)
    :param bins: int
    :param amin: x.min() if amin is None
    :param amax: x.max() if amax is None
    :return:
    '''
    if amin is None:
        amin = x.min()
    if amax is None:
        amax = x.max()
    fig, ax = plt.subplots()
    ax.hist(x, bins=bins, range=(amin, amax))
    for vline_x, col in vlines:
        ax.axvline(x=vline_x, color=col)
    if title is not None:
        ax.set_title(title)
    canvas = FigureCanvas(fig)
    canvas.draw()
    buf = canvas.buffer_rgba()
    ret = np.asarray(buf, dtype=np.uint8)
    ret = ret[:, :, :3]
    plt.close(fig)
    return ret


def gen_dist_img(x, p, other_ps=[], vlines=[], title=None):
    '''
    :param x: np array, shape (n,)
    :param p: np array, shape (n,)
    :params other_ps: [(prob, col)], prob: np array, shape (n,); col: str
    :params vlines: [(x, col)], x: float; col: str
    :return:
    '''
    fig, ax = plt.subplots()
    ax.plot(x, p)
    for other_p, col in other_ps:
        ax.plot(x, other_p, color=col)
    for vline_x, col in vlines:
        ax.axvline(x=vline_x, color=col)
    if title is not None:
        ax.set_title(title)
    ax.set_ylim(bottom=0)
    canvas = FigureCanvas(fig)
    canvas.draw()
    buf = canvas.buffer_rgba()
    ret = np.asarray(buf, dtype=np.uint8)
    ret = ret[:, :, :3]
    plt.close(fig)
    return ret
