import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import scipy
import torch.nn.functional as F
import torch.storage
from torch.utils.data import Dataset, DataLoader
from torch.utils import data
from torchvision import transforms
from PIL import Image
import os
from collections import OrderedDict
import matplotlib.pyplot as plt
import torchvision.models as models
import time
from pathlib import Path
# This is for the progress bar.
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
import csv
import torch
import math
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
import random
from sklearn.metrics import average_precision_score
from scipy.stats import t
from sklearn.metrics import roc_curve, roc_auc_score
#from openTSNE import TSNE
import torchvision
from torch.utils import data
from torchvision import transforms
import numpy as np
from torch.utils.data import Subset
import torch.nn.functional as F
from matplotlib.pyplot import figure
from torch import Tensor
import torchvision.transforms as transforms
import numpy as np
import torchvision.datasets as datasets
import torch
import numpy as np
import torch.utils.data as utils
from torch.utils.data import Sampler, Dataset
from model import _RealNVP
from sklearn import preprocessing
#from torchmetrics.functional import auroc
#from torchmetrics import AUROC,AveragePrecision
class CustomDataset(Dataset):
    def __init__(self,
                 X,
                 y):
        self.data=X
        self.targets=y
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]), (self.targets[idx])
def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

device = get_device()
def read_noise_data(file, normalization='z-score', seed=42,noise_rate=0):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    ab_train=int(split*noise_rate)
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]
    train_ab_idx, test_ab_idx = anom_idx[:ab_train], anom_idx[ab_train:]
    x_train = x[np.hstack([train_norm_idx,train_ab_idx])]
    y_train = y[np.hstack([train_norm_idx, train_ab_idx])]
    data_dim=x_train.shape[1]
    x_test = x[np.hstack([test_norm_idx, test_ab_idx])]
    y_test = y[np.hstack([test_norm_idx, test_ab_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds



def read_gauss_data(file, normalization='z-score', train_level=1,test_level=2,seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x[train_norm_idx]
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]
    noise_train= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(train_norm_idx), x.shape[1])
)  
    noise_test_abnorm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(train_level),
    size=(len(anom_idx), x.shape[1])
)  
    noise_test_norm= np.random.normal(
    loc=0.0,
    scale=np.sqrt(test_level),
    size=(len(test_norm_idx), x.shape[1])
)  
    x[anom_idx]=x[anom_idx]+noise_test_abnorm
    x[test_norm_idx]=x[test_norm_idx]+noise_test_norm
    x[train_norm_idx]=x[train_norm_idx]+noise_train
    x_train=x[train_norm_idx]
    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds


