# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset
from sklearn.preprocessing import MinMaxScaler
from util.env import get_device, set_device
from util.preprocess import build_loc_net, construct_data
from util.net_struct import get_feature_map, get_fc_graph_struc
from datasets.TimeDataset import TimeDataset, Transform
from models.proactive_anomaly_detection import proactive_anomaly_detection
from train import train
from test  import test
from evaluate import get_score_2
import os
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import random
from util.data import *


def norm(train, test):
    column = train.columns
    normalizer = MinMaxScaler(feature_range=(0, 1)).fit(train) # scale training data to [0,1] range
    train_ret = normalizer.transform(train)
    test_ret = normalizer.transform(test)
    
    train_ret = pd.DataFrame(train_ret, columns=column)
    test_ret = pd.DataFrame(test_ret, columns=column)
    
    return train_ret, test_ret, normalizer


class Main():
    def __init__(self, train_config, env_config, save_folder, debug=False):

        self.train_config = train_config
        self.env_config = env_config
        self.save_folder = save_folder
        
        
        ##### Dataset, Provider #####
        dataset = self.env_config['dataset'] 
        self.dataset = dataset


        ##### original train/test data
        train_orig = pd.read_csv(f'./data/{self.dataset}/{self.dataset}_train.csv', sep=',', index_col=0)
        test_orig = pd.read_csv(f'./data/{self.dataset}/{self.dataset}_test.csv', sep=',', index_col=0)
        
        
        ##### label and data
        train = train_orig.iloc[:,:-1]  ; train_labels = train_orig.iloc[:,-1]
        self.train_orig = train         
        test =  test_orig.iloc[:,:-1]   ; test_labels = test_orig.iloc[:,-1]


        ##### normalization
        train, test, self.normalizer = norm(train, test)
        
        
        ##### Make feature map for graph #####
        feature_map = get_feature_map(dataset)
        fc_struc = get_fc_graph_struc(dataset)

        set_device(env_config['device'])
        self.device = get_device()

        fc_edge_index = build_loc_net(fc_struc, list(train.columns), feature_map=feature_map)
        fc_edge_index = torch.tensor(fc_edge_index, dtype = torch.long)

        self.feature_map = feature_map

        train.columns = feature_map
        test.columns = feature_map
        train_dataset_indata = construct_data(train, feature_map, labels=0)
        test_dataset_indata = construct_data(test, feature_map, labels=test_labels.tolist())

        cfg = { 'slide_win': train_config['slide_win'],
            'slide_stride': train_config['slide_stride'] }



        ##### Make Transform
        total_orig = pd.concat([train, test],axis=0)
        self.transform = Transform(total_orig)
        
        train_dataset = TimeDataset(train_dataset_indata, fc_edge_index, self.transform, mode='train', config=cfg)        
        test_dataset = TimeDataset(test_dataset_indata, fc_edge_index, self.transform, mode='test', config=cfg)
        train_dataloader, val_dataloader = self.get_loaders(train_dataset, train_config['seed'], train_config['batch'], val_ratio = train_config['val_ratio'])

        self.train_dataset = train_dataset
        self.train_dataloader = train_dataloader
        
        self.val_dataloader = val_dataloader
        
        self.test_dataset = test_dataset
        self.test_dataloader = DataLoader(test_dataset, batch_size=train_config['batch'], shuffle=False, num_workers=0)


        ##### Model definition #####
        edge_index_sets = []
        edge_index_sets.append(fc_edge_index)
        print(train_config['model'])
        self.model_name = train_config['model']
        self.model = proactive_anomaly_detection(edge_index_sets, input_dim = train_config['slide_win'], decay= train_config['decay'],
                             dim = train_config['dim'], topk = train_config['topk'],trans = self.transform, pred_len = 1, 
                             graph = True, adaptive_gcn_option = True).to(self.device)
            
            
    ##### Run ######
    def run(self, save_folder):
        ##### train part
        #####   1. if there is pretrained model, skip the train part
        #####   2. if there is no pretrained model, conduct train part
        #####       output : train_log = (i_epoch, epoch, acu_loss/len(dataloader), acu_loss, val_loss)
        
        if len(self.env_config['load_model_path']) > 0:
            model_save_path = self.env_config['load_model_path']
            print("Get pretrained model")
        else:
            model_save_path = self.get_save_path(save_folder)[0]
            self.train_log = train(self.model, model_save_path, config = train_config,
                                   train_dataloader = self.train_dataloader, val_dataloader = self.val_dataloader, 
                                   test_dataloader=self.test_dataloader, test_dataset=self.test_dataset, train_dataset=self.train_dataset,
                                   feature_map=self.feature_map, dataset_name=self.env_config['dataset'])
      
        ##### test(predict) part
        #####   output : Test result(MSE, CE), Predict value
        if self.device == 'cpu':
            self.model.load_state_dict(torch.load(model_save_path, map_location = self.device))
        else:
            self.model.load_state_dict(torch.load(model_save_path))
        best_model = self.model.to(self.device)
        self.denorm_test_result, self.test_result = test(best_model, self.test_dataloader, train_config, "test")  
              
        print(f'Test MSE of conti          : {self.denorm_test_result[0]}')
        print(f'Test Cross-Entropy of cate : {self.denorm_test_result[1]}')

        ##### test(Detection) part        
        figure_save_folder = self.get_save_path(save_folder)
        self.conti_ind, self.cate_ind, _ =  self.transform.col_index()
        self.dect_res = self.detection(self.test_result, figure_save_folder, transform = self.transform)   
            
        if len(self.env_config['load_model_path']) > 0:
            return self.denorm_test_result, self.dect_res, self.conti_ind, self.cate_ind
            ##### (MSE, CE Loss), (score list, node_error), conti_index, cate_index
        else:
            return self.denorm_test_result, self.dect_res, self.conti_ind, self.cate_ind, self.train_log
            ##### (MSE, CE Loss), (score list, node_error), conti_index, cate_index, (i_epoch, epoch, acu_loss/len(dataloader), acu_loss, val_loss)


    ##### Function for Detection part
    def detection(self, test_result, figure_save_folder, transform = None):
        np_test_result = np.array(test_result)
        test_pred   = np_test_result[0]                 
        test_ground = np_test_result[1]                 
        test_labels = np_test_result[2, :, 0]           
        
        train_orig = self.train_orig                    
        ##### Inverse Transform for one-hot embedding and separation training
        test_pred_init = transform.inverse_transform(test_pred)         
        test_ground_init = transform.inverse_transform(test_ground)     

        ##### Inverse Transform for minmax normalizer
        test_pred_init = self.normalizer.inverse_transform(test_pred_init)
        test_ground_init = self.normalizer.inverse_transform(test_ground_init)

        ##### Save predict value
        np.save(figure_save_folder[2]+'.npy', test_pred_init)

        ##### Compute error of each node
        node_error = []
        for i in range(len(test_pred_init[0])):
            node_ith_error = np.mean((test_pred_init[:,i]-test_ground_init[:,i])**2)
            node_error.append(node_ith_error)
        
        ##### Get Detection Score
        mean_score_list = get_score_2(test_pred_init, test_labels, train_orig, self.dataset, figure_save_folder[1]) 

        return mean_score_list, node_error
    
    
    ##### Function for getting dataloader
    def get_loaders(self, train_dataset, seed, batch, val_ratio=0.1):
        dataset_len = int(len(train_dataset))
        train_use_len = int(dataset_len * (1 - val_ratio))
        val_use_len = int(dataset_len * val_ratio)
        val_start_index = random.randrange(train_use_len)
        indices = torch.arange(dataset_len)

        train_sub_indices = torch.cat([indices[:val_start_index], indices[val_start_index+val_use_len:]])
        train_subset = Subset(train_dataset, train_sub_indices)

        val_sub_indices = indices[val_start_index:val_start_index+val_use_len]
        val_subset = Subset(train_dataset, val_sub_indices)

        train_dataloader = DataLoader(train_subset, batch_size=batch, shuffle=True)
        val_dataloader = DataLoader(val_subset, batch_size=batch,shuffle=False)

        return train_dataloader, val_dataloader
    
    

    ##### Function for making path
    def get_save_path(self, save_folder):

        self.env_config['save_path'] = save_folder
        dir_path = save_folder

        save_list = save_folder.split("_") 
        dataname = save_list[0]
        modelname = save_list[1]
        for i in range(len(save_list)-7):
            modelname = modelname+'_'+save_list[i+2]
        pt_name = save_list[-5]+'_'+save_list[-4]+'_'+save_list[-3]+'_'+save_list[-2]
        seed_num  = save_list[-1]

        paths = [f'./pretrained/{dataname}_{modelname}/{pt_name}_{seed_num}.pt',
                f'./error_dist/{dataname}_{modelname}/{pt_name}/{seed_num}',
                f'./results/{dataname}_{modelname}/{pt_name}/{seed_num}']
        for path in paths:
            dirname = os.path.dirname(path)
            Path(dirname).mkdir(parents=True, exist_ok=True)

        return paths



