from sklearn.linear_model import SGDOneClassSVM
from sklearn.kernel_approximation import Nystroem
from sklearn.ensemble import IsolationForest
from sklearn.pipeline import make_pipeline
import numpy as np
import pandas as pd
import torch

def generating_latent_vector(datamodule, model):
    batch_size = datamodule.batch_size
    nb_batches = len(datamodule.dataset) // batch_size
    latent_vectors = np.zeros((nb_batches * batch_size, model.d_model))
    fraudulent = np.zeros(nb_batches * batch_size)
    ids = np.zeros((nb_batches * batch_size, model.seq_len))
    if hasattr(model, 'lob_embed'):
        add_lob_embed = True
    else:
        add_lob_embed = False
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    # Computing latent vectors for every batch and determining if sequence had a fraud in it
    for count, (features, _) in enumerate(datamodule):
        if count < nb_batches:
            frauds = features[:, :, -1].numpy()
            features = features.to(device) 
            input = features[:, :, :-2]
            if add_lob_embed:
                lob_repr = model.lob_embed.encode(input[..., :20])
                manual_feat = input[..., 20:]
                lob_repr_expand = lob_repr.unsqueeze(1).expand(-1, manual_feat.shape[1], -1) 
                concat_input = torch.cat([lob_repr_expand, manual_feat], dim=-1)
                context = model.encode(concat_input)
            else:
                context = model.encode(input)

            ids[(count * batch_size): (count + 1) * batch_size, :] = features[:, :, -2].detach().cpu().numpy()
            latent_vectors[(count * batch_size): (count + 1) * batch_size, :] = context.detach().cpu().numpy()
            fraudulent[(count * batch_size): (count + 1) * batch_size] = (np.count_nonzero(frauds, axis=1) >= 1) * \
                                                                         np.max(frauds, axis=1)
                                                                    
    return latent_vectors, fraudulent, ids

def generating_reconstruction_vector(datamodule, model):
    batch_size = datamodule.batch_size
    nb_batches = len(datamodule.dataset) // batch_size
    errors = np.zeros((nb_batches * batch_size))
    fraudulent = np.zeros(nb_batches * batch_size)
    ids = np.zeros((nb_batches * batch_size, model.seq_len))
    if hasattr(model, 'lob_embed'):
        add_lob_embed = True
    else:
        add_lob_embed = False
        
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    # Computing latent vectors for every batch and determining if sequence had a fraud in it
    for count, (features, _) in enumerate(datamodule):
        if count < nb_batches:
            frauds = features[:, :, -1].numpy()
            features = features.to(device) 
            input = features[:, :, :-2]
            if add_lob_embed:
                lob_repr = model.lob_embed.encode(input[..., :20])
                manual_feat = input[..., 20:]
                lob_repr_expand = lob_repr.unsqueeze(1).expand(-1, manual_feat.shape[1], -1) 
                concat_input = torch.cat([lob_repr_expand, manual_feat], dim=-1)
                context = model(concat_input)
                error = (concat_input - context).pow(2).sum(1).sum(1)
            else:
                context = model(input)
                error = (input - context).pow(2).sum(1).sum(1)
            errors[(count * batch_size): (count + 1) * batch_size] = error.detach().cpu().numpy()
            ids[(count * batch_size): (count + 1) * batch_size, :] = features[:, :, -2].detach().cpu().numpy()
            fraudulent[(count * batch_size): (count + 1) * batch_size] = (np.count_nonzero(frauds, axis=1) >= 1) * \
                                                                         np.max(frauds, axis=1)
                                                                    
    return errors, fraudulent, ids

def generating_sequence_vector(datamodule, seq_len, enc_in):
    batch_size = datamodule.batch_size
    nb_batches = len(datamodule.dataset) // batch_size
    sequences = np.zeros((nb_batches * batch_size, seq_len * enc_in))
    fraudulent = np.zeros(nb_batches * batch_size)
    ids = np.zeros((nb_batches * batch_size, seq_len))
    
    # Computing latent vectors for every batch and determining if sequence had a fraud in it
    for count, (features, _) in enumerate(datamodule):
        if count < nb_batches:
            frauds = features[:, :, -1].numpy()
            sequences[(count * batch_size): (count + 1) * batch_size, :] = features[:, :, :-2].reshape((batch_size, -1)).numpy()
            ids[(count * batch_size): (count + 1) * batch_size, :] = features[:, :, -2].numpy()
            fraudulent[(count * batch_size): (count + 1) * batch_size] = (np.count_nonzero(frauds, axis=1) >= 1) * \
                                                                         np.max(frauds, axis=1)                                                  
    return sequences, fraudulent, ids

