import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import scipy
from sklearn.neighbors import LocalOutlierFactor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils import data
#from torchvision import transforms
# from PIL import Image
import os
from collections import OrderedDict
# import matplotlib.pyplot as plt
#import torchvision.models as models
# This is for the progress bar.
from tqdm import tqdm
import seaborn as sns
from sklearn.metrics import precision_score, f1_score
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve
import csv
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
import math
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
import random
from scipy.stats import t
from sklearn.metrics import roc_curve, roc_auc_score, auc
# from openTSNE import TSNE
#import torchvision
from torch.utils import data
#from torchvision import transforms
from torch.utils.data import Subset
import torch.nn.functional as F
from matplotlib.pyplot import figure
from torch import Tensor
#import torchvision.transforms as transforms
#import torchvision.datasets as datasets
import torch.utils.data as utils
from torch.utils.data import Sampler, Dataset
from scipy.io import loadmat
from scipy.spatial import distance
#import click
import json
from os import path as osp
#from .tools import *


class CustomDataset(Dataset):
    def __init__(self,
                 X,
                 y):
        self.data = X
        self.targets = y

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]), (self.targets[idx])


class FourLayer(nn.Module):
    def __init__(self,
                 input_dim=2,
                 num_classes=1,
                 num_hidden_nodes=(2, 2)
                 ):
        super(FourLayer, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.num_hidden_nodes = num_hidden_nodes

        activ = nn.ReLU(True)

        # for text anomaly detection
        # activ = nn.Tanh()
        # ==========================

        self.feature_extractor = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(self.input_dim, self.num_hidden_nodes[0])),
            ('relu1', activ),
            ('fc2', nn.Linear(self.num_hidden_nodes[0], self.num_hidden_nodes[1])),
            ('relu2', activ),
        ]))
        self.size_final = self.num_hidden_nodes[1]

        self.classifier = nn.Sequential(OrderedDict([
            ('fc3', nn.Linear(self.num_hidden_nodes[1], int(self.num_hidden_nodes[1] / 2))),
            ('relu3', activ),
            ('fc4', nn.Linear(int(self.num_hidden_nodes[1] / 2), self.num_classes))]))

        # self.lamda = nn.Parameter(0 * torch.ones([1, 1]))
        # self.inp_lamda = nn.Parameter(0 * torch.ones([1, 1]))

    def forward(self, input):
        features = self.feature_extractor(input)
        logits = self.classifier(features.view(-1, self.size_final))
        return logits

    def half_forward_start(self, input):
        return self.feature_extractor(input)

    def half_forward_end(self, input):
        return self.classifier(input.view(-1, self.size_final))


