from collections import defaultdict
import os
import pdb

import torch
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score, mean_squared_error,mean_absolute_error,recall_score,precision_score,f1_score

from . import constants

def predict(clf, 
    x_test,
    y_test=None,
    return_loss=False,
    eval_batch_size=256,
    ):
    '''Make predictions by TransTabClassifier.

    Parameters
    ----------
    clf: TransTabClassifier
        the classifier model to make predictions.

    x_test: pd.DataFrame
            input tabular data.

    y_test: pd.Series
        target labels for input x_test. will be ignored if ``return_loss=False``.
    
    return_loss: bool
        set True will return the loss if y_test is given.
    
    eval_batch_size: int
        the batch size for inference.

    Returns
    -------
    pred_all: np.array
        if ``return_loss=False``, return the predictions made by TransTabClassifier.

    avg_loss: float
        if ``return_loss=True``, return the mean loss of the predictions made by TransTabClassifier.

    '''
    clf.eval()
    pred_list, loss_list = [], []
    for i in range(0, len(x_test), eval_batch_size):
        bs_x_test = x_test.iloc[i:i+eval_batch_size]
        bs_y_test = y_test.iloc[i:i+eval_batch_size]
        with torch.no_grad():
            logits, loss = clf(bs_x_test, bs_y_test)
        
        if loss is not None:
            loss_list.append(loss.item())
        if logits.shape[-1] == 1: # binary classification
            pred_list.append(logits.sigmoid().detach().cpu().numpy())
        else: # multi-class classification
            pred_list.append(torch.softmax(logits,-1).detach().cpu().numpy())
    pred_all = np.concatenate(pred_list, 0)
    if logits.shape[-1] == 1:
        pred_all = pred_all.flatten()

    if return_loss:
        avg_loss = np.mean(loss_list)
        return avg_loss
    else:
        return pred_all

def evaluate(ypred, y_test, metric='auc', seed=123, bootstrap=False):
    np.random.seed(seed)
    eval_fn = get_eval_metric_fn(metric)
    res_list = []
    stats_dict = defaultdict(list)
    if bootstrap:
        for i in range(10):
            sub_idx = np.random.choice(np.arange(len(ypred)), len(ypred), replace=True)
            sub_ypred = ypred[sub_idx]
            sub_ytest = y_test.iloc[sub_idx]
            try:
                sub_res = eval_fn(sub_ytest, sub_ypred)
            except ValueError:
                print('evaluation went wrong!')
            stats_dict[metric].append(sub_res)
        for key in stats_dict.keys():
            stats = stats_dict[key]
            alpha = 0.95
            p = ((1-alpha)/2) * 100
            lower = max(0, np.percentile(stats, p))
            p = (alpha+((1.0-alpha)/2.0)) * 100
            upper = min(1.0, np.percentile(stats, p))
            print('{} {:.2f} mean/interval {:.4f}({:.2f})'.format(key, alpha, (upper+lower)/2, (upper-lower)/2))
            if key == metric: res_list.append((upper+lower)/2)
    else:
        res = eval_fn(y_test, ypred)
        res_list.append(res)
    return res_list

def get_eval_metric_fn(eval_metric):
    fn_dict = {
        'acc': acc_fn,
        'auc': auc_fn,
        'rec':rec_fn,
        'pre':pre_fn,
        'f1':f1_fn,
        'mse': mse_fn,
        'mae': mae_fn,
        'val_loss': None,
    }
    return fn_dict[eval_metric]

def acc_fn(y, p):
    y_p = np.argmax(p, -1)
    return accuracy_score(y, y_p)

def auc_fn(y, p):
    return roc_auc_score(y, p,multi_class='ovr')

def mse_fn(y, p):
    if len(p.shape)==1:
        return mean_squared_error(y, p)
    y_p = np.argmax(p, -1)
    return mean_squared_error(y, y_p)

def mae_fn(y,p):
    if len(p.shape)==1:
        return mean_squared_error(y, p)
    y_p = np.argmax(p, -1)
    return mean_absolute_error(y,y_p)

def rec_fn(y, p):
    y_p = np.argmax(p, -1)
    return recall_score(y, y_p,average='macro')


# micro：通过计算总真阳性，假阴性和误报来全球计算指标。也就是把所有的类放在一起算（具体到precision），然后把所有类的TP加和，再除以所有类的TP和FN的加和。因此micro方法下的precision和recall都等于accuracy。
#
# macro：计算每个标签的指标，找出它们的未加权平均值。这不会考虑标签不平衡。也就是先分别求出每个类的precision再求其算术平均。
#
# weighted：计算每个标签的指标，并找到它们的平均值，按支持加权（每个标签的真实实例数）。这会改变“宏观”以解决标签不平衡问题;
# 它可能导致F分数不在精确度和召回之间。


def pre_fn(y, p):
    y_p = np.argmax(p, -1)
    return precision_score(y, y_p,average='macro')

def f1_fn(y,p):
    y_p = np.argmax(p, -1)
    return f1_score(y,y_p,average='macro')



class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, output_dir='ckpt', trace_func=print, less_is_better=False):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print     
            less_is_better (bool): If True (e.g., val loss), the metric is less the better.       
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = output_dir
        self.trace_func = trace_func
        self.less_is_better = less_is_better

    def __call__(self, val_loss, model):
        if self.patience < 0: # no early stop
            self.early_stop = False
            return
        
        if self.less_is_better:
            score = val_loss
        else:    
            score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), os.path.join(self.path, constants.WEIGHTS_NAME))
        self.val_loss_min = val_loss