def isolation_forest_detection(information, datamodule, model, test_data_path, **kwargs):
    print("Computing latent vectors on train set...")
    train_dl = datamodule.train_dataloader()
    train_set, _, _ = generating_latent_vector(train_dl, model)
    
    print("Computing latent vectors on valid set...")
    val_dl = datamodule.val_dataloader()
    valid_set, _, _ = generating_latent_vector(val_dl, model)
    train_valid_set = np.concatenate((train_set, valid_set))
    del train_set, valid_set, _
    
    print("Computing latent vectors on test set...")
    test_dl = datamodule.test_dataloader()
    test_set, fraudulent_test, ids = generating_latent_vector(test_dl, model)

    # Grouping frauds and defining their ids
    clean_set = test_set[fraudulent_test == 0]
    fraud_set_type1 = test_set[fraudulent_test == 1]
    del test_set

    ids_test = ids[fraudulent_test == 0, :]
    ids_fraud_set_type1 = ids[fraudulent_test == 1, :]
    del fraudulent_test, ids

    all_ids = np.concatenate((ids_test, ids_fraud_set_type1))

    # Training the Isolation Forest on the fraud-free set (train + valid sets)
    print("Training Isolation Forest model...")
    iso_forest = IsolationForest(n_estimators=500, contamination=0.05, random_state=42)
    iso_forest.fit(train_valid_set)
    del train_valid_set

    # Scoring the sequences in the test set
    print("Scoring test set with Isolation Forest...")
    scores_test_clean = -iso_forest.decision_function(clean_set)
    scores_fraud_type1 = -iso_forest.decision_function(fraud_set_type1)
    scores = np.concatenate((scores_test_clean, scores_fraud_type1))
    scores = scores.reshape((-1, 1))
    scores = np.repeat(scores, repeats=all_ids.shape[1], axis=1)
    scores_df = pd.DataFrame({'index': all_ids.flatten(), 'score': scores.flatten()}).groupby('index').mean().reset_index()
    test_set = pd.read_parquet(test_data_path, engine='fastparquet')[information]
    test_set = test_set.merge(scores_df, how='left', on='index')
    return test_set

def oc_svm(information, datamodule, model, test_data_path, **kwargs):
    # Computing latent vectors of all data sets
    print("Computing latent vectors on train set...")
    train_dl = datamodule.train_dataloader()
    train_set, _, _ = generating_latent_vector(train_dl, model)
    
    print("Computing latent vectors on valid set...")
    val_dl = datamodule.val_dataloader()
    valid_set, _, _ = generating_latent_vector(val_dl, model)
    train_valid_set = np.concatenate((train_set, valid_set))
    del train_set, valid_set, _
    
    print("Computing latent vectors on test set...")
    test_dl = datamodule.test_dataloader()
    test_set, fraudulent_test, ids = generating_latent_vector(test_dl, model)

    # Grouping frauds and defining their ids
    clean_set = test_set[fraudulent_test == 0]
    fraud_set_type1 = test_set[fraudulent_test == 1]
    del test_set

    ids_test = ids[fraudulent_test == 0, :]
    ids_fraud_set_type1 = ids[fraudulent_test == 1, :]
    del fraudulent_test, ids

    all_ids = np.concatenate((ids_test, ids_fraud_set_type1))
    
    # Training the OC-SVM on the fraud-free set consisting in the concatenation of the train and valid sets
    print("Training OC-SVM model...")
    nb_dim = train_valid_set.shape[0]
    var = train_valid_set.var()
    gamma = 10 / (nb_dim * var)
    transform = Nystroem(gamma=gamma, random_state=42)
    clf_sgd = SGDOneClassSVM(shuffle=True, fit_intercept=True, random_state=42, tol=1e-4)
    pipe_sgd = make_pipeline(transform, clf_sgd)
    oc_classifier = pipe_sgd.fit(train_valid_set)
    del train_valid_set

    # Dissimilarity scoring of the sequences in test set
    scores_test_clean = -oc_classifier.score_samples(clean_set)
    scores_fraud_type1 = -oc_classifier.score_samples(fraud_set_type1)
    scores = np.concatenate((scores_test_clean, scores_fraud_type1))
    scores = scores.reshape((-1, 1))
    scores = np.repeat(scores, repeats=all_ids.shape[1], axis=1)
    scores_df = pd.DataFrame({'index': all_ids.flatten(), 'score': scores.flatten()}).groupby('index').mean().reset_index()
    test_set = pd.read_parquet(test_data_path, engine='fastparquet')[information]
    test_set = test_set.merge(scores_df, how='left', on='index')
    return test_set