def read_OD_data(file, normalization='z-score', seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x
    #data_dim=x_train.shape[1]
    y_train = y

    #x_test = x[np.hstack([test_norm_idx, anom_idx])]
    #y_test = y[np.hstack([test_norm_idx, anom_idx])]

    #print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
    #      f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))

    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        #x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
       # x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        #x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        #x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train,x_train, y_train,x.shape[1],sds

def read_data(file, normalization='z-score', seed=42):
    if file.endswith('.npz'):
        data = np.load(file, allow_pickle=True)
        x, y = data['X'], data['y']
        y = np.array(y, dtype=int)
    else:
        if file.endswith('pkl'):
            func = pd.read_pickle
        elif file.endswith('csv'):
            func = pd.read_csv
        else:
            raise NotImplementedError('')

        df = func(file)
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(method='ffill', inplace=True)
        x = df.values[:, :-1]
        y = np.array(df.values[:, -1], dtype=int)

    # train-test splitting
    rng = np.random.RandomState(seed)
    idx = rng.permutation(np.arange(len(x)))
    #idx = np.random.permutation(np.arange(len(x)))
    #print(idx[0:10])
    x, y = x[idx], y[idx]

    norm_idx = np.where(y==0)[0]
    anom_idx = np.where(y==1)[0]
    split = int(0.5 * len(norm_idx))
    train_norm_idx, test_norm_idx = norm_idx[:split], norm_idx[split:]

    x_train = x[train_norm_idx]
    data_dim=x_train.shape[1]
    y_train = y[train_norm_idx]

    x_test = x[np.hstack([test_norm_idx, anom_idx])]
    y_test = y[np.hstack([test_norm_idx, anom_idx])]

    print(f'Original size: [{x.shape}], Normal/Anomaly: [{len(norm_idx)}/{len(anom_idx)}] \n'
          f'After splitting: training/testing [{len(x_train)}/{len(x_test)}]')
    #print(str(torch.rand(1)))
    sds=None
    # normalization
    if normalization == 'min-max':
        minmax_scaler = MinMaxScaler()
        minmax_scaler.fit(x_train)
        x_train = minmax_scaler.transform(x_train)
        x_test = minmax_scaler.transform(x_test)

    elif normalization == 'z-score':
        mus = np.mean(x_train, axis=0)
        sds = np.std(x_train, axis=0)
        sds[sds == 0] = 1
        x_train = np.array([(xx - mus) / sds for xx in x_train])
        x_test = np.array([(xx - mus) / sds for xx in x_test])

    elif normalization == 'scale':
        x_train = x_train / 255
        x_test = x_test / 255
    elif normalization =='ours':
        mean=np.mean(x_train,0)
        std=np.std(x_train,0)
        x_train=(x_train-mean)/ (std + 1e-4)
        x_test= (x_test - mean)/(std + 1e-4)

    return x_train, y_train, x_test, y_test,data_dim,sds

def contribution_calculation(model,train_loader,pow,std):
    model=model
    jac_sum=None
    num=0
    std=std.to(device)
    std=std.float()
    train_likelihood,x,y=[],[],[]
    for batch in tqdm(train_loader):
        model = model.to(device)
        model.train()
        imgs,lab = batch
        imgs=imgs.float()
        imgs = imgs.to(device)
        imgs.requires_grad=True
        outputs,_= model(imgs,0)
        #log_likelihood=-0.5 * ((outputs-model.mu) ** 2/(2*model.sigma**2) + torch.log(2 * np.pi*model.sigma**2))
        #train_likelihood+=log_likelihood
        num+=outputs.shape[0]
        outVector = torch.sum(outputs,0).view(-1)
        batch_size = outputs.size(0)
        outdim = outVector.size()[0]
        x.append(imgs)
        y.append(lab)
        #jac=torch.zeros([5,2])
        jac = torch.stack([torch.autograd.grad(outVector[i], imgs,
                                    retain_graph=True, create_graph=False)[0].view(batch_size, outdim) for i in range(outdim)], dim=0)#last 784 is the input dimension
        #print(jac.shape)
        jac=jac.permute(1,0,2)
        jac=jac@std
        jac=torch.abs(jac)
        if jac_sum==None:
            jac_sum=torch.sum(jac,dim=0)
        else:
            jac_sum+=torch.sum(jac,dim=0)
    jac_sum=jac_sum/num
    jac_sum=torch.norm(jac_sum,dim=1)
    jac_sum=1/(jac_sum+1)
    jac_sum=torch.exp(jac_sum)/torch.sum(torch.exp(jac_sum))
    return jac_sum,je,0,0,x,y


def training(model,train_loader,lr,adam=1,wd=0,pun=0.01, grad_pun=0.01,PNAL=None,std=None):
    model = model.to(device)
    model.device = device
    model.train()
    std=std.to(device)
    std=std.float()
    if adam==1:
        optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay=wd,amsgrad=1)
    if adam==0:
        optimizer = torch.optim.SGD(model.parameters(), lr = lr,momentum=0.8)
    best_loss = 1000000
    loss_sum=0
    best_auc=0
    i=0
    je=None
    for batch in tqdm(train_loader):
     #if i<=4:
        optimizer.zero_grad()
        loss = 0.0
        rs=0
        imgs,_ = batch
        imgs = imgs.to(device)
        imgs=imgs.float()
        imgs.requires_grad=True
        outputs,sldj= model(imgs,sldj=0)
        log_likelihood=-0.5 * (torch.pow(outputs,2) + torch.log(torch.tensor(torch.pi*2)))
        sample_likelihood=torch.sum(log_likelihood,dim=1)
        batch_size = outputs.size(0)
        outVector = torch.sum(outputs,0).view(-1)
        outdim = outVector.size()[0]
        if je==None:
           je=torch.zeros(outdim)
        jac=0
        F_norm=0
        L_2=1
        if grad_pun!=0:
            jac = torch.stack([torch.autograd.grad(outVector[i], imgs,
                                    retain_graph=True, create_graph=True)[0].view(batch_size, outdim) for i in range(outdim)], dim=0)#last 784 is the input dimension
            jac=jac.permute(1,0,2)
            jac=torch.matmul(jac,std)
            jac=torch.abs(jac)
            jac=torch.mean(jac,dim=0)
            if PNAL=='L_2':
             jac_norm=torch.norm(jac,dim=1)
             jac=torch.sum(jac_norm)
             #print(jac_norm.shape)
             #L_2=1
            if PNAL=='L_1sq':
             #print(" 2")
             #jac=torch.pow(jac,2)
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
            if PNAL=='L2cL1':
             #jac=1/(jac+1)
             #print(" 3")
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac/L_2
            if PNAL=='L2-L1':
             #jac=1/(jac+1)
             #print(" 4")
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac-L_2
            if PNAL=='overL2cL1':
             #print(" 5")
             jac=1/(jac+1)
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac/L_2
            if PNAL=='overL2-L1':
             #print(" 6")
             jac=1/(jac+1)
             jac=torch.pow(jac,2)
             L_2=torch.sqrt(torch.sum(jac))
             jac=torch.sum(torch.sqrt(torch.sum(jac,dim=1)))
             jac=jac-L_2
            if PNAL=='half/half':
             #print("7")
             row_norms = torch.norm(jac, p=2, dim=1)
             sorted_indices = torch.argsort(row_norms)
             half_size = len(row_norms) // 2
             min_half_indices = sorted_indices[:half_size]  
             max_half_indices = sorted_indices[half_size:] 
             min_half_sum = row_norms[min_half_indices].sum()
             max_half_square_sum = (row_norms[max_half_indices] ** 2).sum()
             jac=min_half_sum/max_half_square_sum
            if PNAL=='half-half':
             #print(" 8")
             row_norms = torch.norm(jac, p=2, dim=1)
             sorted_indices = torch.argsort(row_norms)
             half_size = len(row_norms) // 2
             min_half_indices = sorted_indices[:half_size]  
             max_half_indices = sorted_indices[half_size:] 
             min_half_sum = row_norms[min_half_indices].sum()
             max_half_square_sum = (row_norms[max_half_indices] ** 2).sum()
             jac=min_half_sum-max_half_square_sum
             #jac=torch.sum(torch.sqrt(jac))
            #print(jac.shape)
            #jac=torch.sum(jac)
            #jac=torch.sum(jac)-F_norm
        loss=torch.mean(-(sample_likelihood+sldj))
        #print(sldj<=0)
        #print(jac)
        #print(jac)
        #print(sample_likelihood.max())
        #print("Sample density mean: "+str(sample_likelihood.mean()))
        #print("Determinants mean: "+str(sldj.mean()))
        #jac=jac/L_2
        loss+=(jac)*grad_pun
        #print(jac) 
       # print(jac/L_2)
        loss.backward()
        optimizer.step()
        #break
        #print(model.mu)
        loss_sum+=loss.detach().cpu()
        i+=1
    print(f" Train |  loss = {loss:.5f}  grad={jac*grad_pun:.5f}")
    return je

