import torch
import itertools
from graph_reconstruction.buildingblocks import *
import matplotlib.pyplot as plt
import neptune
import io
import statistics
from metrics import match_graphs

def get_degree(x):
    while len(x.shape) < 3:
        x = x.unsqueeze(0)
    x = x[ :, :, 58:65]
    v, i = x.max( axis=2 )
    i[ v.isnan() ] = -1
    return i

def properly_structure_deg1(deg1s): 
        
    centers = deg1s[:, 0, :]
    neighbors = deg1s[:, 1:, :]
    degrees = get_degree(neighbors)
    to_sort = torch.cat((degrees.unsqueeze(2), neighbors), dim=-1)
    
    list = []
    
    for i in range(deg1s.shape[0]):
        rows_as_lists = to_sort[i].tolist()
        sorted_indices = sorted(range(len(rows_as_lists)), key=lambda i: rows_as_lists[i], reverse=True)
        sorted_tensor = to_sort[i][sorted_indices]
        list.append(sorted_tensor[:, 1:])
        
    all_sorted = torch.stack(list)
        
    assert(all_sorted.shape[1]==36)
    assert(all_sorted.shape[2]==140)    
    final = torch.cat((centers.unsqueeze(1), all_sorted), dim = 1)

    return final

def properly_structure_deg2(deg2s):
    centers = deg2s[:, 0, :]
    centers = centers.cuda()
    next = deg2s[:, 1:, :]
    degrees = get_degree(next)
    to_sort = torch.cat((degrees.unsqueeze(2), next), dim=-1)
    list = []
    
    for i in range(deg2s.shape[0]): # revert_one_hot_encoding_multiple(gt_deg2_bbs)     revert_one_hot_encoding_multiple(deg2s[i])
        
        # if i!=4:
        #     continue
        
        rows_as_lists = to_sort[i].tolist()

        second_part = []
        second_part += [rows_as_lists[6:11]]
        second_part += [rows_as_lists[11:16]]
        second_part += [rows_as_lists[16:21]]
        second_part += [rows_as_lists[21:26]]
        second_part += [rows_as_lists[26:31]]
        second_part += [rows_as_lists[31:36]]
        
        p1 = second_part[0]
        sorted_indices1 = sorted(range(len(p1)), key=lambda i: p1[i], reverse=True)
        p1 = [p1[j] for j in sorted_indices1]
        p2 = second_part[1]
        sorted_indices2 = sorted(range(len(p2)), key=lambda i: p2[i], reverse=True)
        p2 = [p2[j] for j in sorted_indices2]
        p3 = second_part[2]
        sorted_indices3 = sorted(range(len(p3)), key=lambda i: p3[i], reverse=True)
        p3 = [p3[j] for j in sorted_indices3]
        p4 = second_part[3]
        sorted_indices4 = sorted(range(len(p4)), key=lambda i: p4[i], reverse=True)
        p4 = [p4[j] for j in sorted_indices4]
        p5 = second_part[4]
        sorted_indices5 = sorted(range(len(p5)), key=lambda i: p5[i], reverse=True)
        p5 = [p5[j] for j in sorted_indices5]
        p6 = second_part[5]
        sorted_indices6 = sorted(range(len(p6)), key=lambda i: p6[i], reverse=True)
        p6 = [p6[j] for j in sorted_indices6]

        second_part = [p1]+[p2]+[p3]+[p4]+[p5]+[p6]        
        
        # breakpoint()
        
        new_rows_as_lists = rows_as_lists[0:6] + p1+p2+p3+p4+p5+p6
        
        part0 = []
        
        part0.append([new_rows_as_lists[0]] + new_rows_as_lists[6:11])
        part0.append([new_rows_as_lists[1]] + new_rows_as_lists[11:16])
        part0.append([new_rows_as_lists[2]] + new_rows_as_lists[16:21])
        part0.append([new_rows_as_lists[3]] + new_rows_as_lists[21:26])
        part0.append([new_rows_as_lists[4]] + new_rows_as_lists[26:31])
        part0.append([new_rows_as_lists[5]] + new_rows_as_lists[31:36])
        
        # breakpoint()
         
        sorted_indices0 = sorted(range(len(part0)), key=lambda i: part0[i], reverse=True)
        # sorted_indices0 = [5, 3, 2, 0, 1, 4]
        
        # breakpoint()
        
        part0 = [part0[j][0] for j in sorted_indices0]

        # breakpoint()

        second_part = [second_part[j] for j in sorted_indices0]
                
        final = part0 + second_part[0] + second_part[1] + second_part[2] + second_part[3] + second_part[4] + second_part[5]        
        # breakpoint()

        sorted_tensor = torch.tensor(final)
        list.append(sorted_tensor[:, 1:])
        
    all_sorted = torch.stack(list).cuda()
    assert(all_sorted.shape[1]==36)
    assert(all_sorted.shape[2]==140)    
    final = torch.cat((centers.unsqueeze(1), all_sorted), dim = 1)

    # breakpoint()

    return final

