import torch
import numpy as np
import torch.nn.functional as F
from utils import fair_metric
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, roc_auc_score, roc_curve, precision_recall_curve, accuracy_score, f1_score

# def evaluate(args, data, encoder):
#     accs, auc_rocs, paritys, equalitys = {}, {}, {}, {}
#     feat, output = encoder(data.x, data.edge_index)

#     probs = torch.sigmoid(output.squeeze())
#     probs_train = torch.sigmoid(output[data.train_mask].squeeze())
#     probs_val = torch.sigmoid(output[data.val_mask].squeeze())
#     probs_test = torch.sigmoid(output[data.test_mask].squeeze())

#     pred = (probs > 0.5).type_as(data.y)
#     pred_train = (probs_train > 0.5).type_as(data.y)
#     pred_val = (probs_val > 0.5).type_as(data.y)
#     pred_test = (probs_test > 0.5).type_as(data.y)

#     num_pred_0 = (pred == 0).sum().item()
#     num_pred_1 = (pred == 1).sum().item()
#     print(f"Test set predicted 0: {num_pred_0}, predicted 1: {num_pred_1}")
#     print(f"Test set predicted 0 ratio: {num_pred_0 / len(pred):.3f}")
#     print(f"Test set predicted 1 ratio: {num_pred_1 / len(pred):.3f}")

#     accs["all"] = pred.eq(data.y).sum().item() / data.y.size(0)
#     accs['train'] = pred_train.eq(data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
#     accs['val'] = pred_val.eq(data.y[data.val_mask]).sum().item() / data.val_mask.sum().item()
#     accs['test'] = pred_test.eq(data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()

#     auc_rocs["all"] = roc_auc_score(data.y.cpu().numpy(), pred.detach().cpu().numpy())
#     auc_rocs['train'] = roc_auc_score(data.y[data.train_mask].cpu().numpy(), pred_train.detach().cpu().numpy())
#     auc_rocs['val'] = roc_auc_score(data.y[data.val_mask].cpu().numpy(), pred_val.detach().cpu().numpy())
#     auc_rocs['test'] = roc_auc_score(data.y[data.test_mask].cpu().numpy(), pred_test.detach().cpu().numpy()) 

#     paritys["all"], equalitys["all"] = fair_metric(pred.cpu().numpy(), data.y.cpu().numpy(), data.sens_labels.cpu().numpy())
#     paritys['train'], equalitys['train'] = fair_metric(pred_train.cpu().numpy(), data.y[data.train_mask].cpu().numpy(), data.sens_labels[data.train_mask].cpu().numpy())
#     paritys['val'], equalitys['val'] = fair_metric(pred_val.cpu().numpy(), data.y[data.val_mask].cpu().numpy(), data.sens_labels[data.val_mask].cpu().numpy())
#     paritys['test'], equalitys['test'] = fair_metric(pred_test.cpu().numpy(), data.y[data.test_mask].cpu().numpy(), data.sens_labels[data.test_mask].cpu().numpy())

#     accs["all"], accs['train'], accs['val'], accs['test'] = accs["all"] *100, accs['train']*100, accs['val']* 100, accs['test'] * 100
#     auc_rocs["all"], auc_rocs['train'], auc_rocs['val'], auc_rocs['test']  = auc_rocs["all"] *100, auc_rocs['train']* 100, auc_rocs['val']* 100, auc_rocs['test']* 100
#     paritys["all"], paritys['train'], paritys['val'], paritys['test'] = paritys["all"]*100, paritys['train'] *100, paritys['val'] * 100, paritys['test'] * 100
#     equalitys["all"], equalitys['train'], equalitys['val'], equalitys['test'] = equalitys["all"] *100, equalitys['train'] * 100, equalitys['val'] * 100, equalitys['test'] * 100
#     return accs, auc_rocs, paritys, equalitys


def evaluate_sens(args, data, encoder):
    encoder.eval()
    with torch.no_grad():
        emb, logits = encoder(data.x, data.edge_index)  # [N, d]
        probs = torch.sigmoid(logits)  # 概率
        pred_sens = (probs > 0.5).long()  # 二分类敏感属性预测
    
    all_mask = data.train_mask | data.val_mask | data.test_mask
    y_true = data.sens_labels[all_mask].cpu().numpy()
    y_pred = pred_sens[all_mask].cpu().numpy()
    y_prob = probs[all_mask].cpu().numpy()

    acc = accuracy_score(y_true, y_pred)
    auc = roc_auc_score(y_true, y_prob)
    f1 = f1_score(y_true, y_pred)

    print(f'[Sensitive Info] AUC: {auc:.2f}  ACC: {acc:.2f}  F1: {f1:.2f}')
    return {'auc': auc, 'acc': acc, 'f1': f1}

