import torch
import random
import numpy as np
from tqdm.auto import tqdm
from .training import inference, backward, inference_multi, bar_to_numpy

def single_step(hide, find, opt, writer, batch, cur_log_step, accelerator=None):
    opt.zero_grad()
    img1, img2 = batch

    # h: looks like img2, but also contains img1
    h = hide(img1, img2)
    f = find(h)

    loss_1 = ((h-img2)**2).mean()
    loss_2 = ((f-img1)**2).mean()
    loss = loss_1 + 1.0 * loss_2
    backward(accelerator, loss)

    if accelerator is None:
        writer.add_scalar('train loss', float(loss), cur_log_step)
    else:
        if accelerator.is_main_process:
            writer.add_scalar('train loss', float(loss), cur_log_step)
    
    opt.step()
    return loss

def single_multistep(hide, find, opt, writer, batch, spec_len, cur_log_step,
                     use_grad=True, accelerator=None):
    if use_grad:
        return single_multistep_cond(hide, find, opt, writer, batch, spec_len, 10.0,
                                     cur_log_step, accelerator, False)
    else:
        return single_multistep_cond_nograd(hide, find, opt, writer, batch, spec_len, 10.0,
                                            cur_log_step, accelerator, False)

def single_multistep_cond(hide, find, opt, writer, batch, spec_len, cur_lambda,
                          cur_log_step, accelerator=None, condition=True):
    opt.zero_grad()
    img1, img2 = batch
    total_loss = 0

    # img1: (bs, l, 2, h, w)
    # img2: (bs, c, h, w)

    # hides with reverse order
    h = img2
    hs = []
    for i in reversed(range(spec_len)):
        hs.append(h)
        if condition:
            h = hide(img1[:, i, :, :, :], h, cur_lambda)
        else:
            h = hide(img1[:, i, :, :, :], h)
        loss_1 = ((h-img2)**2).mean() / spec_len
        total_loss += float(loss_1)
        backward(accelerator, loss_1, retain_graph=True)
    hs = hs[::-1]
    
    # finds with normal order
    h_rec = h
    for i in range(spec_len):
        if condition:
            h_rec, f = find(h_rec, cur_lambda)
        else:
            h_rec, f = find(h_rec)
        loss_2 = ((h_rec-hs[i])**2).mean() / spec_len * 2
        loss_3 = ((f-img1[:, i, :, :, :])**2).mean() / spec_len * cur_lambda
        total_loss += float(loss_2) / 2 + float(loss_3) / cur_lambda
        backward(accelerator, loss_3, retain_graph=True)
    
    if getattr(accelerator, 'is_main_process', True):
        writer.add_scalar(f'train loss/{spec_len}', float(total_loss), cur_log_step)
    
    opt.step()
    return total_loss

def single_multistep_cond_nograd(hide, find, opt, writer, batch, spec_len, cur_lambda,
                                 cur_log_step, accelerator=None, condition=True):
    opt.zero_grad()
    img1, img2 = batch
    total_loss = 0

    # img1: (bs, l, 2, h, w)
    # img2: (bs, c, h, w)

    # hides with reverse order
    h_new = img2
    for i in reversed(range(spec_len)):
        h = h_new.detach()
        if condition:
            h_new = hide(img1[:, i, :, :, :], h, cur_lambda)
            h_rec, f = find(h_new, cur_lambda)
        else:
            h_new = hide(img1[:, i, :, :, :], h)
            h_rec, f = find(h_new)

        loss_1 = ((h_new-h)**2).mean() / spec_len
        loss_2 = ((h_rec-h)**2).mean() / spec_len * 2
        loss_3 = ((f-img1[:, i, :, :, :])**2).mean() / spec_len * cur_lambda
        total_loss += float(loss_1) + float(loss_2) / 2 + float(loss_3) / cur_lambda

        backward(accelerator, loss_1, retain_graph=True)
        backward(accelerator, loss_2, retain_graph=True)
        backward(accelerator, loss_3, retain_graph=True)

    if getattr(accelerator, 'is_main_process', True):
        writer.add_scalar(f'train loss/{spec_len}', float(total_loss), cur_log_step)
    
    opt.step()
    return total_loss