def testing(model,train_loader,test_loader,jac_sum=None): 
 model = model
 device = 'cuda'
 model = model.to(device)
 model.eval()
 preds,targets,contributed_preds,sldj_preds,pz_preds,train_preds,test_preds= [],[],[],[],[],[],[]
 for batch in train_loader:
    x, _ = batch
    x = x.to(device)
    x=x.float()
    x.requires_grad=False
    outputs,_= model(x,sldj=0)
    train_preds+=outputs.detach().cpu()
 #print(train_preds.shape)
    
 for batch in test_loader:
    x, y = batch
    x, y = x.to(device), y.to(device)
    x = x.to(device)
    x=x.float()
    x.requires_grad=False
    outputs,sldj= model(x,sldj=0)
    test_preds+=outputs.detach().cpu()
    #likelihood=1
    log_likelihood=-0.5 * (torch.pow(outputs,2))#Add with jacobian to be true density
    contributed_likelihood=log_likelihood*jac_sum
    contributed_sample_likelihood=torch.sum(contributed_likelihood,dim=1)
    sample_likelihood=torch.sum(log_likelihood,dim=1)
    pred=-(sample_likelihood+sldj)
    contributed_pred=-(contributed_sample_likelihood+sldj)
    contributed_preds+=contributed_pred
    preds+=pred
    targets += y
 i = 0
 auroc = AUROC(task="binary")
 average_precision = AveragePrecision(task="binary")
 contributed_preds=torch.stack(contributed_preds)
 targets=torch.stack(targets)
 preds=torch.stack(preds)
 roc_auc = auroc(preds, targets)
 ap_score = average_precision(preds, targets)
 #st=time.time()
 contributed_roc_auc=auroc(contributed_preds,targets)
 contributed_ap_score = average_precision(contributed_preds,targets)
 sldj_auc=0
 
 return roc_auc,contributed_roc_auc,sldj_auc,ap_score,contributed_ap_score,0,0


