import pandas as pd
import glob as glob
import traceback
import numpy as np
from secret_utils import parse_secrets, DICT_GET_PRIOR
from sklearn.model_selection import KFold
import sklearn.metrics as metrics
from sklearn.feature_extraction.text import CountVectorizer
from tqdm import tqdm
from joblib import Parallel, delayed
from datetime import datetime
from collections import Counter
import argparse
import gc
from xgboost import XGBClassifier
import sys
import os
from catboost import CatBoostClassifier
import re

parser = argparse.ArgumentParser(description='Filter secrets')
parser.add_argument('--files', type=str, help='File regex')
parser.add_argument('--output_file', type=str, help='Output file')
parser.add_argument('--rank_filter', type=float, default=0.01, help='Filter ranks')
parser.add_argument('--nonmembers', type=int, default=127, help='Number of nonmembers')
parser.add_argument('--base_path', type=str, help='Base path')
args = parser.parse_args()

base_path = args.base_path
os.makedirs(f'{base_path}/filtered_step_1', exist_ok=True)
n_jobs = 128

test = re.compile(args.files)
files = [f for f in glob.glob(f'{base_path}/secrets/*.csv.xz') if test.fullmatch(f.split('/')[-1]) is not None]
print(files)

def load_dataset(f):
    try:
        curr_df = pd.read_csv(f, compression='xz')
        if len(curr_df) == 0:
            return None
        curr_df = curr_df[curr_df['secret_type'].apply(lambda x: x in DICT_GET_PRIOR)]
        if len(curr_df) == 0:
            return None
        curr_df['dataset'] = '_'.join(f.split('/')[-1].split('_')[:-2]).split('----')[0]
        curr_df['secret'] = curr_df[['string', 'start', 'end']].apply(lambda x: x['string'][x['start']:x['end']], axis=1)
        curr_df.drop(columns=['start', 'end'], inplace=True)
        curr_df.astype({
            'string': 'str',
            'secret': 'str',
            'secret_type': 'category',
            'dataset': 'category'
        }, copy=False)
        gc.collect()
        return curr_df
    except ImportError as e:
        print(f'[{f}] Error: {e}', file=sys.stderr, flush=True)
        traceback.print_exc()
        return None
        
df = Parallel(n_jobs=32)(delayed(load_dataset)(f) for f in files)

if len(df) == 0:
    print('Empty', file=sys.stderr, flush=True)
    sys.exit(0)

df = [a for a in df if a is not None and len(a) > 0]
if len(df) == 0:
    print('Empty dataset')
    sys.exit(0)

df = pd.concat(df).astype({
    'string': 'str',
    'secret': 'str',
    'secret_type': 'category',
    'dataset': 'category'
}, copy=False)
print(df['dataset'].unique())
print(df)
print(df.columns)
print(df.dtypes)

if len(df) > 300_000:
    df = df.sample(n=300_000)
gc.collect()


def parse_filter(secret_type, dataset, curr_df):
    filtered_df = {
        'prefix': [],
        'suffix': [],
        'extra': [],
        'secret_type': [],
        'dataset': [],
        'secret': [],
    }
    print(secret_type, dataset, len(curr_df), file=sys.stderr, flush=True)
    try:
        prefix_qa = parse_secrets(curr_df['secret'], curr_df['string'], secret_type)
        for (prefix, suffix, extra), secrets in prefix_qa.items():
            for secret in secrets:
                filtered_df['prefix'].append(prefix)
                filtered_df['suffix'].append(suffix)
                filtered_df['extra'].append(extra if extra is not None else 'None')
                filtered_df['secret_type'].append(secret_type)
                filtered_df['dataset'].append(dataset)
                filtered_df['secret'].append(secret)
            gc.collect()
    except Exception as e:
        print('Skipping', secret_type, dataset, file=sys.stderr, flush=True)
        traceback.print_exc()
    return pd.DataFrame(filtered_df)


filtered_df = Parallel(n_jobs=32)(delayed(parse_filter)(secret_type, dataset, curr_df) for (secret_type, dataset), curr_df in df.groupby(['secret_type', 'dataset'], observed=True))
filtered_df = pd.concat(filtered_df, copy=False)

print('Initial Size:', len(filtered_df))
print(filtered_df['secret_type'].value_counts())


# Remove duplicates
frequent_secrets = set()
for dataset, curr_df in filtered_df.groupby(['dataset', 'secret_type'], observed=True):
    print(dataset)
    freq = curr_df['secret'].value_counts()
    freq = freq[freq > 1].sort_values(ascending=False)
    frequent_secrets.update(freq.index.tolist())
