from algorithms.slad import SLAD
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import scipy
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils import data
from torchvision import transforms
from PIL import Image
import os
from collections import OrderedDict
import matplotlib.pyplot as plt
import torchvision.models as models
from sklearn.metrics import average_precision_score
# This is for the progress bar.
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
import csv
import torch
import math
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
import random
from scipy.stats import t
from sklearn.metrics import roc_curve, roc_auc_score
#from openTSNE import TSNE
import torchvision
from torch.utils import data
from torchvision import transforms
import numpy as np
from torch.utils.data import Subset
import torch.nn.functional as F
from matplotlib.pyplot import figure
from torch import Tensor
import torchvision.transforms as transforms
import numpy as np
import torchvision.datasets as datasets
import torch
import numpy as np
import torch.utils.data as utils
from torch.utils.data import Sampler, Dataset
from sklearn import preprocessing
from pathlib import Path
def read_noise_data(file, normalization='z-score', seed=42,noise_rate=0):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    ab_train=int(split*noise_rate)
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]
    train_ab_idx, test_ab_idx = anom_idx[:ab_train], anom_idx[ab_train:]
    x_train = x[np.hstack([train_norm_idx,train_ab_idx])]
    y_train = y[np.hstack([train_norm_idx, train_ab_idx])]
    data_dim=x_train.shape[1]
    x_test = x[np.hstack([test_norm_idx, test_ab_idx])]
    y_test = y[np.hstack([test_norm_idx, test_ab_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds


def read_after_gauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train=x[train_norm_idx]
    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(test_norm_idx), x.shape[1])
)  
    x[anom_idx]=x[anom_idx]+noise_test_abnorm
    x[test_norm_idx]=x[test_norm_idx]+noise_test_norm
    x[train_norm_idx]=x[train_norm_idx]+noise_train
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization

    return x_train, y_train, x_test, y_test

def read_gauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x[train_norm_idx]
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(test_norm_idx), x.shape[1])
)  
    x[anom_idx]=x[anom_idx]+noise_test_abnorm
    x[test_norm_idx]=x[test_norm_idx]+noise_test_norm
    x[train_norm_idx]=x[train_norm_idx]+noise_train
    x_train=x[train_norm_idx]
    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test

def read_after_gauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train=x[train_norm_idx]
    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(test_norm_idx), x.shape[1])
)  
    x[anom_idx]=x[anom_idx]+noise_test_abnorm
    x[test_norm_idx]=x[test_norm_idx]+noise_test_norm
    x[train_norm_idx]=x[train_norm_idx]+noise_train
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    #sds=None
    # normalization

    return x_train, y_train, x_test, y_test

def read_after_gauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train=x[train_norm_idx]
    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(test_norm_idx), x.shape[1])
)  
    x[anom_idx]=x[anom_idx]+noise_test_abnorm
    x[test_norm_idx]=x[test_norm_idx]+noise_test_norm
    x[train_norm_idx]=x[train_norm_idx]+noise_train
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    #sds=None
    # normalization

    return x_train, y_train, x_test, y_test

def read_data(file, normalization='z-score', seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x[train_norm_idx]
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]

    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))

    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test
folder_path = Path('datasets')
j=0
for file_path in ["3_backdoor"]:  
#for k in range (0,1):
 file_path=str("datasets/"+file_path+'.npz')
 #if "census" in file_path:
 #   j=1
 #if j!=1:
 #    continue
 #file_path="datasets/yelp.npz"
 first_slash_idx = file_path.find('/')
 #first_slash_idx = file_path.find('\\')
 dot_idx = file_path.rfind('.')
 data = file_path[first_slash_idx + 1 : dot_idx]
 #for k in range (1,2):
 best_auroc,best_auprc=0,0
 for lr in [0.01,0.005,0.001]:
  for hidden in [64,128,256]:
   avg_auroc,avg_auprc=[],[]
   for i in range (0,5):
  #print(data)
  #x_train,y_train,x_test,y_test=read_data(file_path,"z-score")
   #x_train,y_train,x_test,y_test=read_data(file_path,"z-score")
    x_train,y_train,x_test,y_test=read_data(file_path,normalization='z-score')#read_noise_data(file_path,"z-score",noise_rate=k/100)
    model = SLAD(lr=lr,hidden_dims=hidden)
    model.fit(x_train)
    score = model.decision_function(x_test)
    score = np.nan_to_num(score, nan=0.0)
    AUROC=roc_auc_score(y_test,score)
    ap_score = average_precision_score(y_test,score)
    avg_auprc.append(ap_score)
    avg_auroc.append(AUROC)
   best_auroc=max(best_auroc,np.mean(avg_auroc))
   best_auprc=max(best_auprc,np.mean(avg_auprc))
 with open('./tuned SLAD RESULTS/'+data+' lr-'+str(lr)+' hd-'+str(hidden)+' .txt', 'a+') as f:
                f.write("best_auroc: "+str(best_auroc)+" std: "+str(np.std(avg_auroc))+" best_auprc: "+str(best_auprc)+" std: "+str(np.std(avg_auprc))+'\n')
  #print("AUPRC: ", ap_score)
  #print("AUROC: ", AUROC)