import os
import gc

import numpy as np
import torch

from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelBinarizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
import logging
from tqdm import tqdm, trange
import joblib
import cupy as cp
from joblib import dump, load

from copy import copy
import pickle


import warnings
warnings.filterwarnings('ignore')

# set the device to run
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

#from py_boost import Callback
from py_boost import GradientBoosting, SketchBoost, TLPredictor, TLCompiledPredictor

# strategies to deal with multiple outputs
from py_boost.multioutput.sketching import *
from py_boost.multioutput.target_splitter import *
from py_boost import Callback

from evaluate import load
metric = load("accuracy")

class WarmStart(Callback):
    
    def __init__(self, model):
        
        model.to_cpu()
        self.model = copy(model)
        self.model.postprocess_fn = lambda x: x
        
    def before_train(self, build_info):
        
        build_info['model'].base_score = cp.asarray(self.model.base_score)
        
        train = build_info['data']['train']
        train['ensemble'] = cp.asarray(self.model.predict(train['features_cpu']))
        
        valid = build_info['data']['valid']
        valid['ensemble'] = [cp.asarray(self.model.predict(x)) for x in valid['features_cpu']]
        
        self.model.to_cpu()
        
        return 
    
    def after_train(self, build_info):
        
        build_info['model'].models = self.model.models + build_info['model'].models
        # update the actual iteration
        build_info['num_iter'] = build_info['num_iter'] + len(self.model.models)
        # update the actual best round
        early_stop = build_info['model'].callbacks.callbacks[-1]
        early_stop.best_round = early_stop.best_round + len(self.model.models)
        
        # not to store old trees multiple times
        self.model = None
        
        return

def f1_max(pred, target):
    """
    F1 score with the optimal threshold.

    This function first enumerates all possible thresholds for deciding positive and negative
    samples, and then pick the threshold with the maximal F1 score.

    Parameters:
        pred (Tensor): predictions of shape :math:`(B, N)`
        target (Tensor): binary targets of shape :math:`(B, N)`
    """
    order = pred.argsort(descending=True, dim=1)
    target = target.gather(1, order)
    precision = target.cumsum(1) / torch.ones_like(target).cumsum(1)
    recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10)
    is_start = torch.zeros_like(target).bool()
    is_start[:, 0] = 1
    is_start = torch.scatter(is_start, 1, order, is_start)

    all_order = pred.flatten().argsort(descending=True)
    order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1]
    order = order.flatten()
    inv_order = torch.zeros_like(order)
    inv_order[order] = torch.arange(order.shape[0], device=order.device)
    is_start = is_start.flatten()[all_order]
    all_order = inv_order[all_order]
    precision = precision.flatten()
    recall = recall.flatten()
    all_precision = precision[all_order] - \
                    torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1])
    all_precision = all_precision.cumsum(0) / is_start.cumsum(0)
    all_recall = recall[all_order] - \
                 torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1])
    all_recall = all_recall.cumsum(0) / pred.shape[0]
    all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10)
    
    return all_f1[~torch.isnan(all_f1)].max()