def get_gt_37_deg1(gt_fms, gt_ams):
    nn = gt_ams.shape[0]
    list = []
    
    for i in range(nn):
        
        curr = [gt_fms[i]] + [gt_fms[x] for x in range(nn) if gt_ams[i][x]==1 and x!=i]
        
        curr = torch.stack(curr)
        
        more = 37-curr.shape[0]
        nans = torch.stack([torch.zeros(140).cuda() for _ in range(more)])
        curr = torch.cat((curr,nans), dim=0)
        
        assert(curr.shape[0]==37)
        list.append(curr)
        
    result = properly_structure_deg1(torch.stack(list, dim=0))
    
    return result

def get_gt_37_deg2(gt_fms, gt_ams):
    nn = gt_ams.shape[0]
    list = []
    
    for i in range(nn):
        
        list_rows = [gt_fms[i]]
        neighbor_indices = [x for x in range(nn) if gt_ams[i][x] and x!=i]
        neighbors_ftm = gt_fms[neighbor_indices]
        
        neighbors_ftm = [row for row in neighbors_ftm]
        
        more = 6-len(neighbor_indices)
        nans = [torch.zeros(140).cuda() for _ in range(more)]
        list_rows += neighbors_ftm
        list_rows += nans
        for j in neighbor_indices:
            two_hop_neighbors_indices = [x for x in range(nn) if gt_ams[j][x] and x!=j and x!=i]
            two_hop_neighbors_ftm = gt_fms[two_hop_neighbors_indices]
            two_hop_neighbors_ftm = [row for row in two_hop_neighbors_ftm]
            more = 5-len(two_hop_neighbors_indices)
            nans = [torch.zeros(140).cuda() for _ in range(more)]
            list_rows += two_hop_neighbors_ftm
            list_rows += nans
        more = 37-len(list_rows)
        nans = [torch.zeros(140).cuda() for _ in range(more)]
        list_rows += nans
        deg2_bb_tensor = torch.stack(list_rows)
        list.append(deg2_bb_tensor)
                
    result = properly_structure_deg2(torch.stack(list, dim=0))
    
    return result

def get_AX_deg1(batched_bbs): # Raboti :D
    
    # breakpoint()
    
    degrees = get_degree(batched_bbs)
    center_degrees = torch.pow(degrees[:,0]+ 1, -0.5)
    center_degrees = torch.where(center_degrees==1, torch.tensor(0.0), center_degrees)
    center_degrees = torch.where(torch.isinf(center_degrees), torch.tensor(0.0), center_degrees).unsqueeze(1) 
    all_degrees = torch.pow(degrees+ 1, -0.5)
    all_degrees = torch.where(all_degrees==1, torch.tensor(0.0), all_degrees)     
    all_degrees = torch.where(torch.isinf(all_degrees), torch.tensor(0.0), all_degrees) 
    As_first_rows = center_degrees * all_degrees
    AXs = torch.bmm(As_first_rows.unsqueeze(1), batched_bbs)
    return AXs
    # breakpoint()

