import pandas as pd
import os
import numpy as np
import random
import torch
from torch_geometric.data import Data
from torch_geometric.utils import from_scipy_sparse_matrix
import scipy.sparse as sp
from typing import Optional, Callable, List, Tuple, Dict, Any

def index_to_mask(node_num: int, index: torch.Tensor) -> torch.Tensor:
    mask = torch.zeros(node_num, dtype=torch.bool)
    mask[index] = True
    return mask

def sys_normalized_adjacency(adj: sp.coo_matrix) -> sp.coo_matrix:
    adj = sp.coo_matrix(adj)
    adj = adj + sp.eye(adj.shape[0])
    row_sum = np.array(adj.sum(1))
    row_sum = (row_sum == 0) * 1 + row_sum
    d_inv_sqrt = np.power(row_sum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo()

def compute_distance_matrix(x: np.ndarray) -> np.ndarray:
    if x.shape[0] > 10000:
        n = x.shape[0]
        n_neighbors = min(20, n // 100)
        adj = np.zeros((n, n))
        for i in range(n):
            neighbors = np.random.choice(n, size=n_neighbors, replace=False)
            for j in neighbors:
                if i != j:
                    adj[i, j] = 1
                    adj[j, i] = 1
        return adj
    else:
        try:
            from scipy.spatial.distance import cdist
            return cdist(x, x, metric='euclidean')
        except Exception as e:
            n = x.shape[0]
            dist_matrix = np.zeros((n, n))
            for i in range(n):
                for j in range(n):
                    if i != j:
                        dist_matrix[i, j] = np.sqrt(np.sum((x[i] - x[j]) ** 2))
                    else:
                        dist_matrix[i, j] = 0
            return dist_matrix

def build_relationship(x: np.ndarray, thresh: float = 0.25) -> sp.coo_matrix:
    try:
        if x.shape[0] > 10000:
            adj = compute_distance_matrix(x)
            return sp.coo_matrix(adj)
        else:
            df_euclid = compute_distance_matrix(x)
            max_dist = df_euclid.max()
            if max_dist == 0:
                n = x.shape[0]
                adj = np.random.rand(n, n) > 0.8
                adj = adj.astype(float)
                adj = (adj + adj.T) > 0
                adj = adj.astype(float)
                np.fill_diagonal(adj, 0)
                return sp.coo_matrix(adj)
            df_euclid = df_euclid / max_dist
            df_euclid = 1 - df_euclid
            adj = np.where(df_euclid > thresh, 1, 0)
            adj = adj - np.diag(np.diag(adj))
            return sp.coo_matrix(adj)
    except Exception as e:
        n = x.shape[0]
        adj = np.random.rand(n, n) > 0.7
        adj = adj.astype(float)
        adj = (adj + adj.T) > 0
        adj = adj.astype(float)
        np.fill_diagonal(adj, 0)
        return sp.coo_matrix(adj)

def extract_sensitive_attribute(data_df: pd.DataFrame, sens_attr: str,
                              feature_columns: List[str]) -> Tuple[np.ndarray, List[str], int]:
    if sens_attr in feature_columns:
        sens_attr_idx = feature_columns.index(sens_attr)
        non_sens_columns = [col for col in feature_columns if col != sens_attr]
    else:
        sens_attr_idx = -1
        non_sens_columns = feature_columns
    sens_values = data_df[sens_attr].values
    return sens_values, non_sens_columns, sens_attr_idx

def prepare_graphdro_data(data: Data, sens_attr_idx: int, requires_grad: bool = True) -> Data:
    prepared_data = data.clone()
    if sens_attr_idx >= 0:
        sens_attr = data.x[:, sens_attr_idx]
        x_non_sens = torch.cat([
            data.x[:, :sens_attr_idx],
            data.x[:, sens_attr_idx + 1:]
        ], dim=1)
        prepared_data.x = x_non_sens
        prepared_data.s = sens_attr
    else:
        if hasattr(data, 's'):
            prepared_data.s = data.s
        elif hasattr(data, 'a'):
            prepared_data.s = data.a
        else:
            raise ValueError("No sensitive attribute found in data")
    if requires_grad:
        prepared_data.x = prepared_data.x.requires_grad_(True)
        prepared_data.s = prepared_data.s.float().requires_grad_(True)
        num_nodes = prepared_data.x.shape[0]
        from torch_geometric.utils import to_dense_adj
        adj_matrix = to_dense_adj(prepared_data.edge_index, max_num_nodes=num_nodes)[0]
        prepared_data.adj_matrix = adj_matrix.requires_grad_(True)
    prepared_data.sens_attr_idx = sens_attr_idx
    return prepared_data

def reconstruct_full_features(x_non_sens: torch.Tensor, sens_attr: torch.Tensor,
                            sens_attr_idx: int) -> torch.Tensor:
    if sens_attr_idx >= 0:
        x_full = torch.cat([
            x_non_sens[:, :sens_attr_idx],
            sens_attr.unsqueeze(1),
            x_non_sens[:, sens_attr_idx:]
        ], dim=1)
    else:
        x_full = x_non_sens
    return x_full

def load_german_simple(path: str = "dataset/german/german/") -> Tuple[Data, int, int]:
    import csv
    import torch
    import numpy as np
    from torch_geometric.data import Data
    csv_path = os.path.join(path, "german.csv")
    data_rows = []
    with open(csv_path, 'r') as f:
        reader = csv.reader(f)
        headers = next(reader)
        for row in reader:
            processed_row = []
            for i, value in enumerate(row):
                if headers[i] == 'Gender':
                    processed_row.append(1 if value == 'Female' else 0)
                else:
                    try:
                        processed_row.append(float(value))
                    except:
                        processed_row.append(0)
            data_rows.append(processed_row)
    data_array = np.array(data_rows)
    predict_attr_idx = headers.index('GoodCustomer')
    sens_attr_idx = headers.index('Gender')
    remove_cols = ['GoodCustomer', 'OtherLoansAtStore', 'PurposeOfLoan']
    keep_indices = [i for i, col in enumerate(headers) if col not in remove_cols]
    features = data_array[:, keep_indices]
    labels = data_array[:, predict_attr_idx]
    labels = np.where(labels == -1, 0, labels)
    sens_attr_new_idx = -1
    for i, orig_idx in enumerate(keep_indices):
        if orig_idx == sens_attr_idx:
            sens_attr_new_idx = i
            break
    sens_attr = features[:, sens_attr_new_idx]
    features = np.delete(features, sens_attr_new_idx, axis=1)
    features = torch.FloatTensor(features)
    labels = torch.LongTensor(labels.astype(int))
    sens = torch.LongTensor(sens_attr.astype(int))
    edges_path = os.path.join(path, "german_edges.txt")
    edges = np.loadtxt(edges_path, dtype=int)
    edge_index = torch.LongTensor(edges.T)
    num_nodes = features.shape[0]
    num_train = int(0.5 * num_nodes)
    num_val = int(0.25 * num_nodes)
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    train_mask[:num_train] = True
    val_mask[num_train:num_train + num_val] = True
    test_mask[num_train + num_val:] = True
    data = Data(
        x=features,
        edge_index=edge_index,
        y=labels,
        s=sens,
        train_mask=train_mask,
        val_mask=val_mask,
        test_mask=test_mask
    )
    data.sens_attr_idx = sens_attr_new_idx
    data = prepare_graphdro_data(data, -1, requires_grad=True)
    return data, features.shape[1], len(torch.unique(labels))

def load_german(dataset: str = "german",
                sens_attr: str = "Gender",
                predict_attr: str = "GoodCustomer",
                path: str = "dataset/german/german/",
                label_number: int = 1000,
                seed: int = 19,
                test_idx: bool = False,
                requires_grad: bool = True) -> Tuple[Data, int, int]:
    return load_german_simple(path)

def load_bail(dataset: str = "bail",
              sens_attr: str = "WHITE",
              predict_attr: str = "RECID",
              path: str = "../dataset/bail/bail/",
              label_number: int = 1000,
              seed: int = 19,
              test_idx: bool = False,
              requires_grad: bool = True) -> Tuple[Data, int, int]:
    idx_features_labels = pd.read_csv(os.path.join(path, "bail.csv"))
    header = list(idx_features_labels.columns)
    header.remove(predict_attr)
    sens_values, non_sens_columns, sens_attr_idx = extract_sensitive_attribute(
        idx_features_labels, sens_attr, header
    )
    edge_file = os.path.join(path, "bail_edges.txt")
    if os.path.exists(edge_file):
        edges_unordered = np.genfromtxt(edge_file).astype('int')
        features = sp.csr_matrix(idx_features_labels[non_sens_columns], dtype=np.float32)
        labels = idx_features_labels[predict_attr].values
        idx = np.array(idx_features_labels.index)
        idx_map = {j: i for i, j in enumerate(idx)}
        edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                         dtype=int).reshape(edges_unordered.shape)
        adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                            shape=(labels.shape[0], labels.shape[0]),
                            dtype=np.float32)
    else:
        features = sp.csr_matrix(idx_features_labels[non_sens_columns], dtype=np.float32)
        labels = idx_features_labels[predict_attr].values
        adj = build_relationship(features.toarray(), thresh=0.25)
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    adj = sys_normalized_adjacency(adj)
    features = torch.FloatTensor(np.array(features.todense()))
    labels = torch.LongTensor(labels)
    sens = torch.LongTensor(sens_values)
    edge_index = torch.stack([
        torch.tensor(list(adj.row), dtype=torch.long),
        torch.tensor(list(adj.col), dtype=torch.long)
    ], dim=0)
    random.seed(seed)
    np.random.seed(seed)
    num_nodes = features.shape[0]
    num_train = int(0.5 * num_nodes)
    num_val = int(0.25 * num_nodes)
    idx_train = torch.LongTensor(range(num_train))
    idx_val = torch.LongTensor(range(num_train, num_train + num_val))
    idx_test = torch.LongTensor(range(num_train + num_val, num_nodes))
    train_mask = index_to_mask(num_nodes, idx_train)
    val_mask = index_to_mask(num_nodes, idx_val)
    test_mask = index_to_mask(num_nodes, idx_test)
    data = Data(x=features, edge_index=edge_index, y=labels, s=sens,
                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
    data.sens_attr_idx = sens_attr_idx
    if requires_grad:
        data = prepare_graphdro_data(data, sens_attr_idx, requires_grad=True)
    return data, data.x.shape[1], len(labels.unique())

def load_credit(dataset: str = "credit",
                sens_attr: str = "Age",
                predict_attr: str = "NoDefaultNextMonth",
                path: str = "../dataset/credit/credit/",
                label_number: int = 1000,
                seed: int = 19,
                test_idx: bool = False,
                requires_grad: bool = True) -> Tuple[Data, int, int]:
    idx_features_labels = pd.read_csv(os.path.join(path, "credit.csv"))
    header = list(idx_features_labels.columns)
    header.remove(predict_attr)
    sens_values, non_sens_columns, sens_attr_idx = extract_sensitive_attribute(
        idx_features_labels, sens_attr, header
    )
    features = sp.csr_matrix(idx_features_labels[non_sens_columns], dtype=np.float32)
    labels = idx_features_labels[predict_attr].values
    n = features.shape[0]
    n_neighbors = min(10, n // 100)
    adj = sp.lil_matrix((n, n), dtype=np.float32)
    for i in range(n):
        neighbors = np.random.choice(n, size=n_neighbors, replace=False)
        for j in neighbors:
            if i != j:
                adj[i, j] = 1.0
    adj = adj.tocoo()
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    adj = sys_normalized_adjacency(adj)
    features = torch.FloatTensor(np.array(features.todense()))
    labels = torch.LongTensor(labels)
    sens = torch.LongTensor(sens_values)
    edge_index = torch.stack([
        torch.tensor(list(adj.row), dtype=torch.long),
        torch.tensor(list(adj.col), dtype=torch.long)
    ], dim=0)
    random.seed(seed)
    np.random.seed(seed)
    num_nodes = features.shape[0]
    num_train = int(0.5 * num_nodes)
    num_val = int(0.25 * num_nodes)
    idx_train = torch.LongTensor(range(num_train))
    idx_val = torch.LongTensor(range(num_train, num_train + num_val))
    idx_test = torch.LongTensor(range(num_train + num_val, num_nodes))
    train_mask = index_to_mask(num_nodes, idx_train)
    val_mask = index_to_mask(num_nodes, idx_val)
    test_mask = index_to_mask(num_nodes, idx_test)
    data = Data(x=features, edge_index=edge_index, y=labels, s=sens,
                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
    data.sens_attr_idx = sens_attr_idx
    if requires_grad:
        data = prepare_graphdro_data(data, sens_attr_idx, requires_grad=True)
    return data, data.x.shape[1], len(labels.unique())

def load_nba(dataset: str = "nba",
              sens_attr: str = "country",
              predict_attr: str = "SALARY",
              path: str = "dataset/nba/nba/",
              label_number: int = 400,
              seed: int = 19,
              test_idx: bool = False,
              requires_grad: bool = True) -> Tuple[Data, int, int]:
    idx_features_labels = pd.read_csv(os.path.join(path, "nba.csv"))
    header = list(idx_features_labels.columns)
    header.remove("user_id")
    header.remove(sens_attr)
    header.remove(predict_attr)
    labels = idx_features_labels[predict_attr].values
    valid_mask = np.isin(labels, [-1, 0, 1])
    idx_features_labels = idx_features_labels[valid_mask].reset_index(drop=True)
    labels = idx_features_labels[predict_attr].values
    labels_bin = (labels == 1).astype(int)
    features = sp.csr_matrix(idx_features_labels[header], dtype=np.float32)
    sens_values = idx_features_labels[sens_attr].values
    idx = np.array(idx_features_labels["user_id"], dtype=int)
    idx_map = {j: i for i, j in enumerate(idx)}
    edge_file = os.path.join(path, "nba_relationship.txt")
    if os.path.exists(edge_file):
        edges_unordered = np.genfromtxt(edge_file, dtype=int)
        edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), dtype=int).reshape(edges_unordered.shape)
        valid_edges = edges[(edges >= 0).all(axis=1)]
        adj = sp.coo_matrix((np.ones(valid_edges.shape[0]), (valid_edges[:, 0], valid_edges[:, 1])),
                            shape=(labels.shape[0], labels.shape[0]), dtype=np.float32)
    else:
        adj = sp.eye(labels.shape[0])
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    adj = sys_normalized_adjacency(adj)
    features = torch.FloatTensor(np.array(features.todense()))
    labels = torch.LongTensor(labels_bin)
    sens = torch.LongTensor(sens_values)
    np.random.seed(seed)
    num_nodes = features.shape[0]
    idx_all = np.arange(num_nodes)
    np.random.shuffle(idx_all)
    num_train = int(0.5 * num_nodes)
    num_val = int(0.25 * num_nodes)
    idx_train = torch.LongTensor(idx_all[:num_train])
    idx_val = torch.LongTensor(idx_all[num_train:num_train+num_val])
    idx_test = torch.LongTensor(idx_all[num_train+num_val:])
    train_mask = index_to_mask(num_nodes, idx_train)
    val_mask = index_to_mask(num_nodes, idx_val)
    test_mask = index_to_mask(num_nodes, idx_test)
    edge_index = torch.stack([
        torch.tensor(adj.row, dtype=torch.long),
        torch.tensor(adj.col, dtype=torch.long)
    ], dim=0)
    data = Data(x=features, edge_index=edge_index, y=labels, s=sens,
                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
    data.sens_attr_idx = -1
    if requires_grad:
        data = prepare_graphdro_data(data, -1, requires_grad=True)
    return data, data.x.shape[1], len(labels.unique())

def load_pokec_z(dataset: str = "pokec_z",
                 sens_attr: str = "region",
                 predict_attr: str = "I_am_working_in_field",
                 path: str = "../dataset/pokec/",
                 label_number: int = 1000,
                 seed: int = 19,
                 test_idx: bool = False) -> Tuple[Data, int, int]:
    idx_features_labels = pd.read_csv(os.path.join(path, "region_job.csv"))
    header = list(idx_features_labels.columns)
    header.remove(predict_attr)
    edges_unordered = np.genfromtxt(os.path.join(path, "region_job_relationship.txt")).astype('int')
    features = sp.csr_matrix(idx_features_labels[header], dtype=np.float32)
    labels = idx_features_labels[predict_attr].values
    idx = np.array(idx_features_labels.index)
    idx_map = {j: i for i, j in enumerate(idx)}
    edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                     dtype=int).reshape(edges_unordered.shape)
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    adj = sys_normalized_adjacency(adj)
    features = torch.FloatTensor(np.array(features.todense()))
    labels = torch.LongTensor(labels)
    sens = idx_features_labels[sens_attr].values.astype(int)
    sens = torch.LongTensor(sens)
    edge_index, _ = from_scipy_sparse_matrix(adj)
    random.seed(seed)
    np.random.seed(seed)
    num_nodes = features.shape[0]
    num_train = int(0.5 * num_nodes)
    num_val = int(0.25 * num_nodes)
    idx_train = torch.LongTensor(range(num_train))
    idx_val = torch.LongTensor(range(num_train, num_train + num_val))
    idx_test = torch.LongTensor(range(num_train + num_val, num_nodes))
    train_mask = index_to_mask(num_nodes, idx_train)
    val_mask = index_to_mask(num_nodes, idx_val)
    test_mask = index_to_mask(num_nodes, idx_test)
    data = Data(x=features, edge_index=edge_index, y=labels, a=sens,
                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
    return data, data.x.shape[1], len(labels.unique())

def load_pokec_n(dataset: str = "pokec_n",
                 sens_attr: str = "region",
                 predict_attr: str = "I_am_working_in_field",
                 path: str = "../dataset/pokec/",
                 label_number: int = 1000,
                 seed: int = 19,
                 test_idx: bool = False) -> Tuple[Data, int, int]:
    idx_features_labels = pd.read_csv(os.path.join(path, "region_job_2.csv"))
    header = list(idx_features_labels.columns)
    header.remove(predict_attr)
    edges_unordered = np.genfromtxt(os.path.join(path, "region_job_2_relationship.txt")).astype('int')
    features = sp.csr_matrix(idx_features_labels[header], dtype=np.float32)
    labels = idx_features_labels[predict_attr].values
    idx = np.array(idx_features_labels.index)
    idx_map = {j: i for i, j in enumerate(idx)}
    edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                     dtype=int).reshape(edges_unordered.shape)
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    adj = sys_normalized_adjacency(adj)
    features = torch.FloatTensor(np.array(features.todense()))
    labels = torch.LongTensor(labels)
    sens = idx_features_labels[sens_attr].values.astype(int)
    sens = torch.LongTensor(sens)
    edge_index, _ = from_scipy_sparse_matrix(adj)
    random.seed(seed)
    np.random.seed(seed)
    num_nodes = features.shape[0]
    num_train = int(0.5 * num_nodes)
    num_val = int(0.25 * num_nodes)
    idx_train = torch.LongTensor(range(num_train))
    idx_val = torch.LongTensor(range(num_train, num_train + num_val))
    idx_test = torch.LongTensor(range(num_train + num_val, num_nodes))
    train_mask = index_to_mask(num_nodes, idx_train)
    val_mask = index_to_mask(num_nodes, idx_val)
    test_mask = index_to_mask(num_nodes, idx_test)
    data = Data(x=features, edge_index=edge_index, y=labels, a=sens,
                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
    return data, data.x.shape[1], len(labels.unique())

def load_dataset(config: Dict[str, Any]) -> Tuple[Data, int, int]:
    dataset_name = config['dataset']
    data_config = config['data']
    common_params = {
        'sens_attr': data_config['sens_attr'],
        'predict_attr': data_config['predict_attr'],
        'path': data_config['path'],
        'label_number': data_config.get('label_number', 1000),
        'seed': config.get('seed', 19),
        'test_idx': data_config.get('test_idx', False),
        'requires_grad': data_config.get('requires_grad', True)
    }
    if dataset_name == 'german':
        return load_german(**common_params)
    elif dataset_name == 'bail':
        return load_bail(**common_params)
    elif dataset_name == 'credit':
        return load_credit(**common_params)
    elif dataset_name == 'pokec_z':
        return load_pokec_z(**common_params)
    elif dataset_name == 'pokec_n':
        return load_pokec_n(**common_params)
    elif dataset_name == 'nba':
        return load_nba(**common_params)
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

def compute_feature_correlation(data: Data, sens_attr_idx: int = 0) -> Dict[str, float]:
    correlations = {}
    sens_attr = data.x[:, sens_attr_idx].detach().numpy()
    for i in range(data.x.shape[1]):
        if i != sens_attr_idx:
            feature = data.x[:, i].detach().numpy()
            corr = np.corrcoef(sens_attr, feature)[0, 1]
            correlations[f'feature_{i}'] = corr
    labels = data.y.detach().numpy()
    label_corr = np.corrcoef(sens_attr, labels)[0, 1]
    correlations['label'] = label_corr
    return correlations

def analyze_dataset(data: Data, dataset_name: str) -> Dict[str, Any]:
    analysis = {
        'dataset': dataset_name,
        'num_nodes': data.num_nodes,
        'num_edges': data.edge_index.shape[1],
        'num_features': data.x.shape[1],
        'num_classes': len(data.y.unique()),
        'avg_degree': data.edge_index.shape[1] / data.num_nodes,
    }
    unique_labels, label_counts = data.y.unique(return_counts=True)
    analysis['label_distribution'] = {
        f'class_{label.item()}': count.item()
        for label, count in zip(unique_labels, label_counts)
        }
    if hasattr(data, 's'):
        unique_sens, sens_counts = data.s.unique(return_counts=True)
        analysis['sensitive_distribution'] = {
            f'group_{sens.item()}': count.item()
            for sens, count in zip(unique_sens, sens_counts)
        }
    analysis['split_distribution'] = {
        'train': data.train_mask.sum().item(),
        'val': data.val_mask.sum().item(),
        'test': data.test_mask.sum().item()
    }
    return analysis

def load_data(dataset_name: str = "german", **kwargs) -> Tuple[Data, int, int]:
    config = {
        'dataset': dataset_name,
        'data': kwargs
    }
    return load_dataset(config)