import os
import random
import torch
import numpy as np
from torch_scatter import scatter
from torch_geometric.utils import add_remaining_self_loops, degree, remove_self_loops


def propagate2(x, edge_index):
    edge_index, _ = add_remaining_self_loops(
        edge_index, num_nodes=x.size(0))

    # calculate the degree normalize term
    row, col = edge_index
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)

    # for the first order appro of laplacian matrix in GCN, we use deg_inv_sqrt[row]*deg_inv_sqrt[col]
    edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]

    # normalize the features on the starting point of the edge
    out = edge_weight.view(-1, 1) * x[row]

    return scatter(out, edge_index[-1], dim=0, dim_size=x.size(0), reduce='add')


def seed_everything(seed=0):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.allow_tf32 = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = True


def fair_metric(pred, labels, sens):
    idx_s0 = sens == 0
    idx_s1 = sens == 1
    idx_s0_y1 = np.bitwise_and(idx_s0, labels == 1)
    idx_s1_y1 = np.bitwise_and(idx_s1, labels == 1)
    parity = abs(sum(pred[idx_s0]) / sum(idx_s0) -
                 sum(pred[idx_s1]) / sum(idx_s1))
    equality = abs(sum(pred[idx_s0_y1]) / sum(idx_s0_y1) -
                   sum(pred[idx_s1_y1]) / sum(idx_s1_y1))
    return parity.item(), equality.item()


def node_homophily(sens, edge_index):
    src, dst = edge_index
    src_sens = sens[src]
    dst_sens = sens[dst]

    same_sens = (src_sens==dst_sens).long()

    same_count = torch.zeros_like(sens, dtype=torch.float)
    deg = torch.zeros_like(sens, dtype=torch.float)

    same_count.scatter_add_(0, src, same_sens.float())
    deg.scatter_add_(0, src, torch.ones_like(src, dtype=torch.float))

    node_homophily = same_count / deg.clamp(min=1)

    valid_nodes = deg > 0
    avg_homophily = node_homophily[valid_nodes].mean().item()

    return avg_homophily


def sens_shuffle(sens, edge_index, features, topk_indices, sens_idx, col_max, col_min, max_iter=10):
    n = sens.shape[0]
    new_features = features.clone()

    src, dst = remove_self_loops(edge_index)[0]

    prev_h = node_homophily(new_features[:,sens_idx], edge_index)
    
    for _ in range(max_iter):
        temp = new_features.clone()

        order = torch.randperm(n)
        for i in order:
            y_i = new_features[:,sens_idx][i]

            mask = (src == i)
            neighbors = dst[mask]
            if len(neighbors)== 0:
                continue

            same_mask = (new_features[:,sens_idx][neighbors] == y_i)
            same = same_mask.sum()
            diff = len(neighbors) - same
            same = same + 1

            n_flip = abs(same - diff) // 2
            if n_flip > 0:
                if same > diff:
                    cand = neighbors[same_mask]
                    flip_indices = cand[torch.randperm(cand.size(0))][:n_flip]
                    for idx in topk_indices:
                        new_features[flip_indices,idx] = col_max[idx] + col_min[idx] - new_features[flip_indices,idx]

                if same < diff:
                    cand = neighbors[~same_mask]
                    flip_indices = cand[torch.randperm(cand.size(0))][:n_flip]
                    for idx in topk_indices:
                        new_features[flip_indices,idx] = col_max[idx] + col_min[idx] - new_features[flip_indices,idx]
        
        curr_h = node_homophily(new_features[:,sens_idx], edge_index)

        if abs(prev_h - 0.5) < abs(curr_h - 0.5):
            new_features = temp
            break

        prev_h = curr_h

    return new_features


def sens_shuffle_noisy(sens, edge_index, features, topk_indices, sens_idx, col_max, col_min, max_iter=10):
    n = sens.shape[0]
    new_features = features.clone()
    labels = sens.clone()

    src, dst = remove_self_loops(edge_index)[0]

    prev_h = node_homophily(labels, edge_index)
    
    for _ in range(max_iter):
        temp = new_features.clone()

        order = torch.randperm(n)
        for i in order:
            y_i = labels[i]

            mask = (src == i)
            neighbors = dst[mask]
            if len(neighbors)== 0:
                continue

            same_mask = (labels[neighbors] == y_i)
            same = same_mask.sum()
            diff = len(neighbors) - same
            same = same + 1

            n_flip = abs(same - diff) // 2
            if n_flip > 0:
                if same > diff:
                    cand = neighbors[same_mask]
                    flip_indices = cand[torch.randperm(cand.size(0))][:n_flip]
                    labels[flip_indices] = 1 - labels[flip_indices]
                    for idx in topk_indices:
                        new_features[flip_indices,idx] = col_max[idx] + col_min[idx] - new_features[flip_indices,idx]

                if same < diff:
                    cand = neighbors[~same_mask]
                    flip_indices = cand[torch.randperm(cand.size(0))][:n_flip]
                    labels[flip_indices] = 1 - labels[flip_indices]
                    for idx in topk_indices:
                        new_features[flip_indices,idx] = col_max[idx] + col_min[idx] - new_features[flip_indices,idx]
        
        curr_h = node_homophily(labels, edge_index)

        if abs(prev_h - 0.5) < abs(curr_h - 0.5):
            new_features = temp
            break

        prev_h = curr_h

    return new_features
