#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import logging
import numpy as np
import glob
import os
from scipy import io
import pandas as pd
import yaml
import importlib
from sklearn.metrics import average_precision_score, roc_auc_score, precision_recall_fscore_support

BASELINE_MODELS = ['OCSVM', 'KNN', 'IForest', 'LOF', 'PCA', 'ECOD', 
                   'DeepSVDD', 'GOAD', 'ICL', 'NeuTraL', ] # 'AutoEncoder'

npz_files = glob.glob(os.path.join('./Data', '*.npz'))
npz_datanames = [os.path.splitext(os.path.basename(file))[0] for file in npz_files]

mat_files = glob.glob(os.path.join('./Data', '*.mat'))
mat_datanames = [os.path.splitext(os.path.basename(file))[0] for file in mat_files]

dat_files = glob.glob(os.path.join('./Data', '*.data'))
dat_datanames = [os.path.splitext(os.path.basename(file))[0] for file in dat_files]

arff_files = glob.glob(os.path.join('./Data', '*.arff'))
arff_datanames = [os.path.splitext(os.path.basename(file))[0] for file in arff_files]

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataname', type=str, default='Hepatitis')
    parser.add_argument('--model_type', type=str, default='DRL')
    parser.add_argument('--exp_name', type=str, default=None)
    parser.add_argument('--train_ratio', type=float, default=1.0)

    # Experiment 
    parser.add_argument('--num_heads', type=int, default=None)
    parser.add_argument('--depth', type=int, default=None)
    parser.add_argument('--hidden_dim', type=int, default=None)
    parser.add_argument('--learning_rate', type=float, default=None)

    # Experiment
    parser.add_argument('--mlp_ratio', type=float, default=None)
    parser.add_argument('--dropout_prob', type=float, default=None)
    parser.add_argument('--drop_col_prob', type=float, default=None)
    parser.add_argument('--temperature', type=float, default=None)
    parser.add_argument('--sim_type', type=str, default=None)
    parser.add_argument('--num_repeat', type=int, default=None)
    parser.add_argument('--is_weight_sharing', action='store_true')
    parser.add_argument('--use_pos_enc_as_query', action='store_true')
    parser.add_argument('--num_latents', type=int, default=None)
    parser.add_argument('--num_adapters', type=int, default=None)
    parser.add_argument('--use_vq_loss_as_score', action='store_true')    
    parser.add_argument('--beta', type=float, default=None)
    parser.add_argument('--shrink_thred', type=float, default=None)
    parser.add_argument('--latent_loss_weight', type=float, default=None)
    parser.add_argument('--entropy_loss_weight', type=float, default=None)
    parser.add_argument('--use_entropy_loss_as_score', action='store_true')    
    parser.add_argument('--use_mask_token', action='store_true')    
    parser.add_argument('--not_use_power_of_two', action='store_true')    
    parser.add_argument('--num_memories_not_use_power_of_two', action='store_true')    
    parser.add_argument('--num_memories_twice', action='store_true')    
    parser.add_argument('--is_recurrent', action='store_true')    
    parser.add_argument('--contamination_ratio', type=float, default=None)    
    
    return parser


def load_yaml(args):    

    dict_to_import = args.model_type + '.yaml'
    if args.model_type in BASELINE_MODELS:
        dict_to_import = 'CLASSIC.yaml'

    with open(f'configs/{dict_to_import}', 'r') as f:
        configs = yaml.safe_load(f)

    model_config = configs['default']['model_config']
    train_config = configs['default']['train_config']

    # Replace hyperparameters with data specific ones. 
    if args.dataname in configs:
        for k, v in configs[args.dataname].items():
            if k in model_config:
                model_config[k] = v
            if k in train_config:
                train_config[k] = v

    model_config = replace_transformer_config(args, model_config)
    train_config['model_type'] = args.model_type
    train_config['dataset_name'] = args.dataname
    train_config['train_ratio'] = args.train_ratio
    train_config['base_path'] = args.base_path    
    train_config['learning_rate'] = args.learning_rate if args.learning_rate is not None else train_config['learning_rate']
    train_config['not_use_power_of_two'] = args.not_use_power_of_two # default False -> use power of two
    train_config['num_memories_not_use_power_of_two'] = args.num_memories_not_use_power_of_two # defualt None
    train_config['num_memories_twice'] = args.num_memories_twice # defualt None
    train_config['contamination_ratio'] = args.contamination_ratio # defualt None
    model_config['num_features'] = get_input_dim(args, train_config)

    return model_config, train_config


def build_trainer(model_config, train_config):
    model_type = train_config['model_type']
    if model_type == 'MemPAE':
        from models.MemPAE.Trainer import Trainer        
    return Trainer(model_config, train_config)

def get_input_dim(args, model_config):
    if args.dataname in npz_datanames:
        path = os.path.join(model_config['data_dir'], args.dataname + '.npz')
        data = np.load(path)
    elif args.dataname in mat_datanames:
        path = os.path.join(model_config['data_dir'], args.dataname + '.mat')
        data = io.loadmat(path)
    elif args.dataname in dat_datanames:
        path = os.path.join(model_config['data_dir'], args.dataname + '.data')  
        data = pd.read_csv(path, header=None) 
        print(data)
        # data = io.loadmat(path)
    elif args.dataname in arff_datanames:
        path = os.path.join(model_config['data_dir'], args.dataname + '.arff')
        data, _ = io.arff.loadarff(path)
        data = pd.DataFrame(data)
        samples = pd.get_dummies(data.iloc[:, :-1]).to_numpy()
        labels = data.iloc[:, -1].values
    else:
        raise ValueError(f"Unknown dataset {args.dataname}")
    if args.dataname not in arff_datanames:
        dim = data['X'].shape[-1]

    return dim

def replace_transformer_config(args, model_config):
    model_config['num_heads'] = args.num_heads if args.num_heads is not None else model_config['num_heads']
    model_config['depth'] = args.depth if args.depth is not None else model_config['depth']
    model_config['hidden_dim'] = args.hidden_dim if args.hidden_dim is not None else model_config['hidden_dim']
    model_config['mlp_ratio'] = args.mlp_ratio if args.mlp_ratio is not None else model_config['mlp_ratio']
    model_config['dropout_prob'] = args.dropout_prob if args.dropout_prob is not None else model_config['dropout_prob']
    model_config['is_weight_sharing'] = args.is_weight_sharing # 
    model_config['use_pos_enc_as_query'] = args.use_pos_enc_as_query # 
    model_config['use_mask_token'] = args.use_mask_token # 
    model_config['temperature'] = args.temperature if args.temperature is not None else model_config['temperature']

    return model_config


def aucPerformance(score, labels):
    roc_auc = roc_auc_score(labels, score)
    ap = average_precision_score(labels, score)
    return roc_auc, ap

def F1Performance(score, target):
    normal_ratio = (target == 0).sum() / len(target)
    score = np.squeeze(score)
    threshold = np.percentile(score, 100 * normal_ratio)
    pred = np.zeros(len(score))
    pred[score > threshold] = 1
    precision, recall, f1, _ = precision_recall_fscore_support(target, pred, average='binary')
    return f1

def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "%(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])
    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    return logger