import torch

def blind(mixture, bm):
    merge_len = mixture[0].size(0)
    concated_image = torch.cat(mixture, axis=0)
    dim0 = concated_image.size()[0] >> 1
    concated_res =  bm.matmul(concated_image.reshape(1, 3,-1)).reshape(concated_image.size())
    return torch.split(concated_res, merge_len, dim=0)

def unblind(mixture, um):
    merge_len = mixture[0].size(0)
    concated_image = torch.cat(mixture, axis=0)
    dim0 = concated_image.size()[0] >> 1
    concated_res =  um.matmul(concated_image.reshape(1, 3,-1)).reshape(concated_image.size())
    return torch.split(concated_res, merge_len, dim=0)

def group_detach(x):
    res = []
    for i in x:
        res.append(i)
    return res


device = torch.device("cuda:0")

blind_mat   = torch.tensor([[1.2, 0.8, 1.0], [0.8, 1.2, 1.0], [0.0, 0.0, 1.0]], dtype=torch.float32, device=device)
unblind_mat = blind_mat.inverse()
update_mat  = torch.tensor([[0.5, 0.5, 0.],[0.5, 0.5, 0.],[0.5, 0.5, 0.]], dtype=torch.float32, device=device).matmul(unblind_mat)

image = [torch.tensor([3.0, 4.1], dtype=torch.float32, device=device),
	torch.tensor([2.2, 23.9], dtype=torch.float32, device=device),
	torch.tensor([132.1, 4334.0], dtype=torch.float32, device=device)]


blinded = group_detach(blind(image, blind_mat))
print(blinded)


revealed = group_detach(unblind(blinded, unblind_mat))
print(revealed)

updated = group_detach(unblind(blinded, update_mat))
print(updated)