filtered_df = filtered_df[~filtered_df['secret'].isin(frequent_secrets)]
print('After removing duplicates:', len(filtered_df))
print(filtered_df['secret_type'].value_counts())

filtered_df['id'] = range(len(filtered_df))


# Blind baseline
def get_roc_auc(y_true, y_pred_proba):
    fpr, tpr, _ = metrics.roc_curve(y_true, y_pred_proba) 
    return metrics.auc(fpr, tpr)

def get_tpr_metric(y_true, y_pred_proba, fpr_budget):
    fpr, tpr, _ = metrics.roc_curve(y_true, y_pred_proba) 
    tpr_at_low_fpr = np.interp(fpr_budget/100, fpr,tpr)
    return tpr_at_low_fpr


from collections import Counter
class TextPositionVectorizer():
    def __init__(self, ngram_len=1, max_len=50, max_ngram=50):
        self.ngram_len = ngram_len
        self.max_ngram = max_ngram
    
    def fit(self, x):
        self.max_len = int(np.percentile([len(a) for a in x], 90)) - self.ngram_len
        self.max_values = []
        for i in range(self.max_len):
            self.max_values.append([a for a, _ in Counter([a[i:i+self.ngram_len] for a in x]).most_common(self.max_ngram)])
        return self

    def transform(self, x):
        x = np.stack([np.concatenate([self.parse_val(a, i) for i in range(self.max_len)]) for a in x])
        return x
    
    def fit_transform(self, x):
        return self.fit(x).transform(x)
    
    def get_feature_names_out(self):
        return [(i, c) for i, v in enumerate(self.max_values) for c in v]

    def parse_val(self, a, i):
        res = np.zeros(len(self.max_values[i]), dtype=bool)
        if len(a) < i:
            return res
        if a[i:i+self.ngram_len] in self.max_values[i]:
            pos = self.max_values[i].index(a[i:i+self.ngram_len])
            res[pos] = True
        return res


# Blind Baseline with all the samples
blind_clfs = [
    (
        lambda: CountVectorizer(max_features=100, analyzer='char', ngram_range=(1, 2), lowercase=False),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers)
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=1, max_ngram=50),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers)
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=2, max_ngram=50),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers)
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=3, max_ngram=50),
        lambda: XGBClassifier(scale_pos_weight=args.nonmembers),
    ),
    (
        lambda: CountVectorizer(max_features=1000, analyzer='char', ngram_range=(1, 3), lowercase=False),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers)
    ),
    (
        lambda: CountVectorizer(max_features=100, analyzer='char', ngram_range=(1, 2), lowercase=False),
        lambda: XGBClassifier()
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=1, max_ngram=50),
        lambda: XGBClassifier()
    ),
]

for blind_step, (f_converter, f_clf) in enumerate(blind_clfs):
    print(f'Blind step {blind_step} (converter: {f_converter()}, classifier: {f_clf()})')
    average_roc_auc = {}

    new_filtered_df = []
    for secret_type, secret_type_df in filtered_df.groupby('secret_type', sort=False, observed=True):
        if len(secret_type_df) < 2:
            new_filtered_df.append(secret_type_df)
            continue
        nonmembers = []
        members = []
        ids = []

        for extra, curr_df in secret_type_df.groupby('extra', sort=False, observed=True):
            ids += [a for a in curr_df['id']]
            members += [a for a in curr_df['secret']]
            nonmembers += [a for a in DICT_GET_PRIOR[secret_type](extra, args.nonmembers * len(curr_df['secret']))]

        print(secret_type, len(members), len(nonmembers))
        fold = KFold(n_splits=min(3, len(members)), shuffle=True)
        scores_members = np.zeros(len(members))
        scores_nonmembers = np.zeros(len(nonmembers))
        for train_index, test_index in tqdm(fold.split(members)):
            converter  = f_converter()
            x_train = converter.fit_transform([members[i] for i in train_index] + [nonmembers[args.nonmembers*i+j] for i in train_index for j in range(args.nonmembers)])
            x_test_1  = converter.transform([members[i] for i in test_index])
            x_test_0 =  converter.transform([nonmembers[args.nonmembers*i+j] for i in test_index for j in range(args.nonmembers)])
            try:
                x_train = x_train.toarray()
                x_test_0 = x_test_0.toarray()
                x_test_1 = x_test_1.toarray()
            except:
                pass
            
            y_train = np.array([1] * len(train_index) + [0] * (len(train_index) * args.nonmembers))
            clf = f_clf()
            clf.fit(x_train, y_train)
            scores_members[test_index] += clf.predict_proba(x_test_1)[:, 1]
            scores_nonmembers[[args.nonmembers*i + j for i in test_index for j in range(args.nonmembers)]] += clf.predict_proba(x_test_0)[:, 1]
        
        scores = np.concatenate([scores_members, scores_nonmembers])
        y = [1] * len(scores_members) + [0] * len(scores_nonmembers)
        roc_auc = get_roc_auc(y, scores)
        ranks = (np.array(scores_nonmembers).reshape(-1, args.nonmembers) >= np.array(scores_members).reshape(-1, 1)).sum(-1)
        print(f'ROC AUC: {roc_auc}, Rank: {ranks.mean()} ({np.median(ranks)})')
        # print(f'Ranks: {list(ranks)}')
        ids = [a for a, r in zip(ids, ranks) if r > int(1+args.nonmembers * args.rank_filter)]
        new_filtered_df.append(secret_type_df[secret_type_df['id'].isin(ids)])
    filtered_df = pd.concat(new_filtered_df)

    print(f'After removing with the step {blind_step} of the blind baseline per secret type:', len(filtered_df))
    print(filtered_df['secret_type'].value_counts())

blind_clfs = [
    (
        lambda: CountVectorizer(max_features=100, analyzer='char', ngram_range=(1, 2), lowercase=False),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers)
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=1, max_ngram=50),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers)
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=2, max_ngram=50),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers)
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=3, max_ngram=50),
        lambda: XGBClassifier(scale_pos_weight=args.nonmembers),
    ),
    (
        lambda: CountVectorizer(max_features=1000, analyzer='char', ngram_range=(1, 3), lowercase=False),
        lambda: CatBoostClassifier(verbose=0, scale_pos_weight=args.nonmembers)
    ),
    (
        lambda: CountVectorizer(max_features=100, analyzer='char', ngram_range=(1, 2), lowercase=False),
        lambda: XGBClassifier()
    ),
    (
        lambda: TextPositionVectorizer(ngram_len=1, max_ngram=50),
        lambda: XGBClassifier()
    ),
]


for blind_step, (f_converter, f_clf) in enumerate(blind_clfs):
    print(f'Blind step {blind_step} (converter: {f_converter()}, classifier: {f_clf()})')

    nonmembers = []
    members = []
    ids = []
    for (secret_type, extra), curr_df in filtered_df.groupby(['secret_type', 'extra'], sort=False, observed=True):
        ids += [a for a in curr_df['id']]
        members += [a for a in curr_df['secret']]
        nonmembers += [a for a in DICT_GET_PRIOR[secret_type](extra, args.nonmembers * len(curr_df['secret']))]
    print(len(members), len(nonmembers))
    fold = KFold(n_splits=min(3, len(members)), shuffle=True)
    scores_members = np.zeros(len(members))
    scores_nonmembers = np.zeros(len(nonmembers))
    for train_index, test_index in tqdm(fold.split(members)):
        converter  = f_converter()
        x_train = converter.fit_transform([members[i] for i in train_index] + [nonmembers[args.nonmembers*i+j] for i in train_index for j in range(args.nonmembers)])
        x_test_1  = converter.transform([members[i] for i in test_index])
        x_test_0 =  converter.transform([nonmembers[args.nonmembers*i+j] for i in test_index for j in range(args.nonmembers)])
        try:
            x_train = x_train.toarray()
            x_test_0 = x_test_0.toarray()
            x_test_1 = x_test_1.toarray()
        except:
            pass
        
        y_train = np.array([1] * len(train_index) + [0] * (len(train_index) * args.nonmembers))
        clf = f_clf()
        clf.fit(x_train, y_train)
        scores_members[test_index] += clf.predict_proba(x_test_1)[:, 1]
        scores_nonmembers[[args.nonmembers*i + j for i in test_index for j in range(args.nonmembers)]] += clf.predict_proba(x_test_0)[:, 1]

    scores = np.concatenate([scores_members, scores_nonmembers])
    y = [1] * len(scores_members) + [0] * len(scores_nonmembers)
    roc_auc = get_roc_auc(y, scores)
    ranks = (np.array(scores_nonmembers).reshape(-1, args.nonmembers) >= np.array(scores_members).reshape(-1, 1)).sum(-1)
    print(f'ROC AUC: {roc_auc}, Rank: {ranks.mean()} ({np.median(ranks)})')
    ids = [a for a, r in zip(ids, ranks) if r > int(1+args.nonmembers * args.rank_filter)]
    filtered_df = filtered_df[filtered_df['id'].isin(ids)]

    print(f'After removing with the step {blind_step} of the blind baseline:', len(filtered_df))
    print(filtered_df['secret_type'].value_counts())


print('Final size:', len(filtered_df))
print('Saving...')
filtered_df.to_pickle(f'{base_path}/filtered_step_1/{args.output_file}.pkl')

print(f'DONE: {datetime.now()}')

