import glob
import pandas as pd
from sklearn.model_selection import StratifiedKFold
import sklearn.metrics as metrics
from sklearn.naive_bayes import GaussianNB
import numpy as np
import re

import pandas as pd
import glob as glob
import traceback
import numpy as np
from secret_utils import parse_secrets, DICT_GET_PRIOR
from sklearn.model_selection import KFold
import sklearn.metrics as metrics
from sklearn.feature_extraction.text import CountVectorizer
from tqdm import tqdm
from sklearn.linear_model import SGDClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
from joblib import Parallel, delayed
from datetime import datetime
from collections import Counter
import argparse
import gc
from xgboost import XGBClassifier
import sys
import os
from catboost import CatBoostClassifier
from sklearn.ensemble import StackingClassifier


import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--nonmembers", type=int, default=64)
parser.add_argument('--base_path', type=str, help='Base path')
args = parser.parse_args()
base_path = args.base_path

args = parser.parse_args()
print(args)

out_path = 'filtered_step_2'
os.makedirs(f'{base_path}/{out_path}', exist_ok=True)

files = [f for f in glob.glob(f'{base_path}/filtered_step_1/*.pkl')]
print(files)

filtered_df = pd.concat([pd.read_pickle(f) for f in files])

from collections import Counter
class TextPositionVectorizer():
    def __init__(self, ngram_len=1, max_ngram=50):
        self.ngram_len = ngram_len
        self.max_ngram = max_ngram
    
    def fit(self, x):
        self.max_len = int(np.percentile([len(a) for a in x], 90)) - self.ngram_len
        self.max_values = []
        for i in range(self.max_len):
            self.max_values.append([a for a, _ in Counter([a[i:i+self.ngram_len] for a in x if len(a[i:i+self.ngram_len]) == self.ngram_len]).most_common(self.max_ngram)])
        return self

    def transform(self, x):
        x = np.stack([np.concatenate([self.parse_val(a, i) for i in range(self.max_len)]) for a in x])
        return x
    
    def fit_transform(self, x):
        return self.fit(x).transform(x)
    
    def get_feature_names_out(self):
        return [(i, c) for i, v in enumerate(self.max_values) for c in v]

    def parse_val(self, a, i):
        res = np.zeros(len(self.max_values[i]), dtype=bool)
        if len(a) < i:
            return res
        if a[i:i+self.ngram_len] in self.max_values[i]:
            pos = self.max_values[i].index(a[i:i+self.ngram_len])
            res[pos] = True
        return res

filtered_df['id'] = range(len(filtered_df))

# Blind baseline
def get_roc_auc(y_true, y_pred_proba):
    fpr, tpr, _ = metrics.roc_curve(y_true, y_pred_proba) 
    return metrics.auc(fpr, tpr)

def get_tpr_metric(y_true, y_pred_proba, fpr_budget):
    fpr, tpr, _ = metrics.roc_curve(y_true, y_pred_proba) 
    tpr_at_low_fpr = np.interp(fpr_budget/100, fpr,tpr)
    return tpr_at_low_fpr


from collections import Counter
class TextPositionVectorizer():
    def __init__(self, ngram_len=1, max_ngram=50):
        self.ngram_len = ngram_len
        self.max_ngram = max_ngram
    
    def fit(self, x):
        self.max_len = int(np.percentile([len(a) for a in x], 90)) - self.ngram_len
        self.max_values = []
        for i in range(self.max_len):
            self.max_values.append([a for a, _ in Counter([a[i:i+self.ngram_len] for a in x]).most_common(self.max_ngram)])
        return self

    def transform(self, x):
        x = np.stack([np.concatenate([self.parse_val(a, i) for i in range(self.max_len)]) for a in x])
        return x
    
    def fit_transform(self, x):
        return self.fit(x).transform(x)
    
    def get_feature_names_out(self):
        return [(i, c) for i, v in enumerate(self.max_values) for c in v]

    def parse_val(self, a, i):
        res = np.zeros(len(self.max_values[i]), dtype=bool)
        if len(a) < i:
            return res
        if a[i:i+self.ngram_len] in self.max_values[i]:
            pos = self.max_values[i].index(a[i:i+self.ngram_len])
            res[pos] = True
        return res

