import pandas as pd
import numpy as np
import torch
import os
from tqdm import tqdm
import logging
import sys
from datetime import datetime
from optuna.pruners import BasePruner
import optuna
import yaml
from types import SimpleNamespace
from typing import Dict
import argparse

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torchmetrics.functional import calibration_error

def get_col_dict(file: str = '220106_자료/KAMIR-V 1Y FU DATA_Cumulative Death_20220106 변수 정리.xlsx') -> Dict[str, str]:
    cols_info = pd.read_excel(file, sheet_name='Questionnaire').iloc[:610]

    col_dict = {}

    duplicated = {}
    for i in range(len(cols_info)):
        col = cols_info['기관코드'][i]
        info = cols_info['변수 설명 '][i]
        try:
            np.isnan(info)
        except:
            if col in col_dict.keys():
                if col in duplicated.keys():
                    duplicated[col] += 1
                else:
                    duplicated[col] = 1
                col = col + '.' + str(duplicated[col])
            
            col_dict[col] = info

    return col_dict

def merge_config(config: SimpleNamespace, 
                args: argparse.Namespace
    ) -> SimpleNamespace:
    with open("data_config/" + args.data_config + '.yaml', encoding='UTF-8') as f:
        data_config = yaml.load(f, Loader=yaml.FullLoader)

    config.data = SimpleNamespace(**data_config)

    config = set_config_from_args(config, args)
    config.runner_option = args
    return config

def load_data_config(config_loc: str):
    with open(config_loc, encoding='UTF-8') as f:
        data_config = yaml.load(f, Loader=yaml.FullLoader)
    config = SimpleNamespace()
    config.data = SimpleNamespace(**data_config)
    
    config.runner_option = SimpleNamespace()
    config.runner_option.save_data = False
    
    return config
    
def set_config_from_args(config: SimpleNamespace, 
                        args: argparse.Namespace
    ) -> SimpleNamespace:
    if args.n_jobs is not None:
        config.n_jobs = args.n_jobs
        if hasattr(config.model, "n_jobs"):
            config.model.n_jobs = args.n_jobs
    
    if args.n_trials is not None:
        config.optuna.n_trials = args.n_trials
        
    if args.alpha is not None:
        config.self_training.alpha = args.alpha
    if args.delta is not None:
        config.self_training.delta = args.delta
    if args.threshold is not None:
        config.self_training.threshold = args.threshold

    if args.hparams is not None:
        config.model.hparams = args.hparams
    if args.fast_dev_run is not None:
        config.model.fast_dev_run = args.fast_dev_run
    if args.device is not None:
        config.model.device = args.device
    if args.gpus is not None:
        config.model.gpus = args.gpus 
    if args.batch_size is not None:
        config.model.batch_size = args.batch_size
    if args.early_stopping_patience is not None:
        config.model.early_stopping_patience = args.early_stopping_patience
    if args.n_splits is not None:
        config.KFold.n_splits = args.n_splits
        
    if args.dist_threshold is not None:
        config.data_editor.dist_threshold = args.dist_threshold
        
    del args.n_jobs
    del args.n_trials
    del args.alpha
    del args.delta
    del args.threshold
    del args.hparams
    del args.fast_dev_run
    del args.device
    del args.gpus
    del args.batch_size
    del args.early_stopping_patience
    del args.n_splits
    del args.dist_threshold
    
    return config
    
class StreamToLogger(object):
    """
    Fake file-like stream object that redirects writes to a logger instance.
    """
    def __init__(self, logger, level):
        self.logger = logger
        self.level = level
        self.linebuf = ''

    def write(self, buf):
        for line in buf.rstrip().splitlines():
            self.logger.log(self.level, line.rstrip())

    def flush(self):
        pass

def setup_logger(name, args, config, save_dir = 'log'):
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)
    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s\n")
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    
    # optuna.logging.enable_propagation()  # Propagate logs to the root logger.
    # optuna.logging.enable_default_handler() 
    # optuna.logging.set_verbosity(optuna.logging.DEBUG)
    if args.save_log:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        now = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
        fh = logging.FileHandler(f"{save_dir}/{now}-{config.data.target}-{args.model_type}.txt")
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(formatter)
        logger.addHandler(fh)

        sys.stdout = StreamToLogger(logger, logging.INFO)
        sys.stderr = StreamToLogger(logger, logging.ERROR)

    return logger

class Pruner(BasePruner):
    def __init__(self, logger) -> None:
        self.logger = logger

    def prune(self, study: "Study", trial: "FrozenTrial") -> bool:
        self.logger.info("invoked")
        return False

def getColIdx(numeric_cols, data):
    numeric_idx = []
    for idx, col in enumerate(data.columns):
        if col in numeric_cols:
            numeric_idx.append(idx)
    
    category_idx = list(set(range(len(data.columns))) - set(numeric_idx))

    return numeric_idx, category_idx

def getDimPerIdx(category_idx, data):

    cat_dims = []
    for idx in category_idx:
        cat_dims.append(len(set(data.values[:, idx])))

    return cat_dims

