import os

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch


def mmhr_to_dbz(mmhr):
    mmhr = torch.tensor(mmhr)
    log10 = torch.log(torch.tensor(10.)).to(mmhr.device)
    log200 = torch.log(torch.tensor(200.)).to(mmhr.device)
    zero_mmhr = torch.pow(torch.tensor(1 / 200), 5 / 8).to(mmhr.device)
    dbz = (8 / 5 * torch.log(torch.clamp(mmhr, zero_mmhr)) + log200) * 10 / log10
    dbz = dbz.numpy() * 2 - 30
    dbz[dbz < 0] = 0
    return dbz


def save_plots(fig_names, outputs, save_root, figsize=(6, 6), vmin=0, vmax=40, cmap="turbo", **imshow_args):
    os.makedirs(save_root, exist_ok=True)

    for (fig_name, output) in zip(fig_names, outputs):
        fig = plt.figure(figsize=figsize)
        ax = plt.axes()
        ax.set_axis_off()

        output[output < 0] = 0
        alpha = output.copy()
        alpha[alpha < 1] = 0
        alpha[alpha > 1] = 1

        img = ax.imshow(output, alpha=alpha, vmin=vmin, vmax=vmax, cmap=cmap, **imshow_args)
        plt.savefig(os.path.join(save_root, fig_name), bbox_inches='tight')
        plt.close()


def add_palette(array2d):
    palette = [0, 0, 0, 0, 237, 237, 0, 237, 237, 0, 237, 237, 0, 237, 237, 0, 237, 237, 0, 237, 237, 0, 237, 237, 0,
               237, 237, 0, 237, 237, 0, 217, 0, 0, 217, 0, 0, 217, 0, 0, 217, 0, 0, 217, 0, 0, 217, 0, 0, 217, 0, 0,
               217, 0, 0, 217, 0, 0, 217, 0, 0, 145, 0, 0, 145, 0, 0, 145, 0, 0, 145, 0, 0, 145, 0, 0, 145, 0, 0, 145,
               0, 0, 145, 0, 0, 145, 0, 0, 145, 0, 255, 255, 0, 255, 255, 0, 255, 255, 0, 255, 255, 0, 255, 255, 0, 255,
               255, 0, 255, 255, 0, 255, 255, 0, 255, 255, 0, 255, 255, 0, 231, 193, 0, 231, 193, 0, 231, 193, 0, 231,
               193, 0, 231, 193, 0, 231, 193, 0, 231, 193, 0, 231, 193, 0, 231, 193, 0, 231, 193, 0, 255, 145, 0, 255,
               145, 0, 255, 145, 0, 255, 145, 0, 255, 145, 0, 255, 145, 0, 255, 145, 0, 255, 145, 0, 255, 145, 0, 255,
               145, 0, 255, 0, 0, 255, 0, 0, 255, 0, 0, 255, 0, 0, 255, 0, 0, 255, 0, 0, 255, 0, 0, 255, 0, 0, 255, 0,
               0, 255, 0, 0, 200, 0, 0, 200, 0, 0, 200, 0, 0, 200, 0, 0, 200, 0, 0, 200, 0, 0, 200, 0, 0, 200, 0, 0,
               200, 0, 0, 200, 0, 0, 170, 0, 0, 170, 0, 0, 170, 0, 0, 170, 0, 0, 170, 0, 0, 170, 0, 0, 170, 0, 0, 170,
               0, 0, 170, 0, 0, 170, 0, 0, 255, 0, 241, 255, 0, 241, 255, 0, 241, 255, 0, 241, 255, 0, 241, 255, 0, 241,
               255, 0, 241, 255, 0, 241, 255, 0, 241, 255, 0, 241, 151, 0, 181, 151, 0, 181, 151, 0, 181, 151, 0, 181,
               151, 0, 181, 151, 0, 181, 151, 0, 181, 151, 0, 181, 151, 0, 181, 151, 0, 181, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245,
               141, 173, 245, 141, 173, 245, 141, 173, 245, 141, 173, 245, 141]
    new_im = Image.fromarray(np.array(array2d))
    new_im.putpalette(palette)
    return new_im


def save_plots_palette(fig_names, outputs, save_root, figsize=(6, 6), **imshow_args):
    os.makedirs(save_root, exist_ok=True)

    for (fig_name, output) in zip(fig_names, outputs):
        fig = plt.figure(figsize=figsize)
        ax = plt.axes()
        ax.set_axis_off()

        output = output.astype('uint8')
        image = add_palette(output)
        rgb_image = image.convert('RGB')
        r, g, b = rgb_image.split()
        alpha = output
        alpha[alpha < 0] = 0
        alpha[alpha > 1] = 255
        alpha = Image.fromarray(alpha.astype('uint8'), mode='L')

        img = ax.imshow(Image.merge('RGBA', (r, g, b, alpha)), **imshow_args)
        plt.savefig(os.path.join(save_root, fig_name), bbox_inches='tight')
        plt.close()


def draw_figure(plot, x, y, title='', xlabel='', ylabel='', xlim=None, ylim=None, xticks=None, yticks=None, grid=True,
                labels=None, hidden_spines=False, hidden_ticks=False, legend=True, color=None):
    plot.title(title)
    plot.xlabel(xlabel)
    plot.ylabel(ylabel)
    if hidden_ticks:
        plot.tick_params(color='#B0B0B0', direction='in')
    if hidden_spines:
        ax = plot.gca()
        ax.spines['right'].set_color('None')
        ax.spines['top'].set_color('#B0B0B0')
        ax.spines['bottom'].set_color('#B0B0B0')
        ax.spines['left'].set_color('#B0B0B0')
    if grid:
        plot.grid()
    if xlim is not None:
        plot.xlim(xlim)
    if ylim is not None:
        plot.ylim(ylim)
    if xticks is not None:
        plot.xticks(xticks[0], xticks[1])
    if yticks is not None:
        plot.yticks(yticks[0], yticks[1])
    for i in range(len(x)):
        if labels is not None:
            label = labels[i]
        else:
            label = str(i)
        if color is None:
            plot.plot(x[i], y[i], label=label)
        else:
            plot.plot(x[i], y[i], label=label, color=color[i])
    if legend:
        plot.legend()


def draw_csi(csi, x):
    plt.clf()
    n_subfigure = 0
    bar = 4
    for threshold in csi.keys():
        n_subfigure += 1
        plt.subplot(1, len(csi), n_subfigure)
        if n_subfigure == 1:
            ylabel = 'CSI'
            yticks = ([0, 0.2, 0.4, 0.6, 0.8], [0, 0.2, 0.4, 0.6, 0.8])
        else:
            ylabel = ''
            yticks = ([0, 0.2, 0.4, 0.6, 0.8], ['', '', '', '', ''])
        draw_figure(plt, x, csi[threshold], title=r'Precipitation [mm/h] $\geq$ ' + str(threshold),
                    xlabel='Prediction interval [min]', ylabel=ylabel, xlim=(0, len(x[0])), ylim=(0, 0.8),
                    yticks=yticks, legend=False, hidden_spines=True, xticks=(
                [i * bar for i in range(int(len(x[0]) / bar - 0.01) + 1)],
                [i * bar for i in range(int(len(x[0]) / bar - 0.01) + 1)]))
