import torch
import torch.nn.functional as F
from einops import repeat

def symlog(x):
    return torch.sign(x) * torch.log(1 + torch.abs(x))

@torch.no_grad()
def encode_imgs(guidance, rgb_image):
    rgb_image = F.interpolate(rgb_image, (512, 512), mode='bilinear', align_corners=False)
    normalized = guidance.image_processor.preprocess(rgb_image)
    latents = guidance.vae.encode(normalized).latent_dist.sample()
    latents = latents * guidance.vae.config.scaling_factor

    return latents


def disentangled_avdc_alignment(guidance, avdc_trainer, prev_rgb, current_rgb, avdc_text_embeddings, noise_level=400, alignment_scale=200, recon_scale=2000, noise=None):
    # render prev obs and encode it into SD latent as x_cond
    # render current obs, add noise according to noise level
    # predict noises with avdc
    # calculate reward with cond/uncond noises
    
    b = prev_rgb.shape[0]
    assert b == current_rgb.shape[0]

    prev_rgb = prev_rgb.half()
    current_rgb = current_rgb.half()

    prev_latents = encode_imgs(guidance, prev_rgb / 255.0).float()
    current_latents = encode_imgs(guidance, current_rgb / 255.0).float()

    t = torch.randint(noise_level, noise_level + 100, [b], dtype=torch.long, device=guidance.device)
    t = t // 10

    if noise is None:
        noise = torch.randn_like(current_latents)

    # predict the noise residual with unet, NO grad!
    with torch.no_grad():
        # add noise
        current_latents_noisy = avdc_trainer.model.q_sample(current_latents, t, noise)
        # pred noise
        noise_pred_pos = avdc_trainer.model.model(torch.cat([current_latents_noisy, prev_latents], dim=1), t, avdc_text_embeddings)
        noise_pred_uncond = avdc_trainer.model.model(torch.cat([current_latents_noisy, prev_latents], dim=1), t, avdc_text_embeddings * 0.0)

    # noise_pred_uncond = noise_pred[torch.arange(0, b*2, step=2)]
    # noise_pred_pos = noise_pred[torch.arange(1, b*2, step=2)]

    alignment_pred = ((noise_pred_pos - noise_pred_uncond)**2).mean([1,2,3])
    pos_natural_pred = ((noise_pred_pos - noise)**2).mean([1,2,3])
    uncond_natural_pred = ((noise_pred_uncond - noise)**2).mean([1,2,3])
    recon_pred = uncond_natural_pred - pos_natural_pred

    result = symlog(alignment_scale*alignment_pred) + symlog(recon_scale*recon_pred)

    return result