from discriminator import Discriminator
from data_process import generate_negative_samples, load_data

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import roc_auc_score, precision_score
import numpy as np

import os
import time

class NOD:
    def __init__(self,config):
        super(NOD,self).__init__()
        self.config = config
        self.net = Discriminator(config).to(config['device'])
        
        # choose loss function
        if config['loss'] == "BCELoss":
            self.cruteon = nn.BCELoss().to(config['device'])# BCELoss 
        elif config['loss'] == "MSELoss":
            self.cruteon = nn.MSELoss().to(config['device'])# MSELoss 
        
        # choose optimizer
        if config['optimizer'] == "Adam":
            self.optim = optim.Adam(self.net.parameters(), lr = config['learning_rate'], weight_decay = config['weight_decay'])
        elif config['optimizer'] == "SGD":
            self.optim = optim.SGD(self.net.parameters(), lr = config['learning_rate'], momentum = config['momentum'], weight_decay = config['weight_decay'])
    
    #model train 
    def train(self,init_data,init_label): #init_label only is used in supervision tasks
        # get data
        train_data = init_data.copy() #Initially, all samples are considered normal points and labeled as 0
        self.init_data = init_data
        
        train_label = np.zeros(train_data.shape[0]) # label
        
        if self.config['is_scale']:
            self.scaler = MinMaxScaler()
            self.scaler.fit(train_data)
            train_data = self.scaler.transform(train_data)
            train_data = train_data.astype(np.float32)

        train_x, train_y = train_data, train_label

        if not self.config['if_neg_in_every_epoch']:
            noise_train_x,noise_train_y = generate_negative_samples(train_x, self.config['neg_rate'], self.config['neg_min'], self.config['neg_max'], self.config['if_neg_every_feature'])
            x = np.vstack((train_x, noise_train_x))
            y = np.hstack((train_y, noise_train_y))
            x = torch.tensor(x,dtype=torch.float32).to(self.config['device'])
            y = torch.tensor(y,dtype=torch.float32).to(self.config['device'])

        #step_4. train NOD
        loss_list = []
        auc_list, classification_auc_list, classification_precision_list = [], [], []
        sum_time_classification_threshold, sum_time_paint_auc = 0, 0
        tmp_cl_auc, tmp_cl_pre = 0, 0
        for epoch in range(1,self.config["epochs"]+1,1):
            
            if self.config['if_neg_in_every_epoch']:
                noise_train_x,noise_train_y = generate_negative_samples(train_x, self.config['neg_rate'], self.config['neg_min'], self.config['neg_max'], self.config['if_neg_every_feature'])
                x = np.vstack((train_x, noise_train_x))
                y = np.hstack((train_y, noise_train_y))
                x = torch.tensor(x,dtype=torch.float32).to(self.config['device'])
                y = torch.tensor(y,dtype=torch.float32).to(self.config['device'])

            self.net.train()
            output = self.net(x)
            loss = self.cruteon(output,y)

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            
            loss_list.append(loss.item())

            save_model_path = os.path.join(self.config["save_dir"], self.config["version"], "model", self.config['summary_key'])            
            if not os.path.exists(save_model_path):
                os.makedirs(save_model_path)
            if epoch % self.config['save_frequency'] == 0:
                print("epoch",epoch,'loss',loss.item())
                if self.config['save_process_model']:
                    torch.save(self.net.state_dict(),os.path.join(save_model_path , 'epoch'+ str(epoch) +".mdl"))
            torch.save(self.net.state_dict(), os.path.join(save_model_path,"best.mdl"))   # 没有验证集， 我们认为


            # Calculate the AUC for each epoch and use it to plot and observe the changes in AUC
            paint_start = time.time()
            test_data = init_data
            if self.config['is_scale']:
                test_data = self.scaler.transform(test_data)
            test_data = torch.tensor(np.array(test_data), dtype=torch.float32).to(self.config['device'])
            self.net.eval()
            result = self.net(test_data)
            result = result.cpu().data.numpy()  # the predict value in test data

            od_auc = roc_auc_score(init_label, result)
            auc_list.append(od_auc)   # auc of outlier detection 
            sum_time_paint_auc += (time.time() - paint_start)
            
            
            threshold_time = time.time()
            cl_auc = roc_auc_score(y.cpu().data.numpy(), output.cpu().data.numpy()) # auc of auxiliary binary classification
            precision_scores = output.cpu().data.numpy()
            threshold = np.percentile(precision_scores, 50)
            da = np.where(precision_scores > threshold)[0]
            # xiao = np.setdiff1d(np.arange(precision_score.shape[0]), da)
            pred_label = np.zeros_like(precision_scores)
            pred_label[da] = 1

            # precision of auxiliary binary classification, it is deprecated
            cl_pre = (pred_label == y.cpu().data.numpy()).sum() / pred_label.shape[0]

            if self.config['paint_classificaiton_auc']:
                classification_auc_list.append(cl_auc)
            if self.config['paint_classificaiton_precision']: 
                classification_precision_list.append(cl_pre)
            sum_time_classification_threshold += (time.time() - threshold_time)

            # use AUC change rate to decide whether stop training
            if self.config['use_classification_auc_early_stopping']:
                if epoch % self.config['delta_epochs'] == 0:
                    if cl_auc - tmp_cl_auc < self.config['delta_threshold']:
                        break
                    tmp_cl_auc = cl_auc
            # it is deprecated
            if self.config['use_classification_precision_early_stopping']:
                if epoch % self.config['delta_epochs'] == 0:
                    if cl_pre - tmp_cl_pre < self.config['delta_threshold']:
                        break
                    tmp_cl_pre = cl_pre
        
        return loss_list, auc_list, classification_auc_list, classification_precision_list, sum_time_classification_threshold, sum_time_paint_auc
        
    def predict(self, save_model_path = False, data_for_predict = None): # predict;   The following two parameters indicate the evaluation and prediction of the model after obtaining it
        if not save_model_path:
            test_data = self.init_data
            if self.config['is_scale']:
                test_data = self.scaler.transform(test_data)
            test_data = torch.tensor(np.array(test_data),dtype=torch.float32).to(self.config['device'])
            save_model_path = os.path.join(self.config["save_dir"], self.config["version"], "model", self.config['summary_key'], "best.mdl") 
            # print(save_model_path)
        else: 
            test_data = data_for_predict
            if self.config['is_scale']:
                self.scaler = MinMaxScaler()
                self.scaler.fit(test_data)
                test_data = self.scaler.transform(test_data)
                test_data = test_data.astype(np.float32)
            test_data = torch.tensor(np.array(test_data),dtype=torch.float32).to(self.config['device'])
        self.net.eval()
        self.net.load_state_dict(torch.load(save_model_path))
        result = self.net(test_data)
        result = result.cpu().data.numpy()

        threshold = np.percentile(result, (1 - self.config['contamination']) * 100)
        da = np.where(result > threshold)[0]
        pred_label = np.zeros_like(result)
        pred_label[da] = 1

        self.labels_ = result
        self.pred_labels = pred_label
