import os
import gc

import numpy as np
import torch

from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelBinarizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import auc
from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, precision_recall_curve, average_precision_score
from imblearn.under_sampling import RandomUnderSampler, NearMiss
from imblearn.over_sampling import RandomOverSampler, SMOTE
import logging
from tqdm import tqdm, trange
import joblib
import cupy as cp
from joblib import dump, load

from copy import copy
import pickle


import warnings
warnings.filterwarnings('ignore')

# set the device to run
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

#from py_boost import Callback
from py_boost import GradientBoosting, SketchBoost, TLPredictor, TLCompiledPredictor

# strategies to deal with multiple outputs
from py_boost.multioutput.sketching import *
from py_boost.multioutput.target_splitter import *
from py_boost import Callback

from evaluate import load
metric = load("accuracy")

class WarmStart(Callback):
    
    def __init__(self, model):
        
        model.to_cpu()
        self.model = copy(model)
        self.model.postprocess_fn = lambda x: x
        
    def before_train(self, build_info):
        
        build_info['model'].base_score = cp.asarray(self.model.base_score)
        
        train = build_info['data']['train']
        train['ensemble'] = cp.asarray(self.model.predict(train['features_cpu']))
        
        valid = build_info['data']['valid']
        valid['ensemble'] = [cp.asarray(self.model.predict(x)) for x in valid['features_cpu']]
        
        self.model.to_cpu()
        
        return 
    
    def after_train(self, build_info):
        
        build_info['model'].models = self.model.models + build_info['model'].models
        # update the actual iteration
        build_info['num_iter'] = build_info['num_iter'] + len(self.model.models)
        # update the actual best round
        early_stop = build_info['model'].callbacks.callbacks[-1]
        early_stop.best_round = early_stop.best_round + len(self.model.models)
        
        # not to store old trees multiple times
        self.model = None
        
        return

def f1_max(pred, target):
    """
    F1 score with the optimal threshold.

    This function first enumerates all possible thresholds for deciding positive and negative
    samples, and then pick the threshold with the maximal F1 score.

    Parameters:
        pred (Tensor): predictions of shape :math:`(B, N)`
        target (Tensor): binary targets of shape :math:`(B, N)`
    """
    order = pred.argsort(descending=True, dim=1)
    target = target.gather(1, order)
    precision = target.cumsum(1) / torch.ones_like(target).cumsum(1)
    recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10)
    is_start = torch.zeros_like(target).bool()
    is_start[:, 0] = 1
    is_start = torch.scatter(is_start, 1, order, is_start)

    all_order = pred.flatten().argsort(descending=True)
    order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1]
    order = order.flatten()
    inv_order = torch.zeros_like(order)
    inv_order[order] = torch.arange(order.shape[0], device=order.device)
    is_start = is_start.flatten()[all_order]
    all_order = inv_order[all_order]
    precision = precision.flatten()
    recall = recall.flatten()
    all_precision = precision[all_order] - \
                    torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1])
    all_precision = all_precision.cumsum(0) / is_start.cumsum(0)
    all_recall = recall[all_order] - \
                 torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1])
    all_recall = all_recall.cumsum(0) / pred.shape[0]
    all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10)
    
    return all_f1[~torch.isnan(all_f1)].max()


def caculate_metric(pred_y, labels, pred_prob):
    #print('labels', labels.shape) # [n_sample, num_class]
    #print('pred_y', pred_y.shape) # [n_sample, num_class]
    #print('pred_prob', pred_prob.shape) # [n_sample, num_class]

    test_num = len(labels)
    tp = 0
    fp = 0
    tn = 0
    fn = 0

    for index in range(test_num):
        if labels[index] == 1:
            if labels[index] == pred_y[index]:
                tp = tp + 1
            else:
                fn = fn + 1
        else:
            if labels[index] == pred_y[index]:
                tn = tn + 1
            else:
                fp = fp + 1

    # print('tp\tfp\ttn\tfn')
    # print('{}\t{}\t{}\t{}'.format(tp, fp, tn, fn))

    ACC = float(tp + tn) / test_num

    # precision
    if tp + fp == 0:
        Precision = 0
    else:
        Precision = float(tp) / (tp + fp)

    # SE
    if tp + fn == 0:
        Recall = Sensitivity = 0
    else:
        Recall = Sensitivity = float(tp) / (tp + fn)

    # SP
    if tn + fp == 0:
        Specificity = 0
    else:
        Specificity = float(tn) / (tn + fp)

    # MCC
    if (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) == 0:
        MCC = 0
    else:
        MCC = float(tp * tn - fp * fn) / (np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)))

    # F1-score
    if Recall + Precision == 0:
        F1 = 0
    else:
        F1 = 2 * Recall * Precision / (Recall + Precision)

    #labels = labels.tolist()
    #pred_prob = pred_prob.tolist()

    # ROC and AUC
    print(pred_prob[:,1].shape, labels.shape)
    fpr, tpr, thresholds = roc_curve(labels, pred_prob[:,1])  # 默认1就是阳性
    AUC = auc(fpr, tpr)

    # PRC and AP
    #precision, recall, thresholds = precision_recall_curve(labels, pred_prob[:,1])
    #AP = average_precision_score(labels, pred_prob, average='macro', pos_label=1, sample_weight=None)

    return Precision, Sensitivity, Specificity, F1, AUC, MCC #, recall, precision, AP


