import numpy as np
import torch
from pyod.models.ecod import ECOD
from pyod.models.lof import LOF
from pyod.models.deep_svdd import DeepSVDD
from pyod.models.ocsvm import OCSVM
from pyod.models.knn import KNN
from pyod.models.auto_encoder import AutoEncoder
from pyod.models.iforest import IsolationForest
from pyod.models.dif import DIF
from pyod.models.copod import COPOD
import torch.nn as nn
import time
import pandas as pd
import numpy as np
import scipy
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils import data
from torchvision import transforms
from PIL import Image
import os
from collections import OrderedDict
from pyod.models.kde import KDE
from pyod.models.knn import KNN
import matplotlib.pyplot as plt
import torchvision.models as models
from sklearn.metrics import average_precision_score,roc_auc_score
# This is for the progress bar.
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
from pyod.models.kpca import KPCA
import csv
import torch
import math
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
import random
from scipy.stats import t
from sklearn.metrics import roc_curve, roc_auc_score
#from openTSNE import TSNE
import torchvision
from torch.utils import data
from torchvision import transforms
import numpy as np
from torch.utils.data import Subset
import torch.nn.functional as F
from matplotlib.pyplot import figure
from torch import Tensor
import torchvision.transforms as transforms
import numpy as np
import torchvision.datasets as datasets
import torch
import numpy as np
import torch.utils.data as utils
from torch.utils.data import Sampler, Dataset
from sklearn import preprocessing
from pathlib import Path

def read__data(file, normalization='z-score', seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x[train_norm_idx]
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]

    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))

    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test





folder_path = Path('datasets')
j=0
for file_path in ["datasets/36_speech.npz"]:#"datasets/29_Pima.npz","datasets/26_optdigits.npz"
#for file_path in ["datasets/30_satellite.npz","datasets/7_Cardiotocography.npz","datasets/29_Pima.npz", "datasets/35_SpamBase.npz"]:
#for file_path in folder_path.rglob('*'):  
#for k in range (0,10):
#for file_path in ["datasets/36_speech.npz","datasets/37_Stamps.npz","datasets/38_thyroid.npz","datasets/39_vertebral.npz","datasets/40_vowels.npz","datasets/41_Waveform.npz","datasets/42_WBC.npz","datasets/43_WDBC.npz","datasets/44_Wilt.npz","datasets/45_wine.npz","datasets/46_WPBC.npz","datasets/47_yeast.npz"]:
 #file_path=str(file_path)
#for file_path in ["datasets/30_satellite.npz","datasets/7_Cardiotocography.npz","datasets/29_Pima.npz", "datasets/35_SpamBase.npz","datasets/46_WPBC.npz"]: 
 #file_path=str(file_path)
 #file_path="datasets/NLP_by_RoBERTa/20news_0.npz"
 #k=5
 #if "celeba"   in file_path or  "census"   in file_path or "cover"   in file_path or    "donors"   in file_path or   "fraud"  in file_path or "http"  in file_path  or "skin"  in file_path:
     #continue
 #k=1
 #if "11_donors" in file_path:
 # break
 #   j=1
 #if j!=1:
 #    continue
 #file_path="datasets/3_backdoor.npz"
 #first_slash_idx = file_path.find('/')
 #file_path="datasets/41_Waveform.npz"
 #if "Ionosphere" in file_path:
 #   j=1
 #if j==0:
 #   continue
 #file_path="datasets/24_mnist.npz"
# first_slash_idx = file_path.find('\\')
 first_slash_idx = file_path.find('/')
 dot_idx = file_path.rfind('.')
 data = file_path[first_slash_idx + 1 : dot_idx]
 avg_auroc,avg_auprc=[],[]
 #for i in range (1,11):
 for n in range (0,5):
  #print(data)
   #train_data,train_lab,test_data,test_lab,Input_dim,std=read_OD_data(file_path,normalization='z-score')
   #x_train,y_train,_,_,_,std=read_OD_data(file_path,"z-score")
   x_train,y_train,x_test,y_test=read__data(file_path)
   #print(f"")
   #x_train,train_lab,x_test,y_test,Input_dim,std=read_noise_data(file_path,normalization='z-score',noise_rate=i/100)
   #print(x_train.shape[1])
   #model = DPAD(x_train,x_test,y_test,gamma=0.01,lamb=0.1,k=10,bs=8192,hidden_dims=[256, 128],num_classes=32,n_epochs=200,learning_rate=1e-4)
   #model.training()
   #score = model.decision_function(x_test)
   #model=Knn(random_state=None)
   #model=DeepSVDD(Input_dim)
   #model=IsolationForest()
   #model=OCSVM()
   #model=COPOD()
   model=ECOD()
   model.fit(x_train)
   time1=time.time()
   score=model.decision_function(x_test)
   time2=time.time()
   print("### TIME: "+str(time2-time1))
   #score = model.predict_score(x_train,x_test)
   AUROC=roc_auc_score(y_train,score)
   ap_score = average_precision_score(y_train,score)
   #avg_auprc.append(ap_score)
   #avg_auroc.append(AUROC)
   #print(AUROC)
   #print(ap_score)
   #with open('./OD-RESULT/COPOD/'+data+' '+'.txt', 'a+') as f:
   #with open('./NOISE RESULTS/AE/'+data+str(i)+' '+'.txt', 'a+') as f:
                #f.write("best_auroc: "+str(AUROC)+" best_auprc: "+str(ap_score)+'\n')
    #f.write("best_auroc: "+str(np.mean(avg_auroc))+" std: "+str(np.std(avg_auroc))+" best_auprc: "+str(np.mean(avg_auprc))+" std: "+str(np.std(avg_auprc))+'\n')