import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import scipy
from sklearn.neighbors import LocalOutlierFactor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils import data
from torchvision import transforms
# from PIL import Image
import os
from collections import OrderedDict
# import matplotlib.pyplot as plt
import torchvision.models as models
# This is for the progress bar.
from tqdm import tqdm
import seaborn as sns
from sklearn.metrics import precision_score, f1_score
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve
import csv
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
import math
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
import random
from scipy.stats import t
from sklearn.metrics import roc_curve, roc_auc_score, auc
# from openTSNE import TSNE
import torchvision
from torch.utils import data
from torchvision import transforms
from torch.utils.data import Subset
import torch.nn.functional as F
from matplotlib.pyplot import figure
from torch import Tensor
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data as utils
from torch.utils.data import Sampler, Dataset
from scipy.io import loadmat
from scipy.spatial import distance
import click
import json
from os import path as osp
from .tools import *


class CustomDataset(Dataset):
    def __init__(self,
                 X,
                 y):
        self.data = X
        self.targets = y

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]), (self.targets[idx])


class FourLayer(nn.Module):
    def __init__(self,
                 input_dim=2,
                 num_classes=1,
                 num_hidden_nodes=(2, 2)
                 ):
        super(FourLayer, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.num_hidden_nodes = num_hidden_nodes

        activ = nn.ReLU(True)

        # for text anomaly detection
        # activ = nn.Tanh()
        # ==========================

        self.feature_extractor = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(self.input_dim, self.num_hidden_nodes[0])),
            ('relu1', activ),
            ('fc2', nn.Linear(self.num_hidden_nodes[0], self.num_hidden_nodes[1])),
            ('relu2', activ),
        ]))
        self.size_final = self.num_hidden_nodes[1]

        self.classifier = nn.Sequential(OrderedDict([
            ('fc3', nn.Linear(self.num_hidden_nodes[1], int(self.num_hidden_nodes[1] / 2))),
            ('relu3', activ),
            ('fc4', nn.Linear(int(self.num_hidden_nodes[1] / 2), self.num_classes))]))

        # self.lamda = nn.Parameter(0 * torch.ones([1, 1]))
        # self.inp_lamda = nn.Parameter(0 * torch.ones([1, 1]))

    def forward(self, input):
        features = self.feature_extractor(input)
        logits = self.classifier(features.view(-1, self.size_final))
        return logits

    def half_forward_start(self, input):
        return self.feature_extractor(input)

    def half_forward_end(self, input):
        return self.classifier(input.view(-1, self.size_final))