def plot_pl_info(data: Dict, n_bins: int = 10, threshold: float = 0.6):
    # cnt, bins, _ = plt.hist(data['preds_proba'].max(1), bins=n_bins, )
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    indices = np.digitize(data['preds_proba'].max(1), bins, right=True)
    cnt = []
    plt.cla()
    plt.clf()
    preds = data['preds_proba'].argmax(1)
    proba_max = data['preds_proba'].max(1)
    x = []
    y = []
    mse = 0
    ece = 0
    rece = 0
    for i in range(n_bins):
        # indice = (bins[i]<=proba_max) & (proba_max <= bins[i + 1])
        indice = np.where(indices == i + 1)[0]
        cnt.append(len(indice))
        if len(indice) == 0:
            x.append(0)
            y.append(0)
            continue
        acc = accuracy_score(data['label'].values[indice], preds[indice])
        # x.append((bins[i] + bins[i + 1]) / 2)
        x.append(np.mean(proba_max[indice]))
        y.append(acc)
        mse += np.max((x[-1] - acc, 0)) ** 2 * len(indice)
        ece += np.abs(np.mean(proba_max[indice]) - acc) * len(indice) / len(preds)
        rece += max(np.mean(proba_max[indice]) - acc, 0) * len(indice) / len(preds)
        # mse += (x[-1] - acc) ** 2 * indice.sum()
        # temp = np.concatenate((x[-1] - proba_max[indice].reshape(-1, 1), np.zeros_like(proba_max[indice].reshape(-1, 1))), axis=1)
        # mse += np.sum(np.max(temp, axis=1)**2)
    x = np.round(x, 2)
    cnt = np.array(cnt, dtype=np.float32)
    # cnt /= sum(cnt)
    cnt = np.array(list(map(lambda x : x / sum(cnt) if x > 0 else x, cnt)))
    cnt *= 100
    mse /= len(preds)
    # mae /= len(preds)
    print("MSE :", mse)
    print("ECE :", ece * 100)
    print("RECE :", rece * 100)
    pl_indices = proba_max >= threshold
    pl_recall = (preds[pl_indices] == data['label'].values[pl_indices]).sum() / (preds == data['label'].values).sum()
    pl_precision = (preds[pl_indices] == data['label'].values[pl_indices]).sum() / pl_indices.sum()
    # pl_recall = (preds[pl_indices] == 1).sum() / (preds == 1).sum()
    # pl_precision = ((preds[pl_indices] == 1) & (data['label'][pl_indices].values == 1)).sum() / (data['label'][pl_indices].values == 1).sum()
    pl_f1 = 2 * pl_recall * pl_precision / (pl_recall + pl_precision)
    print("PL_Recall :", pl_recall)
    print("PL_Precision :", pl_precision)
    print("PL_F1 Score :", pl_f1)
    
    print(precision_score(data['label'].values[pl_indices], preds[pl_indices]))
    print(recall_score(data['label'].values[pl_indices], preds[pl_indices]))
    print(f1_score(data['label'].values[pl_indices], preds[pl_indices]))
    #ECE = CalibrationError(n_bins=n_bins, task=task)
    # one_hot = torch.nn.functional.one_hot(torch.LongTensor(data['label'].values))
    # temp = temp.astype(np.float32)
    # print(ECE(torch.Tensor(data['preds_proba']), one_hot))
    # print(calibration_error(torch.Tensor(data['preds_proba']), one_hot, n_bins=10))#,  task=task))
    fig, axes = plt.subplots(2,1)
    fig.set_figheight(10)
    ax1, ax2 = axes
    # fig.set_dpi(100)
    sns.barplot(data=pd.DataFrame(np.concatenate((x.reshape((-1, 1)), cnt.reshape((-1,1))), axis=1), columns=['Confidence', r'% of Samples']), 
                x = 'Confidence', y = r'% of Samples', color='dodgerblue', ax=ax1, alpha = 0.5)
    avg_acc = accuracy_score(data['label'], preds)
    
    ax1_x = abs(ax1.get_xbound()[0]) + abs(ax1.get_xbound()[1])
    ax1_y = abs(ax1.get_ybound()[0]) + abs(ax1.get_ybound()[1])
    ax1.axvline(x=avg_acc * ax1_x, linestyle='--', color='r')
    ax1.text((avg_acc+0.02) * ax1_x,0.1 * ax1_y,'accuracy',rotation=90, color='r')
    
    avg_conf = np.mean(proba_max)
    ax1.axvline(x=np.average(avg_conf) * ax1_x, linestyle='--', color='b')
    ax1.text((np.average(avg_conf)+0.02) * ax1_x,0.5 * ax1_y,'avg. confidence',rotation=90, color='b')
    plt.plot()
    
    # _, ax2 = plt.subplots(2, 1)
    sns.barplot(x=x, y= x, ax=ax2, alpha=0.5, color='red', label='conf')
    sns.barplot(x=x, y= np.array(y), ax=ax2, alpha=0.5, color='blue', label='acc')
    ax2.legend()
    ax2.set_xlabel('Confidence')
    plt.plot()