act=2
bs=2048#512512
epoch=100
lr=1e-3
mid_dim=2048
adam=1
pow=1
PNAL='L_1sq'
learn=False
mu=torch.zeros(1)
sigma=torch.ones(1)
#data='30_satellite.npz'#bank should run  satellite:84  waveform:87
folder_path = Path('./Datasets')
j=0
for file_path in ["Datasets/30_satellite.npz","datasets/7_Cardiotocography.npz","datasets/29_Pima.npz", "datasets/35_SpamBase.npz"]:
 first_slash_idx = file_path.find('/')
 dot_idx = file_path.rfind('.')
 data = file_path[first_slash_idx + 1 : dot_idx]
 #print(data)
 for lr in [1e-3,1e-3*5,1e-2]:
   for grad_pun in [0.01,0.1,1]:#1#63.6
            avg_auroc,avg_auprc,avg_con_auroc,avg_con_auprc,avg_f1,avg_con_f1=[],[],[],[],[],[]
            for i in range (0,5):
                train_data,train_lab,test_data,test_lab,Input_dim,std=read_OD_data(file_path,normalization='z-score')
                std=torch.diag(torch.tensor(1/std))
                model=_RealNVP(input_dim=Input_dim,
                 mid_dim=mid_dim, 
                 masktype=0,
                 act=act,
                 mu=mu,
                 sigma=sigma,
                 learn=learn
             )
                best_auc,best_con_auc,best_sldj_auc,best_pz_auc,best_auprc,best_con_auprc,best_f1,best_con_f1=0,0,0,0,0,0,0,0
                for n_epoch in range(epoch):
                 print(f"[EPOCH: {n_epoch:.1f}]")
                 torch.cuda.empty_cache()
                 train_set,test_set=CustomDataset(train_data, train_lab), CustomDataset(test_data, test_lab)
                 train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=bs,
        shuffle=True,
        num_workers=0
    )
                 test_loader = torch.utils.data.DataLoader(
        dataset=test_set, 
        batch_size=2048,
        shuffle=True,
        num_workers=0
    )           
                 con_loader = torch.utils.data.DataLoader(
        dataset=test_set, 
        batch_size=512,
        shuffle=True,
        num_workers=0
    )           
                 je=training(
                        model=model,
                        train_loader=train_loader,
                        lr=lr,
                        adam=adam,
                        wd=0,
                        pun=0,
                        grad_pun=grad_pun,
                        PNAL=PNAL,
                        std=std      
                        )
                 contribution,ee,min_max,likelihood,datas,labels=contribution_calculation(model,con_loader,pow,std)
                 _,con_auc,sldj_auc,_,con_auprc,_,_=testing(model,train_loader,test_loader,contribution)
                 print(f" CON_AUC:  {con_auc:.3f}, CON_AUPRC: {con_auprc}")