class DPAD:
    def __init__(self, train_x, test_x, test_y, gamma, lamb, k, hidden_dims,
                 num_classes=128,
               bs=8192,
               n_epochs=200,
               learning_rate=1e-3,
               adam=1,
                 device='cuda:0'):
        self.train_x, self.test_x, self.test_y, self.gamma, self.lamb, self.k, self.hidden_dims = \
            train_x, test_x, test_y, gamma, lamb, k, hidden_dims,
        self.num_classes = num_classes

        self.bs = bs
        self.n_epochs = n_epochs
        self.learning_rate = learning_rate
        self.adam = adam
        self.device = device

        train_dataset = CustomDataset(train_x, np.zeros(train_x.shape[0]))
        test_dataset = CustomDataset(test_x, test_y)

        self.train_loader = DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True,
                                  num_workers=0)
        self.test_loader = DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False,
                                 num_workers=0)

        n_dim = train_x.shape[-1]
        self.net = FourLayer(input_dim=n_dim, num_classes=self.num_classes, num_hidden_nodes=hidden_dims)

        self.train_c = None

    def training(self):
        # lossmodel_path='./2norm['+str(normal_class)+'],r='+str(r)+',bs='+str(bs)+'/'
        # random.seed(0)
        # np.random.seed(0)
        # torch.manual_seed(0)
        # torch.cuda.manual_seed(0)
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False # 确保每次卷积使用的算法一致
        # model=MyNet(64)

        model = self.net.to(self.device)
        # model.device = device

        weight_decay = 0

        if self.adam == 1:
            optimizer = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate, weight_decay=weight_decay, amsgrad=1)
        if self.adam == 0:
            optimizer = torch.optim.SGD(self.net.parameters(), lr=self.learning_rate, weight_decay=weight_decay, momentum=0.9)

        best_loss = 1000000

        r = self.gamma
        punish = self.lamb


        for epoch in range(self.n_epochs):
            # ---------- Training ----------
            # Make sure the model is in train mode before training.
            # model.train()
            # These are used to record information in training.
            train_loss = []
            all_dist = []
            exp_dist = []
            # Iterate the training set by batches.
            for batch in self.train_loader:
                model = model.to(self.device)
                model.train()
                loss = 0.0
                # A batch consists of image data and corresponding labels.
                imgs, _ = batch
                imgs = imgs.float().to(self.device)
                outputs = model(imgs)
                rs = 0
                dists = torch.cdist(outputs.view(outputs.shape[0], -1), outputs.view(outputs.shape[0], -1), p=2)
                dists = torch.pow(dists, 2)
                real_dists = torch.sum(dists, dim=1)
                real_dists = torch.sum(real_dists, dim=0)
                exp = torch.exp(dists * (-r)).detach()
                # t_dis=dists=torch.pow((1+dists/r), -(r+1)/2)
                dists = exp * dists
                # dists=torch.pow((1+dists/r), -(r+1)/2)*dists
                dists = torch.sum(dists, dim=1)
                # print(dists)
                # break
                dists = torch.sum(dists, dim=0)
                loss = dists / self.bs
                for _, param in model.named_parameters():
                    rs += abs(torch.norm(param, p=2) - 1)
                rs = rs * punish
                loss = loss + rs
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_loss.append(loss)

        self.get_train_c()
        # auroc, aupr, f1, train_score_list = testing_while_train(model=model, batchsize=256, nn=nn, percent=percent,
        #                                                         train_loader=train_loader, test_loader=test_loader)
        # # if auroc > best_auroc:
        # #     best_auroc = auroc
        # #     best_auc_socres = train_score_list
        # # if aupr > best_pr:
        # #     best_pr = aupr
        # # if f1 > best_f1:
        # #     best_f1 = f1
        #
        # # if epoch==n_epochs-1:
        # #     return 0
        # return auroc, aupr, f1, train_score_list

    def get_train_c(self):
        self.net.eval()
        c = []
        for x,y in self.train_loader:
            x, y = x.float().to(self.device), y.to(self.device)
            with torch.no_grad():
                #pred = model.half_forward_start(x)
                pred = self.net(x)
                c += pred.detach()
        c = torch.stack(tuple(c))
        self.train_c = c

    def decision_function(self, test_x):
        # bz = batchsize
        # knn = nn
        nn = self.k
        model = self.net
        # device = 'cuda:0'
        # model = model.to(device)
        # model.device = device
        test_set = torch.utils.data.TensorDataset(torch.Tensor(test_x), torch.zeros(test_x.shape[0]))
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=1024, shuffle=False, num_workers=0)

        model.eval()
        test_dists_all, targets = [], []
        for batch in test_loader:
            x, _ = batch
            x = x.float().to(self.device)

            with torch.no_grad():
                pred = model(x)
                pred = pred.detach()

            # test_c = torch.stack(tuple(preds))
            test_c = pred
            # print(c.shape)
            # print(test_c.shape)
            test_dists = torch.cdist(test_c.view(test_c.shape[0], -1), self.train_c.view(self.train_c.shape[0], -1), p=2)
            test_dist_sorted, indices = torch.topk(test_dists, k=nn, dim=1, largest=False)
            test_dists = torch.sum(test_dist_sorted, dim=1) / nn
            test_dists_all.append(test_dists)

        test_dists_all = torch.cat(tuple(test_dists_all)).cpu().numpy()
        return test_dists_all

        # targets = torch.stack(targets)
        # roc_auc = roc_auc_score(targets, test_dists_all)
        # precision, recall, _ = precision_recall_curve(targets, test_dists_all)
        # auc_pr = average_precision_score(targets, test_dists_all)
        # # recall of abnormal data ================================
        #
        # # calculate f1
        # f1 = f1_calculator(targets, test_dists_all)
        # # threshold = 0.5
        # # pred_lab = np.where(test_dists_all > threshold, 1, 0)
        # # recall_num = 0
        # # for pl, l in zip(pred_lab, targets):
        # #     if pl == 1 and l == 1:
        # #         recall_num += 1
        # # recall_rate = 2 * recall_num / len(targets)
        # # print(fpr)
        # #roc_auc = auc(fpr, tpr)
        # return roc_auc, auc_pr, f1, train_score_list