def get_AX_deg2(batched_bbs, gcn_model): # Raboti :D
    degrees = get_degree(batched_bbs) + 1
    degrees = torch.where(degrees==1, torch.tensor(0.0), degrees)
    result = torch.bmm(degrees.unsqueeze(2), degrees.unsqueeze(1))
    result = torch.pow(result, -0.5)
    result = torch.where(torch.isinf(result), torch.tensor(0.0), result)
    universal_adj_37 = torch.zeros((37,37)).cuda()
    for i in range(7):
        universal_adj_37[0,i] = 1
        universal_adj_37[i,0] = 1
    for i in range(1,7):
        universal_adj_37[i,5*i+2:5*i+7] = 1
        universal_adj_37[5*i+2:5*i+7,i] = 1
    for i in range(37):
        universal_adj_37[i,i] = 1
    all_adj = result * universal_adj_37
    AXs = torch.bmm(all_adj, batched_bbs)
    output = all_adj@gcn_model.act(AXs@gcn_model.W_list[0])@gcn_model.W_list[1]
    # AAXs = torch.bmm(all_adj, AXs)
    return output
    
def convert_to_bb(bb_deg2s):
    
    list_of_bbs = []
    
    for index in range(bb_deg2s.shape[0]):
        bb = bb_deg2s[index]
        non_zero_rows = torch.any(bb != 0, dim=1)
        num_nodes = torch.sum(non_zero_rows).item()
        A = torch.zeros((num_nodes, num_nodes)).cuda()
        X = bb[non_zero_rows].cuda()
        degrees = get_degree(X.unsqueeze(0))
        degrees = degrees.squeeze(0)
        middle_idx = 0
        num_neighbors = 0
        for i in range(1,7):
            if torch.any(bb[i] != 0):
                A[0][i] = 1
                A[i][0] = 1
                num_neighbors+=1
        connections = torch.where(degrees[num_neighbors+1:]>1)[0]
        connections = connections + (num_neighbors + 1)
        curr = num_neighbors+1
        for i in range(1,7):
            for j in range(5*i+2, 5*i+7):
                if torch.any(bb[j] != 0):
                    A[i][curr] = 1
                    A[curr][i] = 1
                    curr+=1
        for i in range(num_nodes):
            A[i][i] = 1
        
        list_of_bbs.append(NewBuildingBlock(A, X, middle_idx, connections, 2))
        
    return list_of_bbs

def are_compatible_pos(xs, xs_size, ys, ys_size, pos):
    nn_xs = xs_size[:,None] - 1
    nn_ys = ys_size[:,None] - 1
    d_xs = get_degree(xs[:,1:])
    d_ys = get_degree(ys[:,1:])
    batch_size = d_xs.size(0)

    can_glue_ys = torch.any( torch.logical_and( torch.all( xs[:,0:1] == ys[:,1:], dim=2 ), d_ys == nn_xs ), axis=1 )
    can_glue_xs = torch.any( torch.logical_and( torch.all( ys[:,0:1] == xs[:,pos:pos+1], dim=2 ), d_xs == nn_ys ), axis=1 )
    can_glue = torch.logical_and( can_glue_xs, can_glue_ys )
    
    indices = ( torch.argmax(
                    torch.logical_and( 
                          torch.all( xs[:,0:1] == ys[:,1:], dim=2 ), 
                          d_ys == nn_xs 
                    ).to(dtype=torch.int64),dim=1
                ) + 
               torch.ones((batch_size), dtype=torch.int64).cuda() )* can_glue.int()
    
    return indices

