import torch
from projop.utils import satisfies
import projop

def random_round (Xs, adjs, constraint_config, adj_vals=[0, 1], feat_vals=[]):
    to_sample_X = len(feat_vals) > 0
    smpld_Xs, smpld_adjs = [], []
    for i in range (Xs.shape[0]):
        if to_sample_X:
            smpld_x = torch.zeros_like(Xs[i])
            smpld_x[range(Xs[i].shape[0]), Xs[i].argmax(dim=1)] = 1
            # X = torch.clamp (Xs[i], min=min(feat_vals), max=max(feat_vals))
            # smpld_x = torch.floor(X) + torch.bernoulli(torch.frac(X)) 
        else:
            smpld_x = Xs[i]
        for _ in range(constraint_config.max_samples):
            adj = torch.clamp (adjs[i], min=min(adj_vals), max=max(adj_vals))
            smpld_adj = torch.floor(adj) + torch.bernoulli(torch.frac(adj))
            if (satisfies(smpld_adj, smpld_x, constraint_config, ngraph=i)):
                break
        smpld_adjs.append(smpld_adj)
        smpld_Xs.append(smpld_x)
    return torch.stack(smpld_Xs), torch.stack(smpld_adjs)

def repeated_round (Xs, adjs, constraint_config, adj_vals=[0, 1], feat_vals=[]):
    to_sample_X = len(feat_vals) > 0
    smpld_Xs, smpld_adjs = [], []
    for i in range (Xs.shape[0]):
        adj, x = adjs[i], Xs[i]
        for _ in range(constraint_config.max_samples):
            adj_clmp = torch.clamp (adj, min=min(adj_vals), max=max(adj_vals))
            adj = torch.round(adj_clmp)
            if to_sample_X:
                x_clmp = torch.clamp (Xs[i], min=min(feat_vals), max=max(feat_vals))
                x = torch.round(x_clmp)
            if (satisfies(adj, x, constraint_config, ngraph=i)):
                break
            else:
                x, adj = projop.project(x, adj, constraint_config, adjndim=i)
        smpld_adjs.append(adj)
        smpld_Xs.append(x)
    return torch.stack(smpld_Xs), torch.stack(smpld_adjs)


def heur_val_round (Xs, adjs, constr_config, adj_vals=[0, 1], feat_vals=[], hidden_Hs=True):
    valencies = torch.tensor(constr_config.params[0], dtype=Xs.dtype, device=Xs.device)
    smpld_Xs, smpld_adjs = [], []
    adjs = torch.clamp (adjs, min=min(adj_vals), max=max(adj_vals))
    for i in range (Xs.shape[0]):
        X, A = Xs[i].clone(), adjs[i].clone()
        atoms_Hs = ~torch.cat([torch.any(X[i] > 0.5).ravel() for i in range(X.shape[0])])
        atoms_others = ~atoms_Hs
        X[atoms_Hs] = 0
        # x_atoms = X[atoms_Hs] if hidden_Hs else X
        smpld_x = torch.zeros_like(X)
        smpld_x[atoms_others, X.argmax(dim=1)[atoms_others]] = 1
        # 
        A[atoms_Hs, :] = 0
        A[:, atoms_Hs] = 0
        # A_atoms = A[atoms_Hs][:, atoms_Hs] if hidden_Hs else A
        atom_inds = torch.where(atoms_others)[0]
        satisfied = False
        while not satisfied:
            satisfied = True
            smpld_a = torch.zeros_like(A)
            atom_inds = atom_inds[torch.randperm(len(atom_inds))]
            for j in atom_inds:
                # how to maintain undirected ???? and still remove atoms_Hs.
                valency_j = smpld_x[j].dot(valencies)
                # knapsack problem to maximize sum_k a_probs[k] such that sum_k a_vals[k] <= val_j.
                a_vals, a_probs = torch.floor(A[j]), torch.frac(A[j])
                # greedy 
                _, psorted_nodes = torch.sort (a_probs, descending=True)
                psorted_nodes = psorted_nodes[smpld_a[j] == 0]
                selected_nodes, current_valency = [], smpld_a[j].sum()
                if smpld_a[j].sum() > valency_j:
                    satisfied = False
                    break
                for node in psorted_nodes:
                    val_node = smpld_x[node].dot(valencies)
                    if ((current_valency + a_vals[node] <= valency_j) and (a_vals[node] > 0) and 
                        (smpld_a[node].sum() + a_vals[node] <= val_node)): # undirected
                        selected_nodes.append(node.item())
                        current_valency += a_vals[node]
                    if (current_valency == valency_j) or (a_probs[node] == 0):
                        break
                smpld_a[j, selected_nodes] = a_vals[selected_nodes]
                smpld_a[selected_nodes, j] = a_vals[selected_nodes] # undirected
        # if torch.any (smpld_a.sum(dim=1) > (smpld_x @ valencies)):
        if any([smpld_a[j].sum() > smpld_x[j].dot(valencies) for j in torch.where(atoms_others)[0]]):
            print (smpld_a, smpld_x, A, X)
        smpld_adjs.append(smpld_a.to(torch.int64))
        smpld_Xs.append(smpld_x)
    return torch.stack(smpld_Xs), torch.stack(smpld_adjs)
