import io
import math
import torch
import numpy as np
import matplotlib.pyplot as plt

def bar_to_numpy(a, b):
    fig = plt.figure(dpi=200)
    buffer = io.BytesIO()
    plt.bar(range(len(a)), a, alpha=0.5)
    plt.bar(range(len(b)), b, alpha=0.5)
    fig.savefig(buffer, format='raw', dpi=200)
    buffer.seek(0)
    w, h = fig.bbox.bounds[2:]
    img = np.frombuffer(buffer.getvalue(), dtype='uint8')
    img = img.reshape(int(h), int(w), -1).transpose(2, 0, 1)
    return img

def normalize_stft_vis(s, p1=0.05, p2=0.95):
    s = [torch.clip(cur_s, torch.quantile(cur_s, p1), torch.quantile(cur_s, p2)) for cur_s in s]
    s = [(cur_s-cur_s.min())/(cur_s.max()-cur_s.min()) for cur_s in s]
    return s

def stft_to_wave(s):
    wave = [torch.istft(cur_s, n_fft=446, hop_length=math.floor(446/1),
            win_length=446, return_complex=False, normalized=True) for cur_s in s]
    wave = [(cur_wave - cur_wave.min()) / (cur_wave.max() - cur_wave.min()) for cur_wave in wave]
    # (0 ~ 1), shape: (1, L)
    return wave

def backward(accelerator, loss, retain_graph=False):
    if accelerator is None:
        loss.backward(retain_graph=retain_graph)
    else:
        accelerator.backward(loss, retain_graph=retain_graph)

@torch.no_grad()
def inference(hide, find, img1, img2, quantize_mid=True):
    init_mode = hide.training or find.training
    hide.eval()
    find.eval()
    
    h = hide(img1, img2)
    if quantize_mid:
        h = ((h + 1) / 2.0 * 255).byte().float()
        h = (h / 255.0 - 0.5) * 2
    f = find(h)
    
    if init_mode:
        hide.train()
        find.train()
    else:
        hide.eval()
        find.eval()
    return h, f

@torch.no_grad()
def inference_multi(hide, find, img1, img2, spec_len, cur_lambda, condition=False, quantize_mid=True):
    init_mode = hide.training or find.training
    hide.eval()
    find.eval()
    
    h = img2
    h_outs = []
    for i in reversed(range(spec_len)):
        if condition:
            h = hide(img1[:, i, :, :, :], h, cur_lambda)
        else:
            h = hide(img1[:, i, :, :, :], h)
        h_outs.append(h.detach().clone())

    if quantize_mid:
        h = ((h + 1) / 2.0 * 255).byte().float()
        h = (h / 255.0 - 0.5) * 2
    h_out = h.detach().clone()

    f_outs = []
    for i in range(spec_len):
        if condition:
            h, f = find(h, cur_lambda)
        else:
            h, f = find(h)
        f_outs.append(f.detach().clone())
    
    if init_mode:
        hide.train()
        find.train()
    else:
        hide.eval()
        find.eval()
    return h_out, h_outs, f_outs