import multiprocessing
import os
import random
import string

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

import pickle

from sklearn.base import clone
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score

import numpy as np
from tqdm import tqdm
import time


def auc(clf, x, y):
    # Calculate the Area Under the ROC Curve (AUC) for a classifier
    probs = clf.predict_proba(x)
    return roc_auc_score(y, probs[:, 1])


def generate_random_string(length):
    # Generate a random string of given length
    letters_and_digits = string.ascii_lowercase + string.digits
    result_str = ''.join((random.choice(letters_and_digits) for _ in range(length)))
    return result_str


class DShap(object):

    def __init__(self, X_train0, y_train0, X_train1, y_train1, X_test, y_test, truncation_tolerance, output_dir):
        self.X_test = X_test
        self.y_test = y_test

        self.X_train0 = X_train0
        self.X_train1 = X_train1
        self.y_train0 = y_train0
        self.y_train1 = y_train1

        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)

        self.X_train = np.concatenate([X_train0, X_train1])
        self.y_train = np.concatenate([y_train0, y_train1])

        self.truncation_tolerance = truncation_tolerance

        self._init_clf()
        self._init_score()

    def _init_clf(self):
        # Initialize the classifier
        clf = LogisticRegression(verbose=0, n_jobs=1)
        self.random_score = 0.5
        self.clf = clf

    def _get_new_clf(self):
        # Create a new instance of the classifier
        return clone(self.clf)

    def _init_score(self):
        # Initialize the full score by training the classifier on the entire training set
        clf = self._get_new_clf()
        clf.fit(self.X_train, self.y_train)
        full_score = auc(clf, self.X_test, self.y_test)
        self.full_score = full_score

    def _tmc_shapley_one_iteration(self):
        # Perform one iteration of the TMC-Shapley algorithm
        idxs = np.random.permutation(len(self.X_train0))
        marginal_contribs = np.zeros(len(self.X_train0))
        X_batch = self.X_train1.copy()
        y_batch = self.y_train1.copy()

        truncation_counter = 0
        new_score = self.random_score

        n = 0
        for idx in idxs:
            n += 1
            old_score = new_score
            X_batch = np.append(X_batch, [self.X_train0[idx]], axis=0)
            y_batch = np.append(y_batch, [self.y_train0[idx]], axis=0)

            clf = self._get_new_clf()
            clf.fit(X_batch, y_batch)
            new_score = auc(clf, self.X_test, self.y_test)

            marginal_contribs[idx] = new_score - old_score

            # Check if the new score is within the truncation tolerance of the full score
            distance_to_full_score = np.abs(new_score - self.full_score)
            if distance_to_full_score <= self.truncation_tolerance * self.full_score:
                truncation_counter += 1
                if truncation_counter > 5:
                    break
            else:
                truncation_counter = 0

        output_name = generate_random_string(20)
        with open('{}/tmc_result_{}.pkl'.format(self.output_dir, output_name), 'wb') as f:
            pickle.dump((marginal_contribs, idxs, n, new_score), f)

    def tmc_shapley(self, iterations):
        # Perform TMC-Shapley iterations
        for _ in tqdm(range(iterations)):
            self._tmc_shapley_one_iteration()


def worker(pid):
    # Worker function to be executed by each process
    print('worker:', pid)
    data_path = 'splitted_dataset.pkl'
    with open(data_path, 'rb') as f:
        X_train0, y_train0, X_train1, y_train1, X_test, y_test = pickle.load(f)

    seed = int(time.time() // (pid + 1)) % 9991 + pid
    np.random.seed(seed)

    dshap = DShap(X_train0, y_train0, X_train1, y_train1,
                  X_test, y_test, truncation_tolerance=0.025,
                  output_dir='./output')
    dshap.tmc_shapley(100000 // worker_num)


if __name__ == '__main__':
    worker_num = 40
    pool = multiprocessing.Pool(processes=worker_num)
    results = []
    for pid in range(worker_num):
        results.append(pool.apply_async(worker, (pid,)))
    for result in results:
        result.get()
    pool.close()
    pool.join()