def revert_one_hot_encoding_multiple(X):
    lists = sum(possible_feature_values(), [])
    list_x_reverted = []
    assert(len(lists)==140)
    list_alls = []
    
    for i in range(X.shape[0]):
        x = X[i]
        list_x_reverted = []
        # print(x)
        if x.shape[0]==0:
            continue
        for ind in range(140):
            if x[ind]>0:
                list_x_reverted.append(lists[ind])
        list_alls.append(list_x_reverted)
        
    return tuple(list_alls)

def get_atom_type(x):
    list_poss_vals = possible_feature_values()
    list_of_vals = torch.tensor(list(itertools.chain(*list_poss_vals))).cuda()
    x = x[ :, 0:28]
    v, i = x.max( axis=1 )
    i=i.cuda()
    i = list_of_vals[i]
    return i

def features_combined_ultimate(xs_idxs, ys_idxs, X_cuda_filtered, list_tensors_pos):
    # print(xs_idxs)
    # print(ys_idxs)
    
    batch_size = ys_idxs.shape[0]
    num_dangling_bits = ys_idxs.shape[1]
    
    xs = X_cuda_filtered[xs_idxs]
    ys = X_cuda_filtered[ys_idxs]
    
    expanded_xs_idxs = xs_idxs.repeat(1, num_dangling_bits)
    expanded_xs_idxs = expanded_xs_idxs.cuda()
    dang_idx = torch.arange(1, num_dangling_bits + 1).repeat(batch_size, 1)
    dang_idx = dang_idx.cuda()
    indexing = torch.stack((dang_idx, expanded_xs_idxs, ys_idxs), dim=0).cuda()
    indexing = indexing.permute(1, 2, 0).reshape(-1, 3)
    
    db = indexing[:,0]
    x_coord = indexing[:,1]
    y_coord = indexing[:,2]
    
    where_append = list_tensors_pos[db, x_coord, y_coord]
    where_append = where_append.view(batch_size, num_dangling_bits)
    
    # if num_dangling==2:
    #     breakpoint()
    
    ys_idx = torch.ones( (batch_size, num_dangling_bits, 7), dtype=torch.bool, device='cuda' )
    ys_idx[:,:,0] = False
    # ys_idx[torch.arange(batch_size),torch.arange(num_dangling_bits), where_append] = False
    ys_idx[torch.arange(batch_size).unsqueeze(1), torch.arange(num_dangling_bits).unsqueeze(0), where_append] = False
    ys = ys[:, :, :7]
    ys = ys[ys_idx]
    ys=ys.reshape(batch_size,num_dangling_bits*5,140)
    xs = xs.reshape(batch_size, 37, 140)
    xs = xs[:, :7]
    xs=torch.cat((xs, ys), dim=1)
    
    # if num_dangling==2:
    #     breakpoint()
    
    nones = torch.zeros((batch_size, 37-xs.shape[1], 140)).cuda() * torch.nan
    xs=torch.cat((xs, nones), dim=1)
    
    # breakpoint()
    
    # assert(torch.sum(xs>1)==0)
    # assert(torch.sum(xs<0)==0)
    
    return xs


def plot_bb_filter(correct, wrong, use_neptune, name, run):
    plt.clf()
    plt.scatter(range(len(wrong)), wrong, color='blue', marker='x', label='wrong')
    plt.scatter(range(len(correct)), correct, color='red', marker='x', label='correct')
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    if use_neptune:
        run[f"charts/{name}"].append(neptune.types.File.from_content(buf.getvalue(), extension="png"))
        
def get_layer_decomp(grad, B=None, tol=None):
    # grad = grad.detach().cpu().numpy()
    # print(type(grad))
    if isinstance(grad, torch.Tensor):
        # If grad is a PyTorch tensor, apply the operations
        grad = grad.detach().cpu().numpy()
    else:
        # Handle other types of objects gracefully
        raise TypeError("Expected grad to be a torch.Tensor, got {} instead.".format(type(grad)))
    
    if B == None:
        B = np.linalg.matrix_rank( grad , tol=tol)
    U,S,Vh = torch.svd_lowrank(torch.tensor(grad),q=B,niter=10)
    R = Vh.T
    return  B, torch.Tensor(R).detach()

