import torch
import numpy as np


def check_z(z, zh, d, lenlog, i_run, run_cond_v):
    z_len = lenlog.z_len; zh_len = lenlog.zh_len
    d_len = lenlog.d_len; L_tmp = len(zh_len)
    z_l_diff = [z_len[l+1] - (zh_len[l] + d_len[l]) for l in range(L_tmp)]
    z_v_diff = [torch.norm(z[l+1]-zh[l]-d[l]).item() for l in range(L_tmp)]


def check_d_rel(model, opt, zs, L, i_run, run_cond_v):
    zg_1, d_1 = pc_update_step(model, opt, zs, L, trg_dtc=True)
    zs = [zs[i].detach() for i in range(len(zs))]
    # zg_2, d_2 = self.pc_update_step(z, L, trg_dtc=False)
    # d_1 == d_2?
    # z_grad_td == - d_1?
    # z_grad_both == z_trad_td + d_1 ?
    d_v_diff = [torch.norm(zg_1[l]+d_1[l].detach()).item() \
                for l in range(len(d_1)-1)]


def check_w_grad(w_grad, z, d, i_run, run_cond_v):
    w_diff_lst = list()
    for l in range(len(w_grad)):
        d_m = d[l].unsqueeze(-1)
        z_m = z[l].unsqueeze(-2)
        m = torch.matmul(d_m, z_m)
        update_w = m.sum(-3) # sum over bsz
        w_diff = torch.norm(w_grad[l]-update_w).item()


def pc_update_step(model, opt, zs, L, trg_dtc=False):
    with torch.enable_grad():
        z_enc, _, _ = model.predict(zs, False, check=True)
        loss, d = model.compare_ls(zs, False, None, z_enc, check=True,
                                        trg_dtc=trg_dtc)
        opt.zero_grad()
        loss.backward()
        z_grad = [z.grad.detach() for l, z in enumerate(zs) \
                    if l > 0 and l < L]
    return z_grad, d


# exponential expressivity plot reproduce

def _fwd_pre(fwd, x, run_cond_v):
    log_dt_plt = fwd(x, x, run_cond_v, i_run=0)
    return log_dt_plt['hs']


def _fwd_len(fwd, x, y, run_cond_v):
    log_dt_plt = fwd(x, x, run_cond_v, i_run=0)
    return log_dt_plt['len/h_enc_len'][0]


def _fwd_lens(fwd, x, y, run_cond_v):
    log_dt_plt = fwd(x, x, run_cond_v, i_run=0)
    return np.stack([log_dt_plt['len/z_len'][0]]+log_dt_plt['len/h_enc_len'])
