import torch
import torchvision
from ..utils.training import inference, inference_multi, normalize_stft_vis, stft_to_wave

@torch.no_grad()
def single_vis(hide, find, writer, batch, cur_log_step, accelerator=None):
    if getattr(accelerator, 'is_main_process', True):
        img1, img2 = batch
        h1, f1 = inference(hide, find, img1, img2, quantize_mid=True)
        s = [((img1[0].cpu())+1)/2.0,
             ((img2[0].cpu())+1)/2.0,
             ((h1[0].detach().cpu())+1)/2.0,
             ((f1[0].detach().cpu())+1)/2.0]
        s = torchvision.utils.make_grid(torch.stack(s), 4)
        s = (s.numpy() * 255).astype('uint8')
        writer.add_image('train example', s, cur_log_step)

def single_multivis(hide, find, writer, batch, spec_len, cur_log_step, accelerator=None):
    return single_multivis_cond(hide, find, writer, batch, spec_len, 10.0, cur_log_step, accelerator, False)

@torch.no_grad()
def single_multivis_cond(hide, find, writer, batch, spec_len, cur_lambda, cur_log_step,
                         accelerator=None, condition=True):
    if getattr(accelerator, 'is_main_process', True):
        img1, img2 = batch
        h1, _, fs = inference_multi(hide, find, img1, img2, spec_len, cur_lambda, condition, quantize_mid=True)

        # (bs, max_audio_len, 2, 224, 224)
        s1_audio = [img1[0, i][None].permute(0, 2, 3, 1) for i in range(spec_len)]
        s1_audio = stft_to_wave(s1_audio)
        s1 = [img1[0, i, :1, :, :].cpu().repeat(3, 1, 1) for i in range(spec_len)]
        s1 = normalize_stft_vis(s1, 0.05, 0.95)
        s2 = [(img2[0].cpu()+1)/2.0, (h1[0].cpu()+1)/2.0]
        s3_audio = [fs[i][:1].permute(0, 2, 3, 1) for i in range(spec_len)]
        s3_audio = stft_to_wave(s3_audio)
        s3 = [(fs[i][0, :1].cpu().repeat(3, 1, 1)+1) for i in range(spec_len)]
        s3 = normalize_stft_vis(s3, 0.05, 0.95)
        s = s1 + s2 + s3

        s = torchvision.utils.make_grid(torch.stack(s), 2+spec_len*2)
        s = (s.numpy() * 255).astype('uint8')
        writer.add_image(f'train example - lambda:{cur_lambda}/{spec_len}', s, cur_log_step)

        for i, (cur_s1_audio, cur_s3_audio) in enumerate(zip(s1_audio, s3_audio)):
            writer.add_audio(f'train example - lambda:{cur_lambda} - GT/{spec_len}/{i}', cur_s1_audio, cur_log_step, 16000)
            writer.add_audio(f'train example - lambda:{cur_lambda} - Pred/{spec_len}/{i}', cur_s3_audio, cur_log_step, 16000)