if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('-batch', help='batch size', type = int, default=128)
    parser.add_argument('-epoch', help='train epoch', type = int, default=100)
    parser.add_argument('-slide_win', help='slide_win', type = int, default=15)
    parser.add_argument('-dim', help='dimension', type = int, default=64)
    parser.add_argument('-slide_stride', help='slide_stride', type = int, default=5)
    parser.add_argument('-dataset', help='MSL / SWAT /...', type = str, default='MSL')
    parser.add_argument('-device', help='cuda / cpu', type = str, default='cuda')
    parser.add_argument('-random_seed', help='random seed', type = int, default=6799)
    parser.add_argument('-out_layer_num', help='outlayer num', type = int, default=1)
    parser.add_argument('-out_layer_inter_dim', help='out_layer_inter_dim', type = int, default=256)
    parser.add_argument('-decay', help='decay', type = float, default=0)
    parser.add_argument('-val_ratio', help='val ratio', type = float, default=0.1)
    parser.add_argument('-topk', help='topk num', type = int, default=15)
    parser.add_argument('--graph_structure', help='graph structure', type = str, default='topk')
    parser.add_argument('-load_model_path', help='trained model path', type = str, default='')
    parser.add_argument('-model', help='gdn, adagdn', type = str, default='proactive_anomaly_detection')
    parser.add_argument('-use_pretrained', help='use pretrained model', type=eval, default=True)
    parser.add_argument('-lowk', help='lowk num', type = int, default=15)
    parser.add_argument('-attention_opt', help='use graph edge attention', type=eval, default=True)
    parser.add_argument('-identity_opt', help='use identity matrix (MLP)', type=eval, default=False)
    parser.add_argument('--sp_attention_opt', help='use sparse attention', type=eval, default=True)
    
    args = parser.parse_args()

    train_config = {
        'batch': args.batch,
        'epoch': args.epoch,
        'slide_win': args.slide_win,
        'dim': args.dim,
        'slide_stride': args.slide_stride,
        'seed': args.random_seed,
        'out_layer_num': args.out_layer_num,
        'out_layer_inter_dim': args.out_layer_inter_dim,
        'decay': args.decay,
        'val_ratio': args.val_ratio,
        'topk': args.topk,
        'lowk': args.lowk,
        'model': args.model,
        'use_pretrained': args.use_pretrained,
        'graph_structure': args.graph_structure,
        'attention_opt': args.attention_opt,
        'identity_opt': args.identity_opt,
        'sp_attention_opt': args.sp_attention_opt
    }

    env_config={
        'save_path': '',
        'dataset': args.dataset,
        'device': args.device,
        'load_model_path': args.load_model_path,
    }
    
    ##### hyperparameter #####
    dim_cand = [256]                
    out_layer_num_cand = [4]        
    topk_cand = [30]                
    decay = [0.4]                           
    models = ["proactive_anomaly_detection"]
    
    param = []
    for d in dim_cand:
        for oln in out_layer_num_cand:
            for tk in topk_cand:
                for dc in decay:                    
                    param.append([d, oln, tk, dc])

    print('-------------- ', env_config['dataset'], ' --------------')
    for mod in models:
        train_config['model'] = mod
        k=0
        
        ##### save error
        train_con_mse = []      ;        train_cate_mse = []
        val_con_mse = []        ;        val_cate_mse = []      
        test_con_mse = []       ;        test_cate_mse = []      
        
        error_node_list =[]  
        node_error_std = []     ;        node_error_mean = []
        
        F1_pak_list = []        ;        F1_comp_list = []          ;        F1_range_list = []
        
        for p in param:
            d, oln, tk, dc = p
            
            ##### seed and environment setting
            np.random.seed(k)
            random.seed(k)
            torch.manual_seed(k)
            torch.cuda.manual_seed(k)
            torch.cuda.manual_seed_all(k)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            torch.cuda.cudnn_enabled = False
            os.environ['PYTHONHASHSEED'] = str(k)
            print(f"Random seed set as {k}")
            
            ##### parameter adapt
            train_config['dim'] = d
            train_config['out_layer_num'] = oln
            train_config['out_layer_inter_dim'] = 256 # olid 256
            train_config['decay'] = dc
            train_config['topk'] = tk
            
            ##### path setting
            save_name_model = env_config['dataset'] + "_" + train_config['model']
            save_name_setting = str(d) + "_" + str(oln)+ "_" + str(tk) + "_"+ str(dc)
            save_folder = save_name_model + "_" + save_name_setting +  "_" + str(k)
            
            ##### load saved path
            if train_config['use_pretrained']:
                env_config['load_model_path'] = f"./pretrained/{save_name_model}/{save_name_setting}_{k}.pt"
            
            
            print('--------- ', k, save_name_model + "_" + save_name_setting, ' ---------')
            main = Main(train_config, env_config, save_folder, debug=False)
            
            ####################  run main file  ####################
            if env_config['load_model_path'] == '':
                Loss, run_vals, conti_ind, cate_ind, train_log = main.run(save_folder)
                ##### Loss, (score_list node_error), conti_ind, cate_ind, (i_epoch, epoch, acu_loss/len(dataloader), acu_loss, val_loss)
                
                print("\nTrain MSE of conti :", train_log[2])
                print("Train Cross-Entropy of cate :", train_log[3])
                print("\nValidation MSE of conti :", train_log[4])
                print("Validation Cross-Entropy of cate :", train_log[5])
                
                train_con_mse.append(train_log[2])
                train_cate_mse.append(train_log[3])
                val_con_mse.append(train_log[4])
                val_cate_mse.append(train_log[5])
            else:
                Loss, run_vals, conti_ind, cate_ind = main.run(save_folder)
            ##### (MSE, CE Loss), (score_list, node_error), conti_ind, cate_ind
            score_list, node_error = run_vals
            error_node_list.append(node_error)
            
            k = k + 1
            
            ####################  print Test loss and node error  ####################
            test_con_mse.append(Loss[0])
            test_cate_mse.append(Loss[1])
            print(f'Test MSE of conti          : {Loss[0]}')
            print(f'Test Cross-Entropy of cate : {Loss[1]}')
                
            print("std of error for each node:", np.std(node_error))
            print("mean of error for each node:", np.mean(node_error))
            
            node_error_std.append(np.std(node_error))
            node_error_mean.append(np.mean(node_error))
            
            
            ####################  print and save best detection score  ####################
            pak_F1 = []
            comp_F1 = []
            range_F1 = []
            print("                 Detection method / F1 / Pre / Rec / F1 pak / Pre pak / Rec pak / F1 range / Pre range / Rec range ")
            for i in range(len(score_list)):
                print(f"Detection method: {score_list[i][0]}")
                print(f"naive: {score_list[i][1:4]} | pak: {score_list[i][4:7]} | comp: {score_list[i][7:10]} | range: {score_list[i][10:]}")
                pak_F1.append(score_list[i][4])
                comp_F1.append(score_list[i][7])
                range_F1.append(score_list[i][10])
            F1_pak_list.append(pak_F1)
            F1_comp_list.append(comp_F1)
            F1_range_list.append(range_F1)
            
            # ####################  save result figure  ####################
            figure_save_folder = f"./error_dist/{save_name_model}"
            figure_save_name = save_name_setting + "_"+ str(k)
            dirname = os.path.dirname(f"{figure_save_folder}/{figure_save_name}.png")
            Path(dirname).mkdir(parents=True, exist_ok=True)
            
            plt.plot(node_error)
            plt.title(f"Error of each node | setting : {figure_save_name}")
            plt.savefig(f"{figure_save_folder}/{figure_save_name}.png")
            plt.clf()


        #########################  Total iteration result  #########################
        print(f'\n --------- result of total {k} iteration ---------')
        if "CAT" in train_config['model']:
            if env_config['load_model_path'] == '':
                print("mean of train conti MSE :", np.mean(train_con_mse))
                print("mean of train cate loss :", np.mean(train_cate_mse))
                print("mean of val conti MSE   :", np.mean(val_con_mse))
                print("mean of val cate loss   :", np.mean(val_cate_mse),"\n")
            print("mean of test conti MSE  :", np.mean(test_con_mse))
            print("mean of test cate loss  :", np.mean(test_cate_mse),"\n")
            print("Node error std          :", np.mean(node_error_std))
            print("Node error mean         :", np.mean(node_error_mean))
            
        print("mean of F1 pak for iteration   : GMM   /  ECOD")
        print("                               ",np.mean(F1_pak_list, axis=0))
        print("mean of F1 comp for iteration  : GMM   /  ECOD")
        print("                               ",np.mean(F1_comp_list, axis=0))
        print("mean of F1 range for iteration : GMM   /  ECOD")
        print("                               ",np.mean(F1_range_list, axis=0),"\n")