def calc(data_folder, model_name, warm_start='', num_trees=3000, leaf=10, method='method_3', sc = '', data_type = 'features', sampling = 'under', dataset = 'DNA', lr=0.03, hessian = True, max_depth = 4):

    #logging.basicConfig(filename='esm_tda/HHblits/facebook/esm2_t33_650M_UR50D/results/thr_0.85_with_vert_method_1.log', encoding='utf-8', level=logging.DEBUG)
    #logger.debug('This message should go to the log file')
    
    N_RERUNS = 1
    np.random.seed(42)
    seeds = np.random.randint(1000, size=(N_RERUNS))
    
    #C_range = [13, 14, 15, 16, 17, 18, 19, 20]
    #C_range = [13, 16, 19]
    C_range = [16]
    
    q = np.zeros((seeds.shape[0], 5))    



    train = np.load(f'{data_folder}/train_{dataset}_{method}_{model_name}{sc}.npy' , allow_pickle=True)
    val = np.load(f'{data_folder}/val_{dataset}_{method}_{model_name}{sc}.npy' , allow_pickle=True)
    test = np.load(f'{data_folder}/test_{dataset}_{method}_{model_name}{sc}.npy' , allow_pickle=True)    

    print('Train:', train.shape)
    #print('Val:', val.shape)
    #print('Test:', test.shape)

    logger = logging.getLogger(__name__)
    logger_filename = f'logs/log_{dataset}_{method}_{model_name}_{data_type}{sc}_sampling_{sampling}_trees_{num_trees}_leaf_{leaf}_lr_{lr}_hessian_{hessian}_max_depth_{max_depth}.log' 
    logger.setLevel(logging.DEBUG)
    handler = logging.FileHandler(logger_filename, 'w', 'utf-8')
    logger.addHandler(handler)

    '''
    train_sc = np.load(f'{data_folder}train_all_heads_all_layers_{filename}_per_tokens_esm2_t33_650M_UR50D_sc.npy' , allow_pickle=True)  #from 2319963
    print(train_sc.shape)
    model_importance = joblib.load(f'{data_folder}model_{filename}_trees_3000_seed_102_2000k.pkl')
    
    importances = model_importance.get_feature_importance()
    np.save('importances_by_model.npy', np.array(importances))
    '''
    print('Data loaded!') 

    y_train = np.load(f'{data_folder}/y_train_{dataset}.npy' , allow_pickle=True) #[:600000]
    print('y_train', y_train.shape)
    #print(np.histogram(np.array(y_train), bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
    lb = LabelBinarizer().fit(y_train)
    y_val = np.load(f'{data_folder}/y_val_{dataset}.npy' , allow_pickle=True) 
    #print('y_val', y_val.shape)
    y_test = np.load(f'{data_folder}/y_test_{dataset}.npy' , allow_pickle=True) 
    #print('y_test', y_test.shape)
    #print('Train/Val/Test labels loaded')

    logger.info(f'dataset - {dataset}, method - {method}, model - {model_name}, sampling - {sampling}')
    logger.info(f'trees - {num_trees}, scaling - {sc}, leaf - {leaf}, lr - {lr}, hessian - {hessian}, max_depth - {max_depth}')
    for i in range(seeds.shape[0]):
        if sampling == 'under':
            UnderS = RandomUnderSampler(random_state=seeds[i], replacement=True)      
            train, y_train = UnderS.fit_resample(train, y_train)  
        elif sampling == 'over':
            over_sampler = RandomOverSampler(random_state=seeds[i])
            train, y_train = over_sampler.fit_resample(train, y_train)
        elif sampling == 'SMOTE':
            smote = SMOTE()
            train, y_train = smote.fit_resample(train, y_train)
        elif sampling == 'NearMiss':
            nm = NearMiss()
            train, y_train = nm.fit_resample(train, y_train)

        #best_quality, best_C = 0.0, 10
        #prev_model_name = f'models/model_{dataset}_{method}_{model_name}_{data_type}_trees_10000_lr_0.1_hessian_{hessian}_max_depth_{max_depth}_seed_{seeds[i]}.pkl' 
        #prev_model = joblib.load(prev_model_name)       
        print('Start new model')
        sketch = RandomProjectionSketch(10) #10
        model = GradientBoosting(
            'crossentropy',
            ntrees=num_trees, lr=lr, verbose=500, es=300, lambda_l2=1, gd_steps=1, 
            subsample=1, colsample=1, min_data_in_leaf=leaf, use_hess=hessian, 
            max_bin=256, max_depth=max_depth, seed = seeds[i], 
            multioutput_sketch=sketch, debug=True #, callbacks=[WarmStart(prev_model)]
        )
        #del prev_model
        #gc.collect()        
        print('Start model fit')
        model.fit(train.astype(np.float32), y_train, 
                eval_sets = [{'X': val.astype(np.float32), 'y': y_val}])   
        #indices = np.argsort(np.array(model.get_feature_importance()))
        #np.save(f'{data_folder}/indices_seed_{seeds[i]}.npy', indices, allow_pickle=True)
        dump_model_name =  f'models/model_{dataset}_{method}_{model_name}_{data_type}{sc}_sampling_{sampling}_trees_{num_trees}_leaf_{leaf}_lr_{lr}_hessian_{hessian}_max_depth_{max_depth}_seed_{seeds[i]}.pkl' 
        joblib.dump(model, dump_model_name)            
        print('Model fitted')
        pred_prob = model.predict(test)
        #print(pred_prob.shape, pred_prob[:5,:])
        pred_y = np.argmax(pred_prob, axis=1)
        
        Precision, Sensitivity, Specificity, F1, AUC, MCC = caculate_metric(pred_y, y_test, pred_prob) #, recall, precision, AP
        #print('precision: ', Precision, precision)
         #, nthread=4
    
        q[i][0] = np.round(Sensitivity, 5) * 100.0
        q[i][1] = np.round(Specificity, 5) * 100.0
        q[i][2] = np.round(Precision, 5) * 100.0   
        q[i][3] = np.round(MCC, 5) * 100.0  
        q[i][4] = np.round(AUC, 5) * 100.0  
        
        logger.info(f'Sensitivity - {q[i][0]}')
        logger.info(f'Specifcity - {q[i][1]}')
        logger.info(f'Precision - {q[i][2]}')
        logger.info(f'MCC - {q[i][3]}')
        logger.info(f'AUC - {q[i][4]}')        
        #del model, sketch, test_pred, pred_torch, pred #, train_sc, val_sc, test_sc
        #gc.collect()
        #print('Model deleted')
    


 
    logger.info(f'seed - {i}, num trees - {num_trees}')
    logger.info(f'Mean est. Sensitivity : {np.mean(q[:, 0])} +- {np.std(q[:, 0])}')
    logger.info(f'Mean est. Specifcity : {np.mean(q[:, 1])} +- {np.std(q[:, 1])}')
    logger.info(f'Mean est. Precision : {np.mean(q[:, 2])} +- {np.std(q[:, 2])}')
    logger.info(f'Mean est. MCC : {np.mean(q[:, 3])} +- {np.std(q[:, 3])}')
    logger.info(f'Mean est. AUC : {np.mean(q[:, 4])} +- {np.std(q[:, 4])}')


