import os
import sys
sys.path.insert(0, './')
import numpy as np

import torch
import torch.nn as nn

from torchvision import datasets, transforms
from sklearn import preprocessing

from .Utility import SubsetRandomSampler, SubsetSampler

import pandas as pd  


def load_adult_data(root):
    header = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
              'occupation', 'relationship', 'race', 'sex', 'capital-gain',
              'capital-loss', 'hours-per-week', 'native-country', 'salary']
    df = pd.read_csv(root, index_col=False, skipinitialspace=True, header=None, names=header)
    df = df.replace('?', np.nan)
    df.dropna(inplace=True)
    categorical_columns = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex',
                           'native-country']
    normalize_columns = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
    label_column = ['salary']

    def convert_to_int(columns):
        for column in columns:
            unique_values = df[column].unique().tolist()
            dic = {}
            for indx, val in enumerate(unique_values):
                dic[val] = indx
            df[column] = df[column].map(dic).astype(int)
            print(column + " done!")

    def convert_to_onehot(data, columns):
        dummies = pd.get_dummies(data[columns])
        data = data.drop(columns, axis=1)
        data = pd.concat([data, dummies], axis=1)
        return data

    def show_unique_values(columns):
        for column in columns:
            uniq = df[column].unique().tolist()
            print(column + " has " + str(len(uniq)) + " values" + " : " + str(uniq))

    convert_to_int(label_column)
    df = convert_to_onehot(df, categorical_columns)
    show_unique_values(label_column)

    def adult_normalize(columns):
        scaler = preprocessing.StandardScaler()
        df[columns] = scaler.fit_transform(df[columns])

    adult_normalize(normalize_columns)

    label = df["salary"].to_numpy().astype(float)
    data = df.drop("salary", axis=1).to_numpy().astype(float)

    return pd.DataFrame(np.concatenate([data, label.reshape(-1, 1)], axis=1))



# 新增自定义数据集类，用于读取 /data/adults.csv
class IndexedAdults(torch.utils.data.Dataset):
    def __init__(self, root, split='train', transform=None):
        # drop rows with missing values
        root = os.path.join(root, 'adults.csv')
        if not os.path.exists(root):
            raise FileNotFoundError(f"File {root} not found.")
        # read csv file
        self.data = load_adult_data(root)
        # 0.8 as train, 0.2 as test
        n = len(self.data)
        split_idx = int(n * 0.8)
        if split == 'train':
            self.data = self.data.iloc[:split_idx].reset_index(drop=True)
        else:
            self.data = self.data.iloc[split_idx:].reset_index(drop=True)
        self.transform = transform

    def __getitem__(self, index):
        row = self.data.iloc[index].values
        # the last column is the label
        # the rest are features
        features = row[:-1].astype(np.float32)
        label = int(row[-1])
        if self.transform is not None:
            features = self.transform(features)
        return features, label, index

    def __len__(self):
        return len(self.data)

# reference: Cifar10.py
def adults(batch_size, root='./data/adults.csv', valid_ratio=None, shuffle=True, augmentation=True, train_subset=None, test_subset=None,
           mislabel_ratio=0., mislabel_seed=0, class_subset_path=None, is_split=False, split_seed=0, is_shadow=False, shadow_ratio=0.8, shadow_seed=0,
           member_train=False, member_test=False, nonmember_train=False, nonmember_test=False, num_worker=0):
    if member_train and member_test:
        raise ValueError('member_train and member_test cannot be both True.')
    if nonmember_train and nonmember_test:
        raise ValueError('nonmember_train and nonmember_test cannot be both True.')

    # maybe need to transform
    transform_train = lambda x: x
    transform_valid = lambda x: x
    transform_test = lambda x: x

    trainset = IndexedAdults(root, split='train', transform=transform_train)
    validset = IndexedAdults(root, split='train', transform=transform_valid)
    testset = IndexedAdults(root, split='test', transform=transform_test)

    classes = sorted(list(set(trainset.data.iloc[:, -1].tolist())))

    if is_split:
        num_train_data, num_test_data = len(trainset), len(testset)
        np.random.seed(split_seed)
        train_subset = np.random.choice(num_train_data, size=num_train_data//2, replace=False)
        np.random.seed(split_seed)
        test_subset = np.random.choice(num_test_data, size=num_test_data//2, replace=False)
        if is_shadow:
            train_subset = set(np.arange(num_train_data)) - set(train_subset)
            train_subset = np.array(list(train_subset))
            test_subset = set(np.arange(num_test_data)) - set(test_subset)
            test_subset = np.array(list(test_subset))
            np.random.seed(shadow_seed)
            train_subset = np.random.choice(train_subset, size=int(len(train_subset) * shadow_ratio), replace=False)
            np.random.seed(shadow_seed)
            test_subset = np.random.choice(test_subset, size=int(len(test_subset) * shadow_ratio), replace=False)

    if train_subset is None:
        train_indices = list(range(len(trainset)))
    else:
        train_indices = np.random.permutation(train_subset)
    if member_train:
        train_indices = train_indices[:len(train_indices) // 2]
    if member_test:
        train_indices = train_indices[len(train_indices) // 2:]
    train_instance_num = len(train_indices)
    print('%d instances are picked from the training set' % train_instance_num)

    if test_subset is None:
        test_indices = list(range(len(testset)))
    else:
        test_indices = test_subset
    if nonmember_train:
        test_indices = test_indices[:len(test_indices) // 2]
    if nonmember_test:
        test_indices = test_indices[len(test_indices) // 2:]
    test_instance_num = len(test_indices)
    print('%d instances are picked from the test set' % test_instance_num)
    test_sampler = SubsetSampler(test_indices)

    if valid_ratio is not None and valid_ratio > 0.:
        split_pt = int(train_instance_num * valid_ratio)
        train_idx, valid_idx = train_indices[split_pt:], train_indices[:split_pt]

        if shuffle:
            train_sampler, valid_sampler = SubsetRandomSampler(train_idx), SubsetSampler(valid_idx)
        else:
            train_sampler, valid_sampler = SubsetSampler(train_idx), SubsetSampler(valid_idx)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=0, pin_memory=True)
        valid_loader = torch.utils.data.DataLoader(validset, batch_size=batch_size, sampler=valid_sampler, num_workers=0, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, sampler=test_sampler, shuffle=False, num_workers=0, pin_memory=True)
    else:
        if shuffle:
            train_sampler = SubsetRandomSampler(train_indices)
        else:
            train_sampler = SubsetSampler(train_indices)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=0, pin_memory=True)
        valid_loader = None
        test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, sampler=test_sampler, shuffle=False, num_workers=0, pin_memory=True)

    return train_loader, valid_loader, test_loader, classes