def evaluate(args, data, encoder):
    encoder.eval()
    accs, auc_rocs, paritys, equalitys = {}, {}, {}, {}
    feat, output = encoder(data.x, data.edge_index)
    all_mask = data.train_mask | data.val_mask | data.test_mask

    probs = torch.sigmoid(output.squeeze())
    probs_all = probs[all_mask]
    probs_train = probs[data.train_mask]
    probs_val = probs[data.val_mask]
    probs_test = probs[data.test_mask]

    pred_all = (probs_all > 0.5).type_as(data.y)
    pred_train = (probs_train > 0.5).type_as(data.y)
    pred_val = (probs_val > 0.5).type_as(data.y)
    pred_test = (probs_test > 0.5).type_as(data.y)

    num_pred_0 = (pred_all == 0).sum().item()
    num_pred_1 = (pred_all == 1).sum().item()
    print(f"Test set predicted 0: {num_pred_0}, predicted 1: {num_pred_1}")
    print(f"Test set predicted 0 ratio: {num_pred_0 / len(pred_all):.3f}")
    print(f"Test set predicted 1 ratio: {num_pred_1 / len(pred_all):.3f}")

    accs["all"] = pred_all.eq(data.y[all_mask]).sum().item() / all_mask.sum().item()
    accs['train'] = pred_train.eq(data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
    accs['val'] = pred_val.eq(data.y[data.val_mask]).sum().item() / data.val_mask.sum().item()
    accs['test'] = pred_test.eq(data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()

    auc_rocs["all"] = roc_auc_score(data.y[all_mask].cpu().numpy(), probs_all.detach().cpu().numpy())
    auc_rocs['train'] = roc_auc_score(data.y[data.train_mask].cpu().numpy(), probs_train.detach().cpu().numpy())
    auc_rocs['val'] = roc_auc_score(data.y[data.val_mask].cpu().numpy(), probs_val.detach().cpu().numpy())
    auc_rocs['test'] = roc_auc_score(data.y[data.test_mask].cpu().numpy(), probs_test.detach().cpu().numpy()) 

    paritys["all"], equalitys["all"] = fair_metric(
        pred_all.cpu().numpy(),
        data.y[all_mask].cpu().numpy(),
        data.sens_labels[all_mask].cpu().numpy()
    )
    paritys['train'], equalitys['train'] = fair_metric(
        pred_train.cpu().numpy(),
        data.y[data.train_mask].cpu().numpy(),
        data.sens_labels[data.train_mask].cpu().numpy()
    )
    paritys['val'], equalitys['val'] = fair_metric(
        pred_val.cpu().numpy(),
        data.y[data.val_mask].cpu().numpy(),
        data.sens_labels[data.val_mask].cpu().numpy()
    )
    paritys['test'], equalitys['test'] = fair_metric(
        pred_test.cpu().numpy(),
        data.y[data.test_mask].cpu().numpy(),
        data.sens_labels[data.test_mask].cpu().numpy()
    )

    for k in accs.keys():
        accs[k] *= 100
        auc_rocs[k] *= 100
        paritys[k] *= 100
        equalitys[k] *= 100
    # print(accs, auc_rocs, paritys, equalitys)
    return accs, auc_rocs, paritys, equalitys

import torch
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score

def evaluate_per_class(args, data, encoder):
    sens_labels = data.sens_labels
    test_labels = data.y[data.test_mask]
    
    t_idx_s0 = sens_labels[data.test_mask] == 0
    t_idx_s1 = sens_labels[data.test_mask] == 1
    t_idx_s0_y1 = torch.logical_and(t_idx_s0, test_labels == 1)
    t_idx_s1_y1 = torch.logical_and(t_idx_s1, test_labels == 1)
    t_idx_s0_y0 = torch.logical_and(t_idx_s0, test_labels == 0)
    t_idx_s1_y0 = torch.logical_and(t_idx_s1, test_labels == 0)

    accs, auc_rocs, f1s, paritys, equalitys = {}, {}, {}, {}, {}
    encoder.eval()
    with torch.no_grad():
        feat, output = encoder(data.x, data.edge_index)
        # print(output)
        probs = torch.sigmoid(output.squeeze()).cpu().numpy()
        y_all = data.y.cpu().numpy()
        sens_all = data.sens_labels.cpu().numpy()
        all_mask = data.train_mask | data.val_mask | data.test_mask
        splits = {
            'all': all_mask.cpu().numpy(),
            'train': data.train_mask.cpu().numpy(),
            'val':   data.val_mask.cpu().numpy(),
            'test':  data.test_mask.cpu().numpy()
        }

        # 转换为 numpy 数组
        representation_np = feat[data.test_mask].detach().cpu().numpy()
        labels = torch.full((test_labels.shape[0],), -1, dtype=torch.int64)

        # 赋值类别标签
        labels[t_idx_s0_y1] = 0
        labels[t_idx_s1_y1] = 1
        labels[t_idx_s0_y0] = 2
        labels[t_idx_s1_y0] = 3
        labels_np = labels.cpu().numpy()

        np.savez(f"{args.dataset}_feat.npz", representations=representation_np)
        np.savez(f"{args.dataset}_labels.npz", labels=labels_np)
        
        result = {}

        pred_all =  (torch.sigmoid(output.squeeze()) > 0.5).type_as(data.y)
        num_pred_0 = (pred_all == 0).sum().item()
        num_pred_1 = (pred_all == 1).sum().item()
        # print(f"Test set predicted 0: {num_pred_0}, predicted 1: {num_pred_1}")
        # print(f"Test set predicted 0 ratio: {num_pred_0 / len(pred_all):.3f}")
        # print(f"Test set predicted 1 ratio: {num_pred_1 / len(pred_all):.3f}")

        for split_name, mask in splits.items():
            y_true = y_all[mask]
            sens    = sens_all[mask]
            prob    = probs[mask]
            pred    = (prob > 0.5).astype(int)

            # Overall metrics
            acc_total = accuracy_score(y_true, pred) * 100 if len(y_true) > 0 else float('nan')
            auc_total = roc_auc_score(y_true, prob) * 100 if len(set(y_true)) == 2 else float('nan')
            f1_total = f1_score(y_true, pred, zero_division=0) * 100 if len(set(y_true)) == 2 else float('nan')
            
            # Per sensitive group
            sens_groups = np.unique(sens)
            sens_metrics = {}
            for s in sens_groups:
                idx = sens == s
                acc = accuracy_score(y_true[idx], pred[idx]) * 100 if idx.sum() > 0 else float('nan')
                auc = roc_auc_score(y_true[idx], prob[idx]) * 100 if len(np.unique(y_true[idx])) == 2 else float('nan')
                f1_ = f1_score(y_true[idx], pred[idx], zero_division=0) * 100 if len(np.unique(y_true[idx])) == 2 else float('nan')
                sens_metrics[int(s)] = {'acc': acc, 'auc': auc, 'f1': f1_}

            # Per target class group
            target_groups = np.unique(y_true)
            y_metrics = {}
            for yval in target_groups:
                idx = y_true == yval
                acc = accuracy_score(pred[idx], y_true[idx]) * 100 if idx.sum() > 0 else float('nan')
                auc = roc_auc_score((y_true == yval).astype(int), prob if yval == 1 else 1 - prob) * 100 if len(set(y_true)) == 2 else float('nan')
                f1_ = f1_score((y_true == yval).astype(int), (pred   == yval).astype(int), zero_division=0) * 100 if idx.sum() > 0 else float('nan')
                y_metrics[int(yval)] = {'acc': acc, 'auc': auc, 'f1': f1_}
            
            # Fairness metrics DP & EO
            dp, eo = fair_metric(pred, y_true, sens)
            dp, eo = dp * 100, eo * 100

            result[split_name] = {
                'overall': {'acc': acc_total, 'auc': auc_total, 'f1': f1_total},
                'sens_group': sens_metrics,
                'target_group': y_metrics,
                'fairness': {'dp': dp, 'eo': eo}
            }

    for split_name in splits:
        # 收集所有子群 acc/auc/f1
        sens_vals = result[split_name]['sens_group'].values()
        target_vals = result[split_name]['target_group'].values()

        # all_acc = [v['acc'] for v in sens_vals] + [v['acc'] for v in target_vals]
        # all_auc = [v['auc'] for v in sens_vals] + [v['auc'] for v in target_vals]
        # all_f1  = [v['f1']  for v in sens_vals] + [v['f1']  for v in target_vals]

        all_acc = [v['acc'] for v in target_vals]
        all_auc = [v['auc'] for v in target_vals]
        all_f1  = [v['f1']  for v in target_vals]

        accs[split_name]    = np.nanmean(all_acc)
        auc_rocs[split_name]= np.nanmean(all_auc)
        f1s[split_name]     = np.nanmean(all_f1)
        paritys[split_name] = result[split_name]['fairness']['dp']
        equalitys[split_name]= result[split_name]['fairness']['eo']

    return accs, auc_rocs, paritys, equalitys