def calc(data_folder, model_name, warm_start='', num_trees=3000, leaf =10, method='method_3', data_type = 'features', lr=0.03, hessian = True, max_depth = 4):

    #logging.basicConfig(filename='esm_tda/HHblits/facebook/esm2_t33_650M_UR50D/results/thr_0.85_with_vert_method_1.log', encoding='utf-8', level=logging.DEBUG)
    #logger.debug('This message should go to the log file')
    
    N_RERUNS = 1
    np.random.seed(42)
    seeds = np.random.randint(1000, size=(N_RERUNS))
    
    #C_range = [13, 14, 15, 16, 17, 18, 19, 20]
    #C_range = [13, 16, 19]
    C_range = [16]
    
    q = np.zeros((seeds.shape[0], 5))    

    train = np.load(f'{data_folder}/train_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True)
    val = np.load(f'{data_folder}/val_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True)
    test = np.load(f'{data_folder}/test_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True)    

    print('Train:', train.shape)
    print('Val:', val.shape)
    print('Test:', test.shape)

    logger = logging.getLogger(__name__)
    logger_filename = f'logs/log_ConSuf10k_{method}_{model_name}_{data_type}_trees_{num_trees}_leaf_{leaf}_lr_{lr}_hessian_{hessian}_max_depth_{max_depth}.log' 
    logger.setLevel(logging.DEBUG)
    handler = logging.FileHandler(logger_filename, 'w', 'utf-8')
    logger.addHandler(handler)

    '''
    train_sc = np.load(f'{data_folder}train_all_heads_all_layers_{filename}_per_tokens_esm2_t33_650M_UR50D_sc.npy' , allow_pickle=True)  #from 2319963
    print(train_sc.shape)
    model_importance = joblib.load(f'{data_folder}model_{filename}_trees_3000_seed_102_2000k.pkl')
    
    importances = model_importance.get_feature_importance()
    np.save('importances_by_model.npy', np.array(importances))
    '''
    print('Data loaded!') 

    y_train = np.load(f'{data_folder}/y_train_ConSuf10k.npy' , allow_pickle=True) - 1 #[:600000]
    print('y_train', y_train.shape, sum(np.where(y_train > 4, 1, 0)))
    #print(np.histogram(np.array(y_train), bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
    lb = LabelBinarizer().fit(y_train)
    y_val = np.load(f'{data_folder}/y_val_ConSuf10k.npy' , allow_pickle=True) - 1 
    #print('y_val', y_val.shape)
    y_test = np.load(f'{data_folder}/y_test_ConSuf10k.npy' , allow_pickle=True) - 1
    #print('y_test', y_test.shape)
    print('Train/Val/Test labels loaded')

    logger.info(f'data type - {data_type}')
    for i in range(seeds.shape[0]):
        #best_quality, best_C = 0.0, 10
        # prev_model_name = f'models/model_ConSuf10k_{method}_{model_name}_{data_type}_trees_10000_leaf_{leaf}_lr_{lr}_hessian_{hessian}_max_depth_{max_depth}_seed_{seeds[i]}.pkl' 
        #prev_model = joblib.load(prev_model_name)       
        print('Start new model')
        sketch = RandomProjectionSketch(1) #10
        model = GradientBoosting(
            'crossentropy',
            ntrees=num_trees, lr=lr, verbose=500, es=300, lambda_l2=1, gd_steps=1, 
            subsample=1, colsample=1, min_data_in_leaf=leaf, use_hess=hessian, 
            max_bin=256, max_depth=max_depth, seed = seeds[i], 
            multioutput_sketch=sketch, debug=True #, callbacks=[WarmStart(prev_model)]
        )
        #del prev_model
        #gc.collect()        
        print('Start model fit')
        model.fit(train.astype(np.float32), y_train, 
                eval_sets = [{'X': val.astype(np.float32), 'y': y_val}])   
        #indices = np.argsort(np.array(model.get_feature_importance()))
        #np.save(f'{data_folder}/indices_seed_{seeds[i]}.npy', indices, allow_pickle=True)
        dump_model_name =  f'models/model_ConSuf10k_{method}_{model_name}_{data_type}_trees_{num_trees}_leaf_{leaf}_lr_{lr}_hessian_{hessian}_max_depth_{max_depth}_seed_{seeds[i]}.pkl' 
        joblib.dump(model, dump_model_name)            
        print('Model fitted')
        pred = model.predict(test) #, nthread=4
        print(pred.shape)
        pred_torch = torch.from_numpy(pred)
        print(pred_torch.shape)
        target = torch.from_numpy(lb.transform(y_test))
        print(target.shape, y_test.shape)
        
        test_pred = np.argmax(pred, axis=1)
        q[i][0] = np.round(f1_score(y_test, test_pred, average='micro'), 5) * 100.0
        q[i][1] = np.round(f1_score(y_test, test_pred, average='macro'), 5) * 100.0
        q[i][2] = np.round(f1_score(y_test, test_pred, average='weighted'), 5) * 100.0  
        q[i][3] = f1_max(pred_torch, target) * 100.0              
        q[i][4] = np.round(metric.compute(predictions=test_pred, references=y_test)['accuracy'], 5) * 100.0  
        logger.info(f'F1 micro quality - {q[i][0]}')
        logger.info(f'F1 macro quality - {q[i][1]}')
        logger.info(f'F1 weighted quality - {q[i][2]}')
        logger.info(f'F1 max quality - {q[i][3]}')
        logger.info(f'Accuracy - {q[i][4]}')            
        del model, sketch, test_pred, pred_torch, pred #, train_sc, val_sc, test_sc
        gc.collect()
        #print('Model deleted')
    


 

    logger.info(f'seed - {i}, num trees - {num_trees}')
    logger.info(f'Mean est. accuracy : {np.mean(q[:, 4])} +- {np.std(q[:, 4])}')
    logger.info(f'Mean est. F1 max quality : {np.mean(q[:, 3])} +- {np.std(q[:, 3])}')
    logger.info(f'Mean est. F1 weighted quality : {np.mean(q[:, 2])} +- {np.std(q[:, 2])}')
    logger.info(f'Mean est. F1 macro quality : {np.mean(q[:, 1])} +- {np.std(q[:, 1])}')
    logger.info(f'Mean est. F1 micro quality : {np.mean(q[:, 0])} +- {np.std(q[:, 0])}')


def main():
    lr = 0.03
    num_trees = 10000 #10000 #100000 #2500
    model_name = 'esm2_t33_650M_UR50D' #'esm2_t36_3B_UR50D' #'esm2_t48_15B_UR50D' #'esm2_t36_3B_UR50D' #'esm2_t33_650M_UR50D' #'esm2_t6_8M_UR50D' 
    method = 'method_3_attns' #'last_embs' #'method_4_attns' #'last_embs' #'method_3_attns'
    attns = '' #'attns_with_cls'
    data_folder = f'/data/conservation/ConSuf10k/{model_name}/features'
    data_type = 'features' #'embs' #'features' #'embs' #'features'
    warm_start = 'batch'
    hessian = True
    max_depth = 4 #3
    leaf = 10 #10
    sc = '' #'_sc'

    '''
    method = 'last_embs' #'method_3_attns'
    train_e = np.load(f'{data_folder}/train_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True)
    val_e = np.load(f'{data_folder}/val_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True)
    test_e = np.load(f'{data_folder}/test_ConSuf10k_{method}_{model_name}.npy' , allow_pickle=True)
    
    method = 'sum_method_3_attns' #'last_embs' #'method_3_attns'
    train_m = np.load(f'{data_folder}/train_ConSuf10k_{method}_{model_name}{sc}.npy' , allow_pickle=True)
    val_m = np.load(f'{data_folder}/val_ConSuf10k_{method}_{model_name}{sc}.npy' , allow_pickle=True)
    test_m = np.load(f'{data_folder}/test_ConSuf10k_{method}_{model_name}{sc}.npy' , allow_pickle=True)
    print(train_m.shape)
    print(val_m.shape)
    print(test_m.shape)
    '''
    '''
    sc = StandardScaler()
    sc.fit(train_m.astype(np.float16))
    train_sc = sc.transform(train_m).astype(np.float16)  
    print(train_sc.shape)  
    np.save(f'{data_folder}/train_ConSuf10k_{method}_{model_name}_sc.npy', train_sc, allow_pickle=True)
    val_sc = sc.transform(val_m)    
    np.save(f'{data_folder}/val_ConSuf10k_{method}_{model_name}_sc.npy', val_sc, allow_pickle=True)
    test_sc = sc.transform(test_m)    
    np.save(f'{data_folder}/test_ConSuf10k_{method}_{model_name}_sc.npy', test_sc, allow_pickle=True)    
    '''
    '''
    train = np.concatenate((train_e, train_m), axis = 1)
    val = np.concatenate((val_e, val_m), axis = 1)
    test = np.concatenate((test_e, test_m), axis = 1)
    print(train.shape)
    print(val.shape)
    print(test.shape)

    method = f'last_embs_{method}'
    np.save(f'{data_folder}/train_ConSuf10k_{method}_{model_name}{sc}.npy', train, allow_pickle=True) # 
    np.save(f'{data_folder}/val_ConSuf10k_{method}_{model_name}{sc}.npy', val, allow_pickle=True)
    np.save(f'{data_folder}/test_ConSuf10k_{method}_{model_name}{sc}.npy', test, allow_pickle=True)
    #'''
    calc(data_folder, model_name, warm_start, num_trees, leaf, method, data_type, lr=lr, hessian=hessian, max_depth=max_depth) 

if __name__ == "__main__":
    main()