class DPAD:
    def __init__(self, train_x, test_x, test_y, gamma, lamb, k, hidden_dims,
                 num_classes=128,
               bs=8192,
               n_epochs=200,
               learning_rate=1e-3,
               adam=1,
                 device='cuda:0'):
        self.train_x, self.test_x, self.test_y, self.gamma, self.lamb, self.k, self.hidden_dims = \
            train_x, test_x, test_y, gamma, lamb, k, hidden_dims,
        self.num_classes = num_classes

        self.bs = bs
        self.n_epochs = n_epochs
        self.learning_rate = learning_rate
        self.adam = adam
        self.device = device

        train_dataset = CustomDataset(train_x, np.zeros(train_x.shape[0]))
        test_dataset = CustomDataset(test_x, test_y)

        self.train_loader = DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True,
                                  num_workers=0)
        self.test_loader = DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False,
                                 num_workers=0)

        n_dim = train_x.shape[-1]
        self.net = FourLayer(input_dim=n_dim, num_classes=self.num_classes, num_hidden_nodes=hidden_dims)

        self.train_c = None

    def training(self):
        # lossmodel_path='./2norm['+str(normal_class)+'],r='+str(r)+',bs='+str(bs)+'/'
        # random.seed(0)
        # np.random.seed(0)
        # torch.manual_seed(0)
        # torch.cuda.manual_seed(0)
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False # 确保每次卷积使用的算法一致
        # model=MyNet(64)

        model = self.net.to(self.device)
        # model.device = device

        weight_decay = 0

        if self.adam == 1:
            optimizer = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate, weight_decay=weight_decay, amsgrad=1)
        if self.adam == 0:
            optimizer = torch.optim.SGD(self.net.parameters(), lr=self.learning_rate, weight_decay=weight_decay, momentum=0.9)

        best_loss = 1000000

        r = self.gamma
        punish = self.lamb


        for epoch in range(self.n_epochs):
            # ---------- Training ----------
            # Make sure the model is in train mode before training.
            # model.train()
            # These are used to record information in training.
            train_loss = []
            all_dist = []
            exp_dist = []
            # Iterate the training set by batches.
            for batch in self.train_loader:
                model = model.to(self.device)
                model.train()
                loss = 0.0
                # A batch consists of image data and corresponding labels.
                imgs, _ = batch
                imgs = imgs.float().to(self.device)
                outputs = model(imgs)
                rs = 0
                dists = torch.cdist(outputs.view(outputs.shape[0], -1), outputs.view(outputs.shape[0], -1), p=2)
                dists = torch.pow(dists, 2)
                real_dists = torch.sum(dists, dim=1)
                real_dists = torch.sum(real_dists, dim=0)
                exp = torch.exp(dists * (-r)).detach()
                # t_dis=dists=torch.pow((1+dists/r), -(r+1)/2)
                dists = exp * dists
                # dists=torch.pow((1+dists/r), -(r+1)/2)*dists
                dists = torch.sum(dists, dim=1)
                # print(dists)
                # break
                dists = torch.sum(dists, dim=0)
                loss = dists / self.bs
                for _, param in model.named_parameters():
                    rs += abs(torch.norm(param, p=2) - 1)
                rs = rs * punish
                loss = loss + rs
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_loss.append(loss)

        self.get_train_c()
        # auroc, aupr, f1, train_score_list = testing_while_train(model=model, batchsize=256, nn=nn, percent=percent,
        #                                                         train_loader=train_loader, test_loader=test_loader)
        # # if auroc > best_auroc:
        # #     best_auroc = auroc
        # #     best_auc_socres = train_score_list
        # # if aupr > best_pr:
        # #     best_pr = aupr
        # # if f1 > best_f1:
        # #     best_f1 = f1
        #
        # # if epoch==n_epochs-1:
        # #     return 0
        # return auroc, aupr, f1, train_score_list

    def get_train_c(self):
        self.net.eval()
        c = []
        for x,y in self.train_loader:
            x, y = x.float().to(self.device), y.to(self.device)
            with torch.no_grad():
                #pred = model.half_forward_start(x)
                pred = self.net(x)
                c += pred.detach()
        c = torch.stack(tuple(c))
        self.train_c = c

    def decision_function(self, test_x):
        # bz = batchsize
        # knn = nn
        nn = self.k
        model = self.net
        # device = 'cuda:0'
        # model = model.to(device)
        # model.device = device
        test_set = torch.utils.data.TensorDataset(torch.Tensor(test_x), torch.zeros(test_x.shape[0]))
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=1024, shuffle=False, num_workers=0)

        model.eval()
        test_dists_all, targets = [], []
        for batch in test_loader:
            x, _ = batch
            x = x.float().to(self.device)

            with torch.no_grad():
                pred = model(x)
                pred = pred.detach()

            # test_c = torch.stack(tuple(preds))
            test_c = pred
            # print(c.shape)
            # print(test_c.shape)
            test_dists = torch.cdist(test_c.view(test_c.shape[0], -1), self.train_c.view(self.train_c.shape[0], -1), p=2)
            test_dist_sorted, indices = torch.topk(test_dists, k=nn, dim=1, largest=False)
            test_dists = torch.sum(test_dist_sorted, dim=1) / nn
            test_dists_all.append(test_dists)

        test_dists_all = torch.cat(tuple(test_dists_all)).cpu().numpy()
        return test_dists_all

        # targets = torch.stack(targets)
        # roc_auc = roc_auc_score(targets, test_dists_all)
        # precision, recall, _ = precision_recall_curve(targets, test_dists_all)
        # auc_pr = average_precision_score(targets, test_dists_all)
        # # recall of abnormal data ================================
        #
        # # calculate f1
        # f1 = f1_calculator(targets, test_dists_all)
        # # threshold = 0.5
        # # pred_lab = np.where(test_dists_all > threshold, 1, 0)
        # # recall_num = 0
        # # for pl, l in zip(pred_lab, targets):
        # #     if pl == 1 and l == 1:
        # #         recall_num += 1
        # # recall_rate = 2 * recall_num / len(targets)
        # # print(fpr)
        # #roc_auc = auc(fpr, tpr)
        # return roc_auc, auc_pr, f1, train_score_list
