
class UF:
    def __init__(self, n):
        self.p = np.arange(n, dtype=int)
    def find(self, i):
        if self.p[i] == i:
            return i
        else:
            self.p[i] = self.find(self.p[i])
            return self.p[i]
    def merge(self, i, j):
        if i != j:
            self.p[self.find(i)] = self.find(j)
class PHC:
    def __call__(self, m):
        n = m.shape[0]
        u = UF(n)
        ti = np.triu_indices_from(m)
        ew = m[ti]
        ei = np.argsort(ew, kind='stable')
        pp = []
        for idx, w in zip(ei, ew[ei]):
            i, j = ti[0][idx], ti[1][idx]
            c1, c2 = u.find(i), u.find(j)
            if c1 != c2:
                u.merge(i, j) if c1 <= c2 else u.merge(j, i)
                pp.append((i, j) if i < j else (j, i))
        return np.array(pp), np.array([])
def cdm(d, p=2):
    if d.dim() == 2:
        return torch.cdist(d.view(d.size(0), -1), d.view(d.size(0), -1), p=p)
    elif d.dim() == 3:
        return torch.cdist(d, d, p=p)
    else:
        raise ValueError("err")
def gp(d):
    sc = PHC()
    p0, _ = sc(d.detach().cpu().numpy())
    return p0
def gsp(p, dev):
    if p.size == 0:
        return torch.tensor([], dtype=torch.long, device=dev), torch.tensor([], dtype=torch.long, device=dev)
    return torch.from_numpy(p[:, 0]).to(dev), torch.from_numpy(p[:, 1]).to(dev)
def cpa(p, l):
    pl, ll, pc = p.tolist(), l.tolist(), []
    for item in pl:
        b, d = item[0], item[1]
        if b < len(ll) and d < len(ll) and ll[b] != ll[d]:
            pc.append(item)
    return np.array(pc)
def chl(f1, f2, l):
    dev = f1.device
    p = gp(cdm(f1))
    pc = cpa(p, l.cpu())
    if pc.shape[0]==0:
        return torch.tensor(0.0, device=dev)
    pb, pd = gsp(pc, dev)
    b_f = f1.index_select(0, pb)
    d_f = f1.index_select(0, pd)
    img_t = d_f-b_f
    bl = l.index_select(0, pb)
    dl = l.index_select(0, pd)
    tb_f = f2.index_select(0, pb)
    td_f = f2.index_select(0, pd)
    tbs = tb_f[torch.arange(len(bl)), :, bl]
    tds = td_f[torch.arange(len(dl)), :, dl]
    txt_t = tds-tbs
    img_n = F.normalize(img_t, p=2, dim=-1)
    txt_n = F.normalize(txt_t, p=2, dim=-1)
    cos = torch.sum(img_n * txt_n, dim=-1)
    return torch.mean(1.0 - cos)
def _csb(fb):
    dev=fb.device
    b, n, _=fb.shape
    if n<=1: return torch.zeros(b, n, device=dev)
    with torch.no_grad():
        dmc=cdm(fb.cpu())
    ssb=torch.zeros(b, n, device=dev)
    for i in range(b):
        dmn=dmc[i].numpy()
        nv=dmn.shape[0]
        u=UF(nv)
        ti=np.triu_indices(nv, k=1)
        ew=dmn[ti]
        ei=np.argsort(ew, kind='stable')
        for idx in ei:
            j, k = ti[0][idx], ti[1][idx]
            if u.find(j)!=u.find(k):
                dt=ew[idx]
                ssb[i, j]+=dt
                ssb[i, k]+=dt
                u.merge(j, k)
    return ssb
class TGA(nn.Module):
    def __init__(self, gs=4, t=1.0, res=True, a=0.5):
        super().__init__()
        self.gs=gs
        self.t=t
        self.res=res
        self.a=a
    def forward(self, pfb):
        dev=pfb.device
        b, n, d=pfb.shape
        ogs=int(np.sqrt(n))
        fm=pfb.permute(0, 2, 1).view(b, d, ogs, ogs)
        cs=ogs//self.gs
        uf=F.unfold(fm, kernel_size=cs, stride=cs)
        nc=uf.shape[-1]
        pic=cs*cs
        uf=uf.view(b, d, pic, nc).permute(0, 3, 2, 1).reshape(b * nc, pic, d)
        sc=_csb(uf)
        sc=sc.view(b, nc, pic, 1)
        su=sc.permute(0, 3, 2, 1).reshape(b, pic, nc)
        fsm=F.fold(su, output_size=(ogs, ogs), kernel_size=cs, stride=cs)
        attn_s=fsm.view(b, n)
        attn_w=F.softmax(attn_s/self.t, dim=1)
        aff=torch.sum(pfb*attn_w.unsqueeze(2), dim=1)
        if not self.res:
            return aff, attn_w.view_as(fsm.squeeze(1))
        gcf=torch.mean(pfb, dim=1)
        ff=gcf+self.a*aff
        return ff, attn_w.view_as(fsm.squeeze(1))