import os

import numpy as np
import torch

import wandb
from attack.msssim import msssim
from utils import gaus_skl


def eval_attack(model, x_ref, x_adv, step, x_trg=None, task=None, save_dir=None):
    """

    x_trg: torch.tensor (N_trg, x_dim)
    x_ref: torch.tensor (1, x_dim)
    x_adv: torch.tensor (N_trg, x_dim)
    """
    if task is not None:
        name_pref = f"task{task}_"
    else:
        name_pref = ""
    if save_dir is not None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(x_adv.cpu(), os.path.join(save_dir, f"{name_pref}x_adv_{step}.pth"))
        torch.save(x_ref.cpu(), os.path.join(save_dir, f"x_ref_{step}.pth"))

    # get reconstructions
    logs = {}
    with torch.no_grad():
        # forward_reshaped is copied from https://github.com/AKuzina/defend_vae_mcmc/
        # blob/e993225228fdacf67fac4bfb403db2938ef82bed/vae/model/vae.py#L185,
        # just for reference
        # def forward_reshaped(self, x):
        #     z_q_mean, z_q_logvar = self.vae.q_z(x)
        #     z_q = self.vae.reparametrize(z_q_mean, z_q_logvar)
        #     x_mean, x_logvar = self.vae.p_x(z_q)
        #     return x_mean, x_logvar, z_q, z_q_mean, z_q_logvar
        # Note that: Anna is using some Gaussian model all the way
        #   => in fact only x_mean is compared as reconstruction, so we could also just
        #   use some samples (maybe do some average here)

        z_ref, z_ref_m, z_ref_lv = model.q_z(x_ref)
        x_ref_m = model.reconstruction(x_ref, use_sample=True)
        x_ref_lv = None

        z_adv, z_adv_m, z_adv_lv = model.q_z(x_adv)
        x_adv_m = model.reconstruction(x_adv, use_sample=True)
        x_adv_lv = None

        adv_dist = (x_adv, x_adv_m, x_adv_lv)
        zadv_dist = (z_adv, z_adv_m, z_adv_lv)
        if save_dir is not None:
            torch.save(x_adv_m.cpu(), os.path.join(save_dir, f"x_adv_rec_{step}.pth"))

    ref_logs = eval_attack_reference(
        ref_dist=(x_ref, x_ref_m, x_ref_lv),
        adv_dist=adv_dist,
        zref_dist=(z_ref, z_ref_m, z_ref_lv),
        zadv_dist=zadv_dist,
    )
    logs.update(ref_logs)

    # Add supervised-only metrics
    if x_trg is not None:
        with torch.no_grad():
            x_trg_m = model.reconstruction(x_trg, use_sample=True)
        trg_logs = eval_attack_target(
            x_trg, x_trg_m, adv_dist[1], step, save_dir=save_dir
        )
        logs.update(trg_logs)

    return logs


def eval_attack_reference(ref_dist, adv_dist, zref_dist, zadv_dist):
    x_ref, x_ref_m, x_ref_lv = ref_dist
    x_adv, x_adv_m, x_adv_lv = adv_dist
    z_ref, z_ref_m, z_ref_lv = zref_dist
    z_adv, z_adv_m, z_adv_lv = zadv_dist

    # eps norm
    eps_norms = [torch.norm(x_ref - x_a.unsqueeze(0)).cpu() for x_a in x_adv]

    # msssims
    ref_sim = [
        msssim(x_ref, x_a.unsqueeze(0), 14, normalize="relu").data.cpu()
        for x_a in x_adv
    ]
    ref_rec_sim = [
        msssim(x_ref_m, x_a.unsqueeze(0), 14, normalize="relu").data.cpu()
        for x_a in x_adv_m
    ]

    # latent space
    s_kl = gaus_skl(z_ref_m, z_ref_lv, z_adv_m, z_adv_lv).mean()
    mus = (z_ref - z_adv).pow(2).sum(1).mean()

    logs = {
        "Adversarial Inputs": wandb.Image(x_adv.cpu()),
        "Adversarial Rec": wandb.Image(x_adv_m.cpu()),
        "ref_sim": np.mean(ref_sim),
        "ref_rec_sim": np.mean(ref_rec_sim),
        "eps_norm": np.mean(eps_norms),
        "s_kl": s_kl.cpu(),
        "z_dist": mus.cpu(),
    }
    return logs


def eval_attack_target(x_trg, x_trg_m, x_adv_m, step, save_dir):
    trg_rec_sim = [
        msssim(x_trg_m[i : i + 1], x_adv_m[i : i + 1], 14, normalize="relu").data.cpu()
        for i in range(x_trg.shape[0])
    ]
    logs_trg = {"trg_rec_sim": np.mean(trg_rec_sim)}
    if step == 0:
        logs_trg["Target Inputs"] = wandb.Image(x_trg.cpu())
        logs_trg["Target Rec"] = wandb.Image(x_trg_m.cpu())
        if save_dir is not None:
            torch.save(x_trg.cpu(), os.path.join(save_dir, "x_trg.pth"))
            torch.save(x_trg_m.cpu(), os.path.join(save_dir, "x_trg_rec.pth"))
    return logs_trg