@torch.no_grad()
def single_val(hide, find, writer, val_dl, cur_log_step, accelerator=None):
    if getattr(accelerator, 'is_main_process', True):
        avg_loss, total_n = 0, 0
        for batch in tqdm(val_dl):
            img1, img2 = batch
            bs = img1.shape[0]
            h, f = inference(hide, find, img1, img2, quantize_mid=True)
            loss_1 = ((h-img2)**2).mean()
            loss_2 = ((f-img1)**2).mean()
            loss = loss_1 + 1.0 * loss_2
            avg_loss += float(loss) * bs
            total_n += bs
        avg_loss /= total_n
        writer.add_scalar('validation loss', avg_loss, cur_log_step)

@torch.no_grad()
def single_multival(hide, find, writer, val_dl, max_audio_len, cur_log_step,
                    accelerator=None, device=None):
    if getattr(accelerator, 'is_main_process', True):
        avg_loss, total_n = np.zeros(max_audio_len), np.zeros(max_audio_len)
        for batch in tqdm(val_dl):
            if device is not None:
                batch = [b.to(device) for b in batch]

            spec_len = random.randint(1, max_audio_len)
            img1, img2 = batch
            bs = img1.shape[0]
            _, hs, fs = inference_multi(hide, find, img1, img2, spec_len, 10.0,
                                        condition=False, quantize_mid=True)
            # not necessarily aligned with the training loss
            loss_1 = sum([((h-img2)**2).mean() for h in hs[::-1]]) / len(hs)
            loss_2 = sum([((f-img1[:, i])**2).mean() for i, f in enumerate(fs)]) / len(fs)
            loss = loss_1 + 10.0 * loss_2

            avg_loss[spec_len-1] += float(loss) * bs
            total_n[spec_len-1] += bs

        log_avg_loss = avg_loss / total_n
        for i, cur_avg_loss in enumerate(log_avg_loss):
            writer.add_scalar(f'validation loss/{i+1}', cur_avg_loss, cur_log_step)

@torch.no_grad()
def single_multival_cond(hide, find, writer, val_dl, conds, max_audio_len,
                         cur_log_step, accelerator=None, device=None):
    if getattr(accelerator, 'is_main_process', True):
        # (max_audio_len, cond_n, 2), (max_audio_len, cond_n)
        avg_loss = np.zeros((max_audio_len, len(conds), 2))
        total_n = np.zeros((max_audio_len, len(conds)))
        for batch in tqdm(val_dl):
            if device is not None:
                batch = [b.to(device) for b in batch]

            spec_len = random.randint(1, max_audio_len)
            cur_cond_idx = random.choice(range(len(conds)))
            cur_cond = conds[cur_cond_idx]

            img1, img2 = batch
            bs = img1.shape[0]
            _, hs, fs = inference_multi(hide, find, img1, img2, spec_len, cur_cond,
                                        condition=False, quantize_mid=True)

            # for now, the validation loss only uses 10 for consistency
            loss_1 = sum([((h-img2)**2).mean() for h in hs[::-1]]) / len(hs)
            loss_2 = sum([((f-img1[:, i])**2).mean() for i, f in enumerate(fs)]) / len(fs)

            avg_loss[spec_len-1][cur_cond_idx][0] += float(loss_1) * bs
            avg_loss[spec_len-1][cur_cond_idx][1] += float(loss_2) * bs
            total_n[spec_len-1][cur_cond_idx] += bs

        # (max_audio_len)
        log_avg_loss = ((avg_loss[:, :, 0] + 10 * avg_loss[:, :, 1]) / total_n).mean(1)
        for i, cur_avg_loss in enumerate(log_avg_loss):
            writer.add_scalar(f'validation loss/{i+1}', cur_avg_loss, cur_log_step)

        # create bar chart

        # (max_audio_len), (max_audio_len)
        avg_loss, total_n = avg_loss.mean(1), total_n.mean(1)
        avg_loss_1, avg_loss_2 = avg_loss[:, 0] / total_n, avg_loss[:, 1] / total_n
        bar_chart = bar_to_numpy(avg_loss_1, avg_loss_2)
        writer.add_image(f'validation gamma vs loss', bar_chart, cur_log_step)
