import os
import sys

sys.path.append(os.pardir)

import random

from copy import deepcopy

from sklearn.preprocessing import (LabelEncoder)

import torchvision.transforms as transforms

from load.load_casual_dataset.dataset import generate_observational_dataset

import torch
import numpy as np

tp = transforms.ToTensor()
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
     ])
transform_fn = transforms.Compose([
    transforms.ToTensor()
])

# DATA_PATH ='./load/share_dataset/'  #'../../../share_dataset/'
DATA_PATH = '../share_dataset/'
IMAGE_DATA = ['mnist', 'cifar10', 'cifar100', 'cifar20', 'utkface', 'facescrub', 'places365']
TABULAR_DATA = ['breast_cancer_diagnose', 'diabetes', 'adult_income', 'criteo', 'credit', 'nursery', 'avazu']
GRAPH_DATA = ['cora']
CAUSAL_DATA = ['observational']
TEXT_DATA = ['news20', 'cola_public', 'SST-2', 'STS-B', 'MRPC', 'MNLI', 'QNLI', 'QQP', 'WNLI', 'RTE', 'MMLU']

def dataset_partition(args, index, dst, half_dim):
    if args.k == 1:
        return dst
    assert args.dataset in CAUSAL_DATA,f"dataset not supported {args.dataset}"
    dim_list = []
    for ik in range(args.k):
        dim_list.append(int(args.dataset_split['dims'][ik]))
    if len(dim_list) > 1:
        for i in range(1, len(dim_list)):
            dim_list[i] = dim_list[i] + dim_list[i - 1]
    dim_list.insert(0, 0)

    if args.k == 1:  # Centralized Training
        return dst

    if index <= (args.k - 1):
        return (dst[0][:, dim_list[index]:dim_list[index + 1]], None)
    else:
        assert index <= (args.k - 1), "invalide party index"
        return None



def load_dataset_per_party(args, index):
    print('load_dataset_per_party  args.need_auxiliary = ', args.need_auxiliary)
    args.classes = [None] * args.num_classes

    half_dim = -1
    args.idx_train = None
    args.idx_test = None
    assert args.dataset == "observational", "dataset not supported yet"
    n_train = args.dataset_split["n_train"]
    X, B_true = generate_observational_dataset(
        save_dir=args.data_path
    )
    if args.need_auxiliary == 1:
        assert 1 == 2, "need_auxiliary not supported for {} dataset".format(args.dataset)
    perturbation_colname = "perturbation_label"
    X_numpy = X.drop(perturbation_colname, axis=1).to_numpy().astype(float)
    train_dst = (torch.tensor(X_numpy[:n_train]), None)
    test_dst = (torch.tensor(X_numpy[n_train:]), None)
    args.B_true = B_true

    if len(train_dst) == 2:
        if args.dataset in CAUSAL_DATA:
            if args.need_auxiliary == 1:
                assert 1 == 2, "need_auxiliary not supported for causal dataset:{}".format(args.dataset)
            train_dst = (train_dst[0].to(args.device), None)
            test_dst = (test_dst[0].to(args.device), None)
            train_dst = dataset_partition(args, index, train_dst, -1)
            test_dst = dataset_partition(args, index, test_dst, -1)
        else:
            train_dst, args = dataset_partition(args, index, train_dst, half_dim)
            test_dst = ([deepcopy(train_dst[0][0]), deepcopy(train_dst[0][1]), test_dst[0][2]], test_dst[1])
    elif len(train_dst) == 3:
        train_dst, args = dataset_partition(args, index, train_dst, half_dim)
        test_dst = (
            [deepcopy(train_dst[0][0]), deepcopy(train_dst[0][1]), test_dst[0][2]], test_dst[1], test_dst[2])
    # important
    if args.need_auxiliary == 1:
        # print(f"[debug] aux_dst={aux_dst[0].shape},{aux_dst[1].shape if aux_dst[1] != None else aux_dst[1]}")
        # if len(aux_dst) == 3:
        #     print(f"[debug] aux_dst[2]={aux_dst[2].shape if aux_dst[2] != None else aux_dst[2]}")
        return args, half_dim, train_dst, test_dst
    else:
        return args, half_dim, train_dst, test_dst


def process_dense_feats(data, feats):
    # logging.info(f"Processing feats: {feats}")
    d = data.copy()
    d = d[feats].fillna(0.0)
    for f in feats:
        d[f] = d[f].apply(lambda x: np.log(x + 1) if x > -1 else -1)
    return d


def process_sparse_feats(data, feats):
    # logging.info(f"Processing feats: {feats}")
    d = data.copy()
    d = d[feats].fillna("-1")
    for f in feats:
        label_encoder = LabelEncoder()
        d[f] = label_encoder.fit_transform(d[f])
    feature_cnt = 0
    for f in feats:
        d[f] += feature_cnt
        feature_cnt += d[f].nunique()
    return d


def prepare_poison_target_list(args):
    args.target_label = random.randint(0, args.num_classes - 1)


def get_dataset_path(dataset_split):
    if 'train_set_file' in dataset_split and 'test_set_file' in dataset_split:
        train_set_file = dataset_split['train_set_file']
        test_set_file = dataset_split['test_set_file']
        return train_set_file, test_set_file
    return None, None