def main():
    dataset = 'ZN' #'MG' #'MN' #'CA' #'PRO' #'PEP' #'HEM' #'ATP' #'RNA' #'ATP' #'DNA' #'RNA' #'DNA' #'PEP' 
    lr = 0.03
    num_trees = 100000 #2500
    model_name = 'esm2_t33_650M_UR50D' #'esm2_t48_15B_UR50D' #'esm2_t33_650M_UR50D' #'esm2_t36_3B_UR50D' #'esm2_t33_650M_UR50D' #'esm2_t36_3B_UR50D' #'esm2_t6_8M_UR50D'
    method = 'last_embs_method_3_attns' #'sum_method_3_attns' # 'sum_method_3_attns' #'sum_method_3_attns' #'last_embs' #'last_embs' #'sum_method_3_attns' #'last_embs' #'last_embs' #'method_3_attns' #'method_3_attns' #'last_embs' #'last_embs_method_3_attns' #'last_embs' #'sum_method_3_attns' #'last_embs_sum_method_3_attns' #'method_3_attns' #'last_embs_method_3_attns' #'last_embs' #'sum_method_4_attns' #'last_embs_method_3_attns' #'method_4_attns' #'last_embs' #'last_embs_method_3_attns' #'method_4_attns' #'last_embs' #'method_4_attns' #'last_embs_method_3_attns' #'last_embs' #'method_4_attns' #'last_embs' #'method_3_attns'
    attns = '' #'attns_with_cls'
    data_folder = f'/data/binding/{dataset}/{model_name}/features'
    data_type = 'emb_features' #'features' #'features' #'features' #'features' #'emb' #'emb' #'features' #'emb' #'emb' #'features' #'emb' #''emb_features' #'emb' #'features' #'emb_features' #'features' #'emb_features' #'embs' #'features' #'embs' #'emb_features' #'features' #'embs' #'features' #'emb_features' #'embs' #'features' #'embs' #'features'
    warm_start = 'batch'
    hessian = True
    max_depth = 4 #3
    leaf = 10 #10
    sc = '' #'_sc'
    sampling =  'SMOTE' # 'over' #'NearMiss'  #'SMOTE' #'over' ##'under' 'NearMiss' 
    

    '''
    method = 'last_embs' #'method_3_attns'
    train_e = np.load(f'{data_folder}/train_{dataset}_{method}_{model_name}.npy' , allow_pickle=True)
    val_e = np.load(f'{data_folder}/val_{dataset}_{method}_{model_name}.npy' , allow_pickle=True)
    test_e = np.load(f'{data_folder}/test_{dataset}_{method}_{model_name}.npy' , allow_pickle=True)
    
    method = 'sum_method_3_attns' #'last_embs' #'method_3_attns'
    train_m = np.load(f'{data_folder}/train_{dataset}_{method}_{model_name}{sc}.npy' , allow_pickle=True)
    val_m = np.load(f'{data_folder}/val_{dataset}_{method}_{model_name}{sc}.npy' , allow_pickle=True)
    test_m = np.load(f'{data_folder}/test_{dataset}_{method}_{model_name}{sc}.npy' , allow_pickle=True)
    #'''
    '''
    sc = StandardScaler()
    sc.fit(train_m)
    train_sc = sc.transform(train_m)    
    np.save(f'{data_folder}/train_{dataset}_{method}_{model_name}_sc.npy', train_sc, allow_pickle=True)
    val_sc = sc.transform(val_m)    
    np.save(f'{data_folder}/val_{dataset}_{method}_{model_name}_sc.npy', val_sc, allow_pickle=True)
    test_sc = sc.transform(test_m)    
    np.save(f'{data_folder}/test_{dataset}_{method}_{model_name}_sc.npy', test_sc, allow_pickle=True)  
    '''
    '''
    train = np.concatenate((train_e, train_m), axis = 1)
    val = np.concatenate((val_e, val_m), axis = 1)
    test = np.concatenate((test_e, test_m), axis = 1)
    
    method = f'last_embs_{method}'
    np.save(f'{data_folder}/train_{dataset}_{method}_{model_name}{sc}.npy', train, allow_pickle=True)
    np.save(f'{data_folder}/val_{dataset}_{method}_{model_name}{sc}.npy', val, allow_pickle=True)
    np.save(f'{data_folder}/test_{dataset}_{method}_{model_name}{sc}.npy', test, allow_pickle=True)
    #'''
    calc(data_folder, model_name, warm_start, num_trees, leaf, method, sc, data_type, sampling, dataset, lr=lr, hessian=hessian, max_depth=max_depth) 

if __name__ == "__main__":
    main()