import pandas as pd
import re
from urllib.parse import urlparse
import os

def preprocess():
    '''
    [All positive samples]
        all_pos_key.csv, all_pos_X.csv
        100% of positive samples
    [Samples for Training ML model]
        X_train_key.csv, X_train.csv, y_train.csv
        10,000 positive samples and 10,000 negative samples
    [Samples for Validation]
        X_val_key.csv, X_val.csv, y_val.csv
        10,000 positive samples and 10,000 negative samples, which are not used for training ML model
    [Samples for Testing]
        X_test_key.csv, X_test.csv, y_test.csv
        100,000 positive samples and 100,000 negative samples
    '''
    data_path = 'data/url/malicious_phish.csv'
    all_pos_key_path = 'data/url/preprocessed/all_pos_key.csv'
    all_pos_X_path = 'data/url/preprocessed/all_pos_X.csv'
    X_train_key_path = 'data/url/preprocessed/X_train_key.csv'
    X_train_path = 'data/url/preprocessed/X_train.csv'
    y_train_path = 'data/url/preprocessed/y_train.csv'
    X_val_key_path = 'data/url/preprocessed/X_val_key.csv'
    X_val_path = 'data/url/preprocessed/X_val.csv'
    y_val_path = 'data/url/preprocessed/y_val.csv'
    X_test_key_path = 'data/url/preprocessed/X_test_key.csv'
    X_test_path = 'data/url/preprocessed/X_test.csv'
    y_test_path = 'data/url/preprocessed/y_test.csv'
    os.makedirs('data/url/preprocessed', exist_ok=True)

    data = pd.read_csv(data_path)
    rem = {"Category": {"benign": 0, "malicious": 1}}
    feature = ['@','?','-','=','.','#','%','+','$','!','*',',','//']

    def abnormal_url(url):
        hostname = urlparse(url).hostname
        hostname = str(hostname)
        match = re.search(hostname, url)
        if match:
            # print match.group()
            return 1
        else:
            # print 'No matching pattern found'
            return 0
        
    def httpSecure(url):
        htp = urlparse(url).scheme
        match = str(htp)
        if match=='https':
            # print match.group()
            return 1
        else:
            # print 'No matching pattern found'
            return 0
        
    def digit_count(url):
        digits = 0
        for i in url:
            if i.isnumeric():
                digits = digits + 1
        return digits

    def letter_count(url):
        letters = 0
        for i in url:
            if i.isalpha():
                letters = letters + 1
        return letters

    def Shortining_Service(url):
        match = re.search('bit\.ly|goo\.gl|shorte\.st|go2l\.ink|x\.co|ow\.ly|t\.co|tinyurl|tr\.im|is\.gd|cli\.gs|'
                        'yfrog\.com|migre\.me|ff\.im|tiny\.cc|url4\.eu|twit\.ac|su\.pr|twurl\.nl|snipurl\.com|'
                        'short\.to|BudURL\.com|ping\.fm|post\.ly|Just\.as|bkite\.com|snipr\.com|fic\.kr|loopt\.us|'
                        'doiop\.com|short\.ie|kl\.am|wp\.me|rubyurl\.com|om\.ly|to\.ly|bit\.do|t\.co|lnkd\.in|'
                        'db\.tt|qr\.ae|adf\.ly|goo\.gl|bitly\.com|cur\.lv|tinyurl\.com|ow\.ly|bit\.ly|ity\.im|'
                        'q\.gs|is\.gd|po\.st|bc\.vc|twitthis\.com|u\.to|j\.mp|buzurl\.com|cutt\.us|u\.bb|yourls\.org|'
                        'x\.co|prettylinkpro\.com|scrnch\.me|filoops\.info|vzturl\.com|qr\.net|1url\.com|tweez\.me|v\.gd|'
                        'tr\.im|link\.zip\.net',
                        url)
        if match:
            return 1
        else:
            return 0
        
    def having_ip_address(url):
        match = re.search(
            '(([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.'
            '([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\/)|'  # IPv4
            '(([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\.'
            '([01]?\\d\\d?|2[0-4]\\d|25[0-5])\\/)|'  # IPv4 with port
            '((0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\.(0x[0-9a-fA-F]{1,2})\\/)' # IPv4 in hexadecimal
            '(?:[a-fA-F0-9]{1,4}:){7}[a-fA-F0-9]{1,4}|'
            '([0-9]+(?:\.[0-9]+){3}:[0-9]+)|'
            '((?:(?:\d|[01]?\d\d|2[0-4]\d|25[0-5])\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d|\d)(?:\/\d{1,2})?)', url)  # Ipv6
        if match:
            return 1
        else:
            return 0

    data.loc[data['type'] != 'benign', 'type'] = 'malicious'
    data['url'] = data['url'].replace('www.', '', regex=True)
    data['url'] = data['url'].str.strip().replace('\n', '', regex=True)
    data['url'] = data['url'].str.strip().replace('\t', '', regex=True)
    data['url'] = data['url'].str.strip().replace('\r', '', regex=True)
    data['url'] = data['url'].str.strip().replace(' ', '', regex=True)

    data['Category'] = data['type']
    data = data.replace(rem)

    data['url_len'] = data['url'].apply(lambda x: len(str(x)))

    for a in feature:
        data[a] = data['url'].apply(lambda i: i.count(a))

    data['abnormal_url'] = data['url'].apply(lambda i: abnormal_url(i))
    data['https'] = data['url'].apply(lambda i: httpSecure(i))
    data['digits']= data['url'].apply(lambda i: digit_count(i))
    data['letters']= data['url'].apply(lambda i: letter_count(i))
    data['Shortining_Service'] = data['url'].apply(lambda x: Shortining_Service(x))
    data['having_ip_address'] = data['url'].apply(lambda i: having_ip_address(i))

    all_pos_df = data.loc[(data['Category'] == 1)]
    all_neg_df = data.loc[(data['Category'] == 0)]

    print("=== URL dataset ===")
    print("all_pos_df shape: ", all_pos_df.shape)
    print("all_neg_df shape: ", all_neg_df.shape)

    # pos: train 100%, val 100%, test 10%
    # neg: train 80%, val 10%, test 10%
    train_pos = all_pos_df
    train_neg = all_neg_df.sample(frac=0.8, random_state=42)
    val_pos = all_pos_df
    val_neg = all_neg_df.drop(train_neg.index).sample(frac=0.5, random_state=42)
    test_pos = all_pos_df.sample(frac=0.1, random_state=42)
    test_neg = all_neg_df.drop(train_neg.index).drop(val_neg.index)

    print("train_pos shape: ", train_pos.shape)
    print("train_neg shape: ", train_neg.shape)
    print("val_pos shape: ", val_pos.shape)
    print("val_neg shape: ", val_neg.shape)
    print("test_neg shape: ", test_neg.shape)

    train_data = pd.concat([train_pos, train_neg])
    val_data = pd.concat([val_pos, val_neg])
    test_data = pd.concat([test_pos, test_neg])

    # Shuffle
    train_data = train_data.sample(frac=1, random_state=42).reset_index(drop=True)
    val_data = val_data.sample(frac=1, random_state=42).reset_index(drop=True)
    test_data = test_data.sample(frac=1, random_state=42).reset_index(drop=True)

    all_pos_key = all_pos_df['url']
    all_pos_X = all_pos_df.drop(['url','type','Category'],axis=1)
    X_train_key = train_data['url']
    X_train = train_data.drop(['url','type','Category'],axis=1)
    y_train = train_data['Category']
    X_val_key = val_data['url']
    X_val = val_data.drop(['url','type','Category'],axis=1)
    y_val = val_data['Category']
    X_test_key = test_data['url']
    X_test = test_data.drop(['url','type','Category'],axis=1)
    y_test = test_data['Category']

    print("[ALL POSITIVE]")
    print("all_pos_key shape: ", all_pos_key.shape)
    print("all_pos_X shape: ", all_pos_X.shape)
    print("[TRAIN]")
    print("X_train_key shape: ", X_train_key.shape)
    print("X_train shape: ", X_train.shape)
    print("y_train shape: ", y_train.shape, ", pos: ", y_train[y_train == 1].shape[0], ", neg: ", y_train[y_train == 0].shape[0])
    print("[VAL]")
    print("X_val_key shape: ", X_val_key.shape)
    print("X_val shape: ", X_val.shape)
    print("y_val shape: ", y_val.shape, ", pos: ", y_val[y_val == 1].shape[0], ", neg: ", y_val[y_val == 0].shape[0])
    print("[TEST]")
    print("X_test_key shape: ", X_test_key.shape)
    print("X_test shape: ", X_test.shape)
    print("y_test shape: ", y_test.shape, ", pos: ", y_test[y_test == 1].shape[0], ", neg: ", y_test[y_test == 0].shape[0])

    all_pos_key.to_csv(all_pos_key_path, index = False, header=False)
    all_pos_X.to_csv(all_pos_X_path, index = False, header=False)
    X_train_key.to_csv(X_train_key_path, index = False, header=False)
    X_train.to_csv(X_train_path, index = False, header=False)
    y_train.to_csv(y_train_path, index = False, header=False)
    X_val_key.to_csv(X_val_key_path, index = False, header=False)
    X_val.to_csv(X_val_path, index = False, header=False)
    y_val.to_csv(y_val_path, index = False, header=False)
    X_test_key.to_csv(X_test_key_path, index = False, header=False)
    X_test.to_csv(X_test_path, index = False, header=False)
    y_test.to_csv(y_test_path, index = False, header=False)

if __name__ == "__main__":
    preprocess()
