"""
compute entanglement
Meyer_Wallach
"""
import math
import torch
from tqdm import tqdm


def MW(test_x, model, params):
    in_list = []
    out_list = []
    for x in tqdm(test_x):
        x = torch.flatten(x, start_dim=0)
        in_state = model.circuit_state(x, params, exec_=False)
        out_state = model.circuit_state(x, params, exec_=True)
        _, ent_in, ent_out = entQ(in_state, out_state, 1)
        in_list.append(ent_in)
        out_list.append(ent_out)
    in_list = torch.tensor(in_list)
    out_list = torch.tensor(out_list)
    return in_list, out_list, out_list-in_list


### MW
def gener_distance(u, v):
    uvmat = torch.kron(u, v) - torch.kron(v, u)
    return 0.5 * (torch.linalg.norm(torch.abs(uvmat)) ** 2)

def liner_map(b, j, psi):
    newpsi = []
    num_qubits = math.ceil(math.log2(psi.size(0)))
    for i in range(psi.size(0)):
        delta_i2bin = ((i >> (num_qubits - 1 - j)) & 1) ^ b ^ 1
        if (delta_i2bin):
            newpsi.append(psi[i].unsqueeze(0))
    return torch.cat(newpsi)

def ent_state(psi):
    num_qubits = math.ceil(math.log2(psi.size(0)))
    res = 0.0
    for j in range(num_qubits):
        res += gener_distance(liner_map(0, j, psi), liner_map(1, j, psi))
    return res * 4 / num_qubits

def entQ(in_state, out_state, k):
    ent_in = ent_state(in_state)
    ent_out = ent_state(out_state)
    return k*(ent_out - ent_in), ent_in, ent_out