blind_clfs = [
    (
        lambda: CountVectorizer(max_features=100, analyzer='char', ngram_range=(1, 2), lowercase=False),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers, colsample_bylevel=0.1)
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=1, max_ngram=32),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers, colsample_bylevel=0.1)
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=1, max_ngram=32),
        lambda: CatBoostClassifier(verbose=0, colsample_bylevel=0.1)
    ),
    (
        lambda: CountVectorizer(max_features=500, analyzer='char', ngram_range=(1, 2), lowercase=False),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers, colsample_bylevel=0.1)
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=1, max_ngram=32),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers, colsample_bylevel=0.1)
    ),
    (
        lambda: CountVectorizer(max_features=5000, analyzer='char', ngram_range=(1, 3), lowercase=False),
        lambda: CatBoostClassifier(verbose=0, depth=10, scale_pos_weight=args.nonmembers)
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=1, max_ngram=32),
        lambda: CatBoostClassifier(verbose=0, depth=10)
    ),
    (
        lambda: CountVectorizer(max_features=1000, analyzer='char', ngram_range=(1, 2), lowercase=False),
        lambda: CatBoostClassifier(verbose=0, depth=10)
    ),
    (
        lambda: CountVectorizer(max_features=1000, analyzer='char', ngram_range=(1, 2), lowercase=False),
        lambda: CatBoostClassifier(verbose=0, depth=10)
    ),
]


for blind_step, (f_converter, f_clf) in enumerate(blind_clfs):
    print(f'Blind step {blind_step} (converter: {f_converter()}, classifier: {f_clf()})')

    nonmembers = []
    members = []
    ids = []
    keep_ids = []
    for (secret_type, extra), curr_df in filtered_df.groupby(['secret_type', 'extra'], sort=False, observed=True):
        ids += [a for a in curr_df['id']]
        members += [a for a in curr_df['secret']]
        nonmembers += [a for a in DICT_GET_PRIOR[secret_type](extra, args.nonmembers * len(curr_df['secret']))]
    print(len(members), len(nonmembers))
    fold = KFold(n_splits=min(3, len(members)), shuffle=True)
    scores_members = np.zeros(len(members))
    scores_nonmembers = np.zeros(len(nonmembers))
    for train_index, test_index in tqdm(fold.split(members)):
        converter  = f_converter()
        x_train = converter.fit_transform([members[i] for i in train_index] + [nonmembers[args.nonmembers*i+j] for i in train_index for j in range(args.nonmembers)])
        x_test_1  = converter.transform([members[i] for i in test_index])
        x_test_0 =  converter.transform([nonmembers[args.nonmembers*i+j] for i in test_index for j in range(args.nonmembers)])
        try:
            x_train = x_train.toarray()
            x_test_0 = x_test_0.toarray()
            x_test_1 = x_test_1.toarray()
        except:
            pass
        
        y_train = np.array([1] * len(train_index) + [0] * (args.nonmembers * len(train_index)))
        clf = f_clf()
        clf.fit(x_train, y_train)
        scores_members[test_index] += clf.predict_proba(x_test_1)[:, 1]
        scores_nonmembers[[args.nonmembers*i + j for i in test_index for j in range(args.nonmembers)]] += clf.predict_proba(x_test_0)[:, 1]

    scores = np.concatenate([scores_members, scores_nonmembers])
    y = [1] * len(scores_members) + [0] * len(scores_nonmembers)
    roc_auc = get_roc_auc(y, scores)
    ranks = (np.array(scores_nonmembers).reshape(-1, args.nonmembers) >= np.array(scores_members).reshape(-1, 1)).sum(-1)
    print(f'ROC AUC: {roc_auc}, Rank: {ranks.mean()} ({np.median(ranks)})')
    ids = [a for a, r in zip(ids, ranks) if int(args.nonmembers * 0.05) < r and r < int(args.nonmembers * (1-0.05))]
    filtered_df = filtered_df[filtered_df['id'].isin(ids+keep_ids)]

    print(f'After removing with the step {blind_step} of the blind baseline:', len(filtered_df))
    print(filtered_df['secret_type'].value_counts())
    print(filtered_df['dataset'].value_counts())


print('Final size:', len(filtered_df))
print('Saving...')
for dataset, curr_df in filtered_df.groupby('dataset', observed=True):
    curr_df.to_pickle(f'{base_path}/{out_path}/{dataset}.pkl')

print(f'DONE: {datetime.now()}')