import os
import gc

import numpy as np
import torch

from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelBinarizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc
import logging
from tqdm import tqdm, trange
import joblib
import cupy as cp
from joblib import dump, load

from copy import copy
import pickle


import warnings
warnings.filterwarnings('ignore')

# set the device to run
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

#from py_boost import Callback
from py_boost import GradientBoosting, SketchBoost, TLPredictor, TLCompiledPredictor

# strategies to deal with multiple outputs
from py_boost.multioutput.sketching import *
from py_boost.multioutput.target_splitter import *
from py_boost import Callback

from evaluate import load
metric = load("accuracy")

class WarmStart(Callback):
    
    def __init__(self, model):
        
        model.to_cpu()
        self.model = copy(model)
        self.model.postprocess_fn = lambda x: x
        
    def before_train(self, build_info):
        
        build_info['model'].base_score = cp.asarray(self.model.base_score)
        
        train = build_info['data']['train']
        train['ensemble'] = cp.asarray(self.model.predict(train['features_cpu']))
        
        valid = build_info['data']['valid']
        valid['ensemble'] = [cp.asarray(self.model.predict(x)) for x in valid['features_cpu']]
        
        self.model.to_cpu()
        
        return 
    
    def after_train(self, build_info):
        
        build_info['model'].models = self.model.models + build_info['model'].models
        # update the actual iteration
        build_info['num_iter'] = build_info['num_iter'] + len(self.model.models)
        # update the actual best round
        early_stop = build_info['model'].callbacks.callbacks[-1]
        early_stop.best_round = early_stop.best_round + len(self.model.models)
        
        # not to store old trees multiple times
        self.model = None
        
        return

def f1_max(pred, target):
    """
    F1 score with the optimal threshold.

    This function first enumerates all possible thresholds for deciding positive and negative
    samples, and then pick the threshold with the maximal F1 score.

    Parameters:
        pred (Tensor): predictions of shape :math:`(B, N)`
        target (Tensor): binary targets of shape :math:`(B, N)`
    """
    order = pred.argsort(descending=True, dim=1)
    target = target.gather(1, order)
    precision = target.cumsum(1) / torch.ones_like(target).cumsum(1)
    recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10)
    is_start = torch.zeros_like(target).bool()
    is_start[:, 0] = 1
    is_start = torch.scatter(is_start, 1, order, is_start)

    all_order = pred.flatten().argsort(descending=True)
    order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1]
    order = order.flatten()
    inv_order = torch.zeros_like(order)
    inv_order[order] = torch.arange(order.shape[0], device=order.device)
    is_start = is_start.flatten()[all_order]
    all_order = inv_order[all_order]
    precision = precision.flatten()
    recall = recall.flatten()
    all_precision = precision[all_order] - \
                    torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1])
    all_precision = all_precision.cumsum(0) / is_start.cumsum(0)
    all_recall = recall[all_order] - \
                 torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1])
    all_recall = all_recall.cumsum(0) / pred.shape[0]
    all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10)
    
    return all_f1[~torch.isnan(all_f1)].max()

def caculate_metric(pred_y, labels, pred_prob):
    #print('labels', labels.shape) # [n_sample, num_class]
    #print('pred_y', pred_y.shape) # [n_sample, num_class]
    #print('pred_prob', pred_prob.shape) # [n_sample, num_class]

    test_num = len(labels)
    tp = 0
    fp = 0
    tn = 0
    fn = 0

    for index in range(test_num):
        if labels[index] == 1:
            if labels[index] == pred_y[index]:
                tp = tp + 1
            else:
                fn = fn + 1
        else:
            if labels[index] == pred_y[index]:
                tn = tn + 1
            else:
                fp = fp + 1

    # print('tp\tfp\ttn\tfn')
    # print('{}\t{}\t{}\t{}'.format(tp, fp, tn, fn))

    ACC = float(tp + tn) / test_num

    # precision
    if tp + fp == 0:
        Precision = 0
    else:
        Precision = float(tp) / (tp + fp)

    # SE
    if tp + fn == 0:
        Recall = Sensitivity = 0
    else:
        Recall = Sensitivity = float(tp) / (tp + fn)

    # SP
    if tn + fp == 0:
        Specificity = 0
    else:
        Specificity = float(tn) / (tn + fp)

    # MCC
    if (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) == 0:
        MCC = 0
    else:
        MCC = float(tp * tn - fp * fn) / (np.sqrt(float((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))))

    # F1-score
    if Recall + Precision == 0:
        F1 = 0
    else:
        F1 = 2 * Recall * Precision / (Recall + Precision)

    #labels = labels.tolist()
    #pred_prob = pred_prob.tolist()

    # ROC and AUC
    print(pred_prob[:,1].shape, labels.shape)
    fpr, tpr, thresholds = roc_curve(labels, pred_prob[:,1])  # 默认1就是阳性
    AUC = auc(fpr, tpr)

    # PRC and AP
    #precision, recall, thresholds = precision_recall_curve(labels, pred_prob[:,1])
    #AP = average_precision_score(labels, pred_prob, average='macro', pos_label=1, sample_weight=None)

    return Precision, Sensitivity, Specificity, F1, AUC, MCC #, recall, precision, AP


