import os
import torch
import torch.nn as nn
import numpy as np
from data_load.data_loader import Dataset_kpi, Dataset_yahoo
from .train_basis import train_basis
from torch.utils.data import DataLoader
from torch import optim
import csv
import math
import pandas as pd
from model.model import Model, param_weight, TransformerEncoder
from tools.tool import EarlyStopping, adjust_learning_rate
from gluonts.torch.util import copy_parameters
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from sklearn.utils import column_or_1d
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix, average_precision_score


from pyod.models.knn import KNN
from pyod.models.pca import PCA
from pyod.models.lof import LOF
from pyod.models.cblof import CBLOF
from pyod.models.mcd import MCD
from pyod.models.lscp import LSCP
from pyod.models.ecod import ECOD
#from pyod.models.suod import SUOD
# from pyod.models.auto_encoder import 

def reconstruct_label(timestamp, label):
    timestamp = np.asarray(timestamp, np.int64)
    index = np.argsort(timestamp)
    
    timestamp_sorted = np.asarray(timestamp[index])
    interval = np.min(np.diff(timestamp_sorted))

    label = np.asarray(label, np.int64)
    print(len(label))
    label = np.take_along_axis(label, index, axis=0)
    #np.asarray(label[index])

    idx = (timestamp_sorted - timestamp_sorted[0]) // interval

    new_label = np.zeros(shape=((timestamp_sorted[-1] - timestamp_sorted[0]) // interval + 1,), dtype=np.int)
    new_label[idx] = label

    return new_label


class train_anomal_repr(train_basis):
    def __init__(self, args):
        super(train_anomal_repr, self).__init__(args)
        self.args = args
        #self.model = self.ininial_model().to(self.device)
        #self.test_model = transformer_val(feat_dim = self.args.input_dim,  max_len = self.args.pred_len, d_model = self.args.encoding_dim, n_heads = self.args.num_heads, 
        #                num_layers = self.args.num_layers, dim_feedforward=self.args.dim_feedforward).to(self.device)
        #self.test_model = test_model(args).to(self.device)
        #self.l2_reg = 0 #args.max_train_length
    
    def load_data(self, flag):
        data_dict = {
            'anomal_detect_kpi':Dataset_kpi,
            'anomal_detect_yahoo':Dataset_yahoo,
        }
        dataset = data_dict[self.args.tasks](self.args.root_path, self.args.tasks, self.args.dataset, self.args.file_name, flag, 
                        self.args.feature, self.args.target, self.args.pred_len, self.args.scale)
        if flag == 'TRAIN' or flag == 'VAL':
            shuffle = True
        else:
            shuffle = False
        print(len(dataset))
        dataloader = DataLoader(dataset, shuffle=shuffle, batch_size = self.args.batch_size, drop_last=True, num_workers=0)
        return dataset, dataloader
        
    def ininial_model(self, length):
        #model = RnnEncoder(100, 7, 10)
        model = Model(self.args, length)
        return model
    def ininial_encoder(self, length):
        #model = RnnEncoder(100, 7, 10)
        model = TransformerEncoder(feat_dim = self.args.input_dim,  max_len = length, d_model = self.args.encoding_dim, n_heads = self.args.num_heads, 
                        num_layers = self.args.num_layers, dim_feedforward=self.args.dim_feedforward)
        return model
    
    def val_model(self, val_loader, val_data, criterion):  
        val_loss = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y) in enumerate(val_loader):
                batch_x = torch.nan_to_num(batch_x).float().to(self.device)
                res, kl_p, latent_variables, params = self.model(batch_x)
                nce = torch.empty([self.args.decomp_layers, self.args.batch_size, batch_x.shape[1]]).to(self.device)
                nce1 = torch.empty([self.args.decomp_layers, self.args.batch_size, batch_x.shape[1]]).to(self.device)
                output, target = self.encoder(batch_x)
                for i in range(self.args.decomp_layers):
                    pos_rep, _ = self.encoder(latent_variables[i])
                    score = -F.logsigmoid(torch.mean(torch.bmm(output, pos_rep.transpose(2, 1)), -1)/50)
                    nce[i] = score
                    nce1[i] = score/math.pow(i+1, self.args.beta)
                loss1 = torch.mean(torch.sum(params.permute(2, 0, 1)* nce1, 0), -1).sum()#*0.005
                loss2 = criterion(res, batch_x)
                loss3 = 1/(1 + torch.exp(kl_p*self.args.alpha))
                loss4 = criterion(target, batch_x)
                loss = loss1 + loss2 + loss3 + loss4
                val_loss.append(loss.item())
        val_loss = np.average(val_loss)
        self.model.train()
        return val_loss
                
    def train_model(self, setting):
        train_data, train_loader = self.load_data('TRAIN')
        val_data, val_loader = self.load_data('VAL')
        length = len(train_data[0][0])
        #n_classes = train_data.nb_classes
        self.model = self.ininial_model(length).to(self.device)
        self.encoder = self.ininial_encoder(length).to(self.device)
        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
        model_optim = optim.Adam(list(self.model.parameters())+list(self.encoder.parameters()), lr=self.args.learning_rate)
        criterion = nn.MSELoss()
        train_steps = len(train_data) 
        for epoch in range(self.args.epochs):
            self.model.train()
            #print(next(self.model.parameters()).device)
            epoch_loss = []
            for j, (batch_x, batch_y) in enumerate(train_loader):
                model_optim.zero_grad()
                batch_x = torch.nan_to_num(batch_x).float().to(self.device)
                res, kl_p, latent_variables, params = self.model(batch_x)
                output, target = self.encoder(batch_x)
                #print(output.shape)
                nce = torch.empty([self.args.decomp_layers, self.args.batch_size, length]).to(self.device)
                nce1 = torch.empty([self.args.decomp_layers, self.args.batch_size, length]).to(self.device)
                for i in range(self.args.decomp_layers):
                    pos_rep, _ = self.encoder(latent_variables[i])
                    score = -F.logsigmoid(torch.mean(torch.bmm(output, pos_rep.transpose(2, 1)), -1)/50)
                    #print(score.shape)
                    nce[i] = score
                    nce1[i] = score/math.pow(i+1, self.args.beta)
                #print(nce1.shape)
                #loss1 = F.sigmoid(torch.mean(torch.sum(params.permute(2, 0, 1)* nce1, 0))*0.005)
                #loss1 = torch.mean(torch.sum(params.permute(2, 0, 1)* nce1, 0), -1).sum()*0.005
                nce2 = params.permute(2, 0, 1)* nce1
                loss1 = nce2.sum(0).mean(-1).sum()#*0.005
                if self.args.scope:
                    loss1 = F.sigmoid(loss1)
                #print(loss1.item())
                loss2 = criterion(res, batch_x)
                loss3 = 1/(1 + torch.exp(kl_p*self.args.alpha))
                loss4 = criterion(target, batch_x)
                loss = loss1 + loss2 + loss3 + loss4
                loss.backward()
                epoch_loss.append(loss.item())
                model_optim.step()
                if (j+1) % 100==0:
                    print("\titers: {0}, epoch: {1} | loss1: {2:.7f}, loss2: {3:.7f}, loss3: {4:.7f}, loss4: {5:.7f}".format(j + 1, epoch + 1, loss1.item(), loss2.item(),
                                                    loss3.item(), loss4.item()))    
            epoch_loss = np.average(epoch_loss)
            vali_loss = self.val_model(val_loader, val_data, criterion)
            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f}, Vali Loss: {3:.7f}".format(epoch + 1, train_steps, epoch_loss,  vali_loss))
            early_stopping(vali_loss, self.encoder, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break
            adjust_learning_rate(model_optim, epoch+1, self.args) 
        #best_model_path = path+'/'+'checkpoint.pth'
        #self.model.load_state_dict(torch.load(best_model_path))

    
    def test_anomal_detection(self, setting):
        test_data, test_loader = self.load_data('TEST')
        path = os.path.join(self.args.checkpoints, setting)
        length = len(test_data[0][0])
        self.model = self.ininial_encoder(length).to(self.device)
        #self.test_model = self.ininial_test_model(length).to(self.device)
        best_model_path = path+'/'+'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))
        #copy_parameters(self.model, self.test_model)
        window_size = self.args.pred_len 
        self.detector = ECOD()
        if self.args.tasks == 'anomal_detect_kpi':
            df_raw = pd.read_hdf(os.path.join(self.args.root_path, 'phase2_ground_truth.hdf'))
            df_raw = df_raw.set_index(['KPI ID', 'timestamp']).sort_index()
            scaler = StandardScaler()
            all_auc = []
            all_auprc = []
            all_recall = []
            all_f1 = []
            all_precision = []  
            all_acc = []
            for name, df in df_raw.groupby(level=0):
                data = df['value'].to_numpy()
                label = df['label'].to_numpy()
                print(len(data), len(label))
                timestamp = df.index.get_level_values(1)
                print(len(timestamp), len(label))
                data = np.expand_dims(data, -1)
                scaler.fit(data)
                value = scaler.transform(data)
                length = len(value)//window_size
                test_data= torch.tensor(np.array(value[:length*window_size]).reshape(-1, window_size ,1)).float().to(self.device)
                label = label[:length*window_size]
                timestamp = timestamp[:length*window_size]
                test_rep, _ = self.model(test_data)
                test_rep = test_rep.detach().reshape(-1, 32).cpu().numpy()
                #label = label.reshape(-1, 1)
                print('encoding done.')
                self.detector.fit(test_rep)
    
                print('train_done')
                #anomaly_scores_knn = self.detector.decision_function(test_rep)
                #print(anomaly_scores_knn)
                #auc = roc_auc_score(column_or_1d(label), column_or_1d(anomaly_scores_knn))#[:,0])
                #auprc = average_precision_score(column_or_1d(label), column_or_1d(anomaly_scores_knn))#[:,0])
                pred_label = self.detector.predict(test_rep)
                pred_label = reconstruct_label(timestamp, pred_label)
                label = reconstruct_label(timestamp, label)
                print(pred_label)
                recall = recall_score(column_or_1d(label), column_or_1d(pred_label))
                precision = precision_score(column_or_1d(label), column_or_1d(pred_label))
                f1 = f1_score(column_or_1d(label), column_or_1d(pred_label))
    
                acc = accuracy_score(column_or_1d(label), column_or_1d(pred_label))
                auc = roc_auc_score(column_or_1d(label), column_or_1d(pred_label))#[:,0])
                auprc = average_precision_score(column_or_1d(label), column_or_1d(pred_label))#[:,0])
                #precision = precision_recall_fscore_support(column_or_1d(label), column_or_1d(pred_label))
                c = confusion_matrix(column_or_1d(label), column_or_1d(pred_label))
        
                all_acc.append(acc)
                all_f1.append(f1)
                all_precision.append(precision)
                all_recall.append(recall)
                all_auc.append(auc)
                all_auprc.append(auprc)
        
            with open('kpi_result.csv', 'a+') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([self.args.method, np.average(all_acc), np.average(all_auc), np.average(all_auprc), np.average(all_recall), np.average(all_precision), np.average(all_f1), c])
            print(np.average(all_auc), np.average(all_recall), np.average(all_precision), np.average(all_f1))
            print(len(df_raw.groupby(level=0)))
        else:
            data = Dataset_yahoo(self.args.root_path, None, None, flag='TEST')
            label = np.array(data[:][1])
            data = np.array(data[:][0])
            print(label.shape)
            test_data = torch.tensor(data).float().to(self.device)
            print(len(label))
    
            test_rep, _ = self.model(test_data)
            test_rep = test_rep.detach().reshape(-1, 32).cpu().numpy()
            label = label.reshape(-1, 1)
            print('encoding done.')
            self.detector.fit(test_rep)
    
            print('train_done')
            #anomaly_scores_knn = self.detector.decision_function(test_rep)
            #print(anomaly_scores_knn)
            pred_label = self.detector.predict(test_rep)
            print(pred_label)
            recall = recall_score(column_or_1d(label), column_or_1d(pred_label))
            precision = precision_score(column_or_1d(label), column_or_1d(pred_label))
            f1 = f1_score(column_or_1d(label), column_or_1d(pred_label))
            acc = accuracy_score(column_or_1d(label), column_or_1d(pred_label))
            auc = roc_auc_score(column_or_1d(label), column_or_1d(pred_label))#[:,0])
            auprc = average_precision_score(column_or_1d(label), column_or_1d(pred_label))#[:,0])
            c = confusion_matrix(column_or_1d(label), column_or_1d(pred_label))
            with open('yahoo_result.csv', 'a+') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([self.args.method, acc, auc, auprc, recall, precision, f1, c])