def check_if_in_span(R_K_norm, v):
    v /= v.pow(2).sum(-1,keepdim=True).sqrt()
    
    proj = torch.einsum('ik,ij,...j->...k', R_K_norm, R_K_norm, v ) # ( (R_K_norm @ v.T) [:,:,None] * R_K_norm[:,None,:] ).sum(0)
    
    out_of_span = proj - v
    size = out_of_span.abs().sum(-1) # Using L1 norm

    return size

def get_best_hyperparameter(list, biggest_yes):
    mean = statistics.mean(list)
    std_dev = statistics.stdev(list)
    if std_dev!=0:
        hype = (mean - biggest_yes)/std_dev
        return hype
    else:
        print(list)
        print(biggest_yes)
        assert(2<1)
    
def l2_dist_grad(adv_grad, true_grad):
    l2_distances = []

    for tensor1, tensor2 in zip(adv_grad, true_grad):
        squared_diff = (tensor1 - tensor2) ** 2
        sum_squared_diff = squared_diff.sum()
        l2_distance = torch.sqrt(sum_squared_diff)
        l2_distances.append(l2_distance.item())

    mean_l2_distance = torch.tensor(l2_distances).mean()
    return mean_l2_distance.item()

def graph_gradient_dist(A, X, gt_gradient, model, criterion, gt_fms=None, gt_ams=None, gt_ls=None):
    with torch.enable_grad():
        logits_adv = model(A, X)
        if gt_ls is None:
            possible_labels = [torch.tensor([0.,1.]).cuda(),torch.tensor([1.,0.]).cuda()]
            
            l2_diffs = []
            
            for y in possible_labels:
                adv_loss = criterion(logits_adv, y)
                adv_loss = adv_loss.sum() / 2
                adv_grad = torch.autograd.grad(adv_loss, model.parameters(), create_graph=True)
                l2_diff = l2_dist_grad(adv_grad,gt_gradient)
                l2_diffs.append(l2_diff)
            
            return min(l2_diffs)
        else:
            y = torch.zeros((A.shape[0],2)).cuda()
            y[:, 1] = 1.
            with torch.no_grad():
                row_ids, col_ids = match_graphs(A.cuda(), X.cuda(), gt_ams.cuda(), gt_fms.cuda(), model)
                y[col_ids] = gt_ls[row_ids]

            adv_loss = criterion(logits_adv, y)
            adv_loss = adv_loss.sum() / adv_loss.numel()  
            adv_grad = torch.autograd.grad(adv_loss, model.parameters(), create_graph=True)
            l2_diff = l2_dist_grad(adv_grad,gt_gradient)
            return l2_diff
def log_metrics(metrics, run):
    for metric in metrics:
        if metric == 'num_edge_frac' and math.isnan(metrics[metric]):
            run[f"metrics/{metric}"].append(1)
            continue    
        run[f"metrics/{metric}"].append(metrics[metric])
        
def compute_grads_fed_avg(model, criterion, Xs, As, gt_ls, avg_epochs, avg_lr):
    with torch.enable_grad():
        og_weights = [param.data.clone() for param in model.parameters()]

        model.eval()
        optimizer = torch.optim.SGD(model.parameters(), lr=avg_lr)

        for _ in range(avg_epochs):
            optimizer.zero_grad()
            logits = model(As, Xs)
            loss = criterion(logits, gt_ls).mean()
            loss.backward()
            optimizer.step()
           
        grad = [-(param.data.detach() - og_weights[i])/avg_lr/avg_epochs for i, param in enumerate(model.parameters())]
        
        for i, param in enumerate(model.parameters()):
            param.data = og_weights[i]
            
        return grad
        
class CustomArgs:    
    do_ohe = True
    eval_method = 'acc'