def calc(data_folder, model_name, warm_start='', num_trees=3000, leaf =10, method='method_3', data_type = 'features', lr=0.03, hessian = True, max_depth = 4):

    #logging.basicConfig(filename='esm_tda/HHblits/facebook/esm2_t33_650M_UR50D/results/thr_0.85_with_vert_method_1.log', encoding='utf-8', level=logging.DEBUG)
    #logger.debug('This message should go to the log file')
    
    N_RERUNS = 1
    np.random.seed(42)
    seeds = np.random.randint(1000, size=(N_RERUNS))
    
    #C_range = [13, 14, 15, 16, 17, 18, 19, 20]
    #C_range = [13, 16, 19]
    C_range = [16]
    
    q = np.zeros((seeds.shape[0], 7))    

    train = np.load(f'{data_folder}/train_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True) #[:10000,:]
    val = np.load(f'{data_folder}/val_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True)
    test = np.load(f'{data_folder}/test_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True)    

    print('Train:', train.shape)
    print('Val:', val.shape)
    print('Test:', test.shape)

    logger = logging.getLogger(__name__)
    logger_filename = f'logs/log_ConSuf10k2Q_{method}_{model_name}_{data_type}_trees_{num_trees}_leaf_{leaf}_lr_{lr}_hessian_{hessian}_max_depth_{max_depth}.log' 
    logger.setLevel(logging.DEBUG)
    handler = logging.FileHandler(logger_filename, 'w', 'utf-8')
    logger.addHandler(handler)

    '''
    train_sc = np.load(f'{data_folder}train_all_heads_all_layers_{filename}_per_tokens_esm2_t33_650M_UR50D_sc.npy' , allow_pickle=True)  #from 2319963
    print(train_sc.shape)
    model_importance = joblib.load(f'{data_folder}model_{filename}_trees_3000_seed_102_2000k.pkl')
    
    importances = model_importance.get_feature_importance()
    np.save('importances_by_model.npy', np.array(importances))
    '''
    print('Data loaded!') 

    y_train = np.where(np.load(f'{data_folder}/y_train_ConSuf10k.npy' , allow_pickle=True) > 5, 1, 0) #[:10000] #[:600000]
    print('y_train', y_train.shape)
    #lb = LabelBinarizer().fit(y_train)
    y_val = np.where(np.load(f'{data_folder}/y_val_ConSuf10k.npy' , allow_pickle=True) > 5, 1, 0)
    print('y_val', y_val.shape)
    y_test = np.where(np.load(f'{data_folder}/y_test_ConSuf10k.npy' , allow_pickle=True) > 5, 1, 0)
    print('y_test', y_test.shape)
    print('Train/Val/Test labels loaded')

    logger.info(f'data type - {data_type}')
    for i in range(seeds.shape[0]):
        #best_quality, best_C = 0.0, 10
        # prev_model_name = f'models/model_ConSuf10k_{method}_{model_name}_{data_type}_trees_10000_leaf_{leaf}_lr_{lr}_hessian_{hessian}_max_depth_{max_depth}_seed_{seeds[i]}.pkl' 
        #prev_model = joblib.load(prev_model_name)       
        print('Start new model')
        sketch = RandomProjectionSketch(1) #10
        model = GradientBoosting(
            'crossentropy',
            ntrees=num_trees, lr=lr, verbose=500, es=300, lambda_l2=1, gd_steps=1, 
            subsample=1, colsample=1, min_data_in_leaf=leaf, use_hess=hessian, 
            max_bin=256, max_depth=max_depth, seed = seeds[i], 
            multioutput_sketch=sketch, debug=True #, callbacks=[WarmStart(prev_model)]
        )
        #del prev_model
        #gc.collect()        
        print('Start model fit')
        model.fit(train.astype(np.float32), y_train, 
                eval_sets = [{'X': val.astype(np.float32), 'y': y_val}])   
        #indices = np.argsort(np.array(model.get_feature_importance()))
        #np.save(f'{data_folder}/indices_seed_{seeds[i]}.npy', indices, allow_pickle=True)
        dump_model_name =  f'models/model_ConSuf10k2Q_{method}_{model_name}_{data_type}_trees_{num_trees}_leaf_{leaf}_lr_{lr}_hessian_{hessian}_max_depth_{max_depth}_seed_{seeds[i]}.pkl' 
        joblib.dump(model, dump_model_name)            
        print('Model fitted')
        pred_prob = model.predict(test)
        print(pred_prob.shape, pred_prob[:5,:])
        pred_y = np.argmax(pred_prob, axis=1)
        
        Precision, Sensitivity, Specificity, F1, AUC, MCC = caculate_metric(pred_y, y_test, pred_prob) #, recall, precision, AP
        #print('precision: ', Precision, precision)
         #, nthread=4
    
        q[i][0] = np.round(Sensitivity, 5) * 100.0
        q[i][1] = np.round(Specificity, 5) * 100.0
        q[i][2] = np.round(Precision, 5) * 100.0   
        q[i][3] = np.round(MCC, 5) * 100.0  
        q[i][4] = np.round(AUC, 5) * 100.0  
        q[i][5] = np.round(f1_score(y_test, pred_y, average='binary'), 5) * 100.0 
        q[i][6] = np.round(metric.compute(predictions=pred_y, references=y_test)['accuracy'], 5) * 100.0         
        
        logger.info(f'Sensitivity - {q[i][0]}')
        logger.info(f'Specifcity - {q[i][1]}')
        logger.info(f'Precision - {q[i][2]}')
        logger.info(f'MCC - {q[i][3]}')
        logger.info(f'AUC - {q[i][4]}')   
        logger.info(f'F1+ quality - {q[i][5]}')
        logger.info(f'Accuracy - {q[i][6]}') 
        
        del model, sketch, pred_prob, pred_y #, train_feat_t_sc, val_feat_t_sc, test_feat_t_sc
        gc.collect()        

    logger.info(f'seed - {i}, num trees - {num_trees}')
    logger.info(f'Mean est. Sensitivity : {np.mean(q[:, 0])} +- {np.std(q[:, 0])}')
    logger.info(f'Mean est. Specifcity : {np.mean(q[:, 1])} +- {np.std(q[:, 1])}')
    logger.info(f'Mean est. Precision : {np.mean(q[:, 2])} +- {np.std(q[:, 2])}')
    logger.info(f'Mean est. MCC : {np.mean(q[:, 3])} +- {np.std(q[:, 3])}')
    logger.info(f'Mean est. AUC : {np.mean(q[:, 4])} +- {np.std(q[:, 4])}')
    logger.info(f'Mean est. F1+ : {np.mean(q[:, 5])} +- {np.std(q[:, 5])}')
    logger.info(f'Mean est. Accuracy : {np.mean(q[:, 6])} +- {np.std(q[:, 6])}')

def main():
    lr = 0.03
    num_trees = 20000 #10000 #100000 #2500
    model_name = 'esm2_t33_650M_UR50D' #'esm2_t36_3B_UR50D' #'esm2_t48_15B_UR50D' #'esm2_t36_3B_UR50D' #'esm2_t33_650M_UR50D' #'esm2_t6_8M_UR50D' 
    method = 'last_embs_method_3_attns' #'last_embs' #'method_4_attns' #'last_embs' #'method_3_attns'
    attns = '' #'attns_with_cls'
    data_folder = f'/data/conservation/ConSuf10k/{model_name}/features'
    data_type = 'emb_features' #'embs' #'features' #'embs' #'features'
    warm_start = 'batch'
    hessian = True
    max_depth = 4 #3
    leaf = 10 #10
    sc = '' #'_sc'

    '''
    method = 'last_embs' #'method_3_attns'
    train_e = np.load(f'{data_folder}/train_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True)
    val_e = np.load(f'{data_folder}/val_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True)
    test_e = np.load(f'{data_folder}/test_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True)
    
    method = 'sum_method_3_attns' #'last_embs' #'method_3_attns'
    train_m = np.load(f'{data_folder}/train_ConSuf10k_{method}_{model_name}{sc}.npy' , allow_pickle=True)
    val_m = np.load(f'{data_folder}/val_ConSuf10k_{method}_{model_name}{sc}.npy' , allow_pickle=True)
    test_m = np.load(f'{data_folder}/test_ConSuf10k_{method}_{model_name}{sc}.npy' , allow_pickle=True)
    print(train_m.shape)
    print(val_m.shape)
    print(test_m.shape)
    '''
    '''
    sc = StandardScaler()
    sc.fit(train_m.astype(np.float16))
    train_sc = sc.transform(train_m).astype(np.float16)  
    print(train_sc.shape)  
    np.save(f'{data_folder}/train_ConSuf10k_{method}_{model_name}_sc.npy', train_sc, allow_pickle=True)
    val_sc = sc.transform(val_m)    
    np.save(f'{data_folder}/val_ConSuf10k_{method}_{model_name}_sc.npy', val_sc, allow_pickle=True)
    test_sc = sc.transform(test_m)    
    np.save(f'{data_folder}/test_ConSuf10k_{method}_{model_name}_sc.npy', test_sc, allow_pickle=True)    
    '''
    '''
    train = np.concatenate((train_e, train_m), axis = 1)
    val = np.concatenate((val_e, val_m), axis = 1)
    test = np.concatenate((test_e, test_m), axis = 1)
    print(train.shape)
    print(val.shape)
    print(test.shape)

    method = f'last_embs_{method}'
    np.save(f'{data_folder}/train_ConSuf10k_{method}_{model_name}{sc}.npy', train, allow_pickle=True) # 
    np.save(f'{data_folder}/val_ConSuf10k_{method}_{model_name}{sc}.npy', val, allow_pickle=True)
    np.save(f'{data_folder}/test_ConSuf10k_{method}_{model_name}{sc}.npy', test, allow_pickle=True)
    #'''
    calc(data_folder, model_name, warm_start, num_trees, leaf, method, data_type, lr=lr, hessian=hessian, max_depth=max_depth) 

if __name__ == "__main__":
    main()