def recon_err_classification(information, datamodule, model, test_data_path, **kwargs):
    print("Computing reconstruction vectors on test set...")
    test_dl = datamodule.test_dataloader()
    error_test, fraudulent_test, ids = generating_reconstruction_vector(test_dl, model)

    # Grouping frauds and defining their ids
    score_clean_set = error_test[fraudulent_test == 0]
    score_fraud_type1 = error_test[fraudulent_test == 1]
    del error_test

    ids_test = ids[fraudulent_test == 0, :]
    ids_fraud_set_type1 = ids[fraudulent_test == 1, :]
    del fraudulent_test, ids

    all_ids = np.concatenate((ids_test, ids_fraud_set_type1))
    scores = np.concatenate((score_clean_set, score_fraud_type1))
    scores = scores.reshape((-1, 1))
    scores = np.repeat(scores, repeats=all_ids.shape[1], axis=1)
    scores_df = pd.DataFrame({'index': all_ids.flatten(), 'score': scores.flatten()}).groupby('index').mean().reset_index()
    test_set = pd.read_parquet(test_data_path, engine='fastparquet')[information]
    test_set = test_set.merge(scores_df, how='left', on='index')
    return test_set

def only_detection(information, datamodule, test_data_path, seq_len, enc_in, detection_method, **kwargs):
    print("Computing sequence vectors on train set...")
    train_dl = datamodule.train_dataloader()
    train_set, _, _ = generating_sequence_vector(train_dl, seq_len, enc_in)
    
    print("Computing sequence vectors on valid set...")
    val_dl = datamodule.val_dataloader()
    valid_set, _, _ = generating_sequence_vector(val_dl, seq_len, enc_in)
    train_valid_set = np.concatenate((train_set, valid_set))
    del train_set, valid_set, _
    
    print("Computing sequence vectors on test set...")
    test_dl = datamodule.test_dataloader()
    test_set, fraudulent_test, ids = generating_sequence_vector(test_dl, seq_len, enc_in)

    # Grouping frauds and defining their ids
    clean_set = test_set[fraudulent_test == 0]
    fraud_set_type1 = test_set[fraudulent_test == 1]
    del test_set

    ids_test = ids[fraudulent_test == 0, :]
    ids_fraud_set_type1 = ids[fraudulent_test == 1, :]
    del fraudulent_test, ids

    all_ids = np.concatenate((ids_test, ids_fraud_set_type1))

    if detection_method == 'oc_svm':
        # Training the OC-SVM on the fraud-free set consisting in the concatenation of the train and valid sets
        print("Training OC-SVM model...")
        nb_dim = train_valid_set.shape[0]
        var = train_valid_set.var()
        gamma = 10 / (nb_dim * var)
        transform = Nystroem(gamma=gamma, random_state=42)
        clf_sgd = SGDOneClassSVM(shuffle=True, fit_intercept=True, random_state=42, tol=1e-4)
        pipe_sgd = make_pipeline(transform, clf_sgd)
        oc_classifier = pipe_sgd.fit(train_valid_set)
        del train_valid_set

        # Dissimilarity scoring of the sequences in test set
        scores_test_clean = -oc_classifier.score_samples(clean_set)
        scores_fraud_type1 = -oc_classifier.score_samples(fraud_set_type1)
    elif detection_method == 'isolation_forest_detection':
        # Training the Isolation Forest on the fraud-free set (train + valid sets)
        print("Training Isolation Forest model...")
        iso_forest = IsolationForest(n_estimators=500, contamination=0.05, random_state=42)
        iso_forest.fit(train_valid_set)
        del train_valid_set

        # Scoring the sequences in the test set
        print("Scoring test set with Isolation Forest...")
        scores_test_clean = -iso_forest.decision_function(clean_set)
        scores_fraud_type1 = -iso_forest.decision_function(fraud_set_type1)
    scores = np.concatenate((scores_test_clean, scores_fraud_type1))
    scores = scores.reshape((-1, 1))
    scores = np.repeat(scores, repeats=all_ids.shape[1], axis=1)
    scores_df = pd.DataFrame({'index': all_ids.flatten(), 'score': scores.flatten()}).groupby('index').mean().reset_index()
    test_set = pd.read_parquet(test_data_path, engine='fastparquet')[information]
    test_set = test_set.merge(scores_df, how='left', on='index')
    return test_set