import os
import argparse
import numpy as np

import torch
from torch_geometric.utils import remove_self_loops, scatter


def node_homophily(edge_index, y):
    # taken from https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/utils/_homophily.py
    row, col = edge_index
    out = torch.zeros(row.size(0), device=row.device)
    out[y[row] == y[col]] = 1.
    out = scatter(out, col, 0, dim_size=y.size(0), reduce='mean')
    return out

def calculate_homophily_ratio(data_list):
    # assuming the first node is the center node
    ratio_list = []
    for data in data_list:
        edge_index, y = data.edge_index, data.node_labels.squeeze(-1)
        edge_index, _ = remove_self_loops(edge_index)
        if edge_index.shape[1] > 0:
            ratios = node_homophily(edge_index, y)
            ratio_list.append(ratios[0].item())
        else:
            ratio_list.append(0)

    return ratio_list

def calculate_accuracy(data_list):
    acc_list = []
    for data in data_list:
        y_pred, y_true = data.y_pred, data.y
        acc_list.append((y_pred == y_true).item())

    return acc_list

def calculate_sd(data_list, thres=[0, 0.2, 0.4, 0.6, 0.8, 1]):
    ratio = np.array(calculate_homophily_ratio(data_list))
    acc = np.array(calculate_accuracy(data_list))

    group_acc = []
    for i in range(len(thres) - 1):
        lower, upper = thres[i], thres[i + 1]
        if i == 0 or i == len(thres) - 2:
            flags = np.logical_and(ratio >= lower, ratio <= upper)
        elif i == 1 or i == len(thres) - 3:
            flags = np.logical_and(ratio > lower, ratio < upper)
        else:
            flags = np.logical_and(ratio >= lower, ratio < upper)
        
        group_acc.append(acc[flags].mean())
    
    sd = np.std(group_acc)
    cv = sd / acc.mean()

    return group_acc, sd, cv


THRES_DICT = {
    'cora': [0, 0.333, 0.5, 0.75, 1, 1],
    'citeseer': [0, 0, 0.5, 0.667, 1, 1],
    'pubmed': [0, 0, 0.5, 0.8, 1, 1],
}

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="homo")
    parser.add_argument("--data_name", type=str, default="cora")
    parser.add_argument("--type", type=str, default="metric")
    parser.add_argument('--quantile', action='store_true')
    parser.add_argument("--num_groups", type=int, default=5)
    args = parser.parse_args()

    files = os.listdir('./results/predictions')
    files = [x for x in files if args.data_name in x and x.endswith('.pt')]
    ori_files = [x for x in files if 'augFalse' in x]
    aug_files = [x for x in files if 'augTrue' in x]
    prefixes = np.unique([x.split('-r.')[0] for x in files])

    if args.quantile:
        fname = [x for x in ori_files if 'r.0' in x][0]
        data_list = torch.load(f'./results/predictions/{fname}')
        ratio = np.array(calculate_homophily_ratio(data_list))

        # determine thresholds
        grid = np.arange(100) / 100
        ratio_quantiles = np.quantile(ratio, q=grid)

        upper = grid[np.where(ratio_quantiles == 1)[0][0]]
        lower = grid[np.where(ratio_quantiles == 0)[0][-1]]

        if (1 - upper) / (args.num_groups - 1) >= lower + 0.01:
            qs = np.linspace(0, upper, args.num_groups)
            thres = np.quantile(ratio, q=qs).tolist() + [1]
        else:
            qs = np.linspace(lower, upper, args.num_groups - 1)
            thres = [0] + np.quantile(ratio, q=qs).tolist() + [1]

    else:
        thres = THRES_DICT[args.data_name]
    
    print(f'thres: {thres}')

    if args.type == 'metric':
        for p in prefixes:
            target_files = [x for x in files if x.startswith(p)]
            sd_list = []
            cv_list = []
            for f in target_files:
                data_list = torch.load(f'./results/predictions/{f}')
                _, sd, cv = calculate_sd(data_list, thres)
                sd_list.append(sd)
                cv_list.append(cv)

            print(f"Prefix: {p}")
            print(f'sd = {np.mean(sd_list):.4f} +/- {np.std(sd_list):.4f}')
            print(f'cv = {np.mean(cv_list):.4f} +/- {np.std(cv_list):.4f}')
    
    elif args.type == 'figure':
        pass




    # ori_files = [x for x in files if 'augFalse' in x]
    # aug_files = [x for x in files if 'augTrue' in x]

    

    # print(f'thres={thres}')
    # # fname = [x for x in ori_files if 'r.0' in x][0]
    # # data_list = torch.load(f'./results/predictions/{fname}')
    # # print(f'Homophily rate: {np.array(calculate_homophily_ratio(data_list)).mean()}')

    # ori_acc_list = []
    # ori_sd_list = []
    # ori_cv_list = []
    # for f in ori_files:
    #     data_list = torch.load(f'./results/predictions/{f}')
    #     acc, sd, cv = calculate_sd(data_list, thres)
    #     ori_acc_list.append(acc)
    #     ori_sd_list.append(sd)
    #     ori_cv_list.append(cv)

    # print(f"{ori_files[0].split('-r.')[0]}")
    # # print(f'acc = {np.mean(ori_acc_list):.4f} +/- {np.std(ori_acc_list):.4f}')
    # print(f'sd = {np.mean(ori_sd_list):.4f} +/- {np.std(ori_sd_list):.4f}')
    # print(f'cv = {np.mean(ori_cv_list):.4f} +/- {np.std(ori_cv_list):.4f}')

    # aug_acc_list = []
    # aug_sd_list = []
    # aug_cv_list = []
    # for f in aug_files:
    #     data_list = torch.load(f'./results/predictions/{f}')
    #     acc, sd, cv = calculate_sd(data_list, thres)
    #     aug_acc_list.append(acc)
    #     aug_sd_list.append(sd)
    #     aug_cv_list.append(cv)

    # print(f"{aug_files[0].split('-r.')[0]}")
    # # print(f'acc = {np.mean(aug_acc_list):.4f} +/- {np.std(aug_acc_list):.4f}')
    # print(f'sd = {np.mean(aug_sd_list):.4f} +/- {np.std(aug_sd_list):.4f}')
    # print(f'cv = {np.mean(aug_cv_list):.4f} +/- {np.std(aug_cv_list):.4f}')