
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from skmultilearn.model_selection import iterative_train_test_split as multi_split
from imblearn.over_sampling import SMOTE, ADASYN
from sklearn import preprocessing
from sklearn import metrics
from imblearn.over_sampling import RandomOverSampler 

from sklearn.metrics import silhouette_score

from sklearn.cluster import KMeans



def load_data(file_name, test_s = 0, rnd = 1, red = 0, scale = 0):          #load data function

    if file_name[-1] == 'v':
        df = pd.read_csv(file_name)         
    else:
        df = pd.read_excel(file_name, index_col=0)
        if 'y' not in df.columns:
            df = pd.read_excel(file_name)
        df['Xconst'] = 1
            

    y = df['y']
    x = df.drop('y', axis = 1)
    

    for c in x.columns:
        if c[0] != 'X':
            x = x.drop(c, axis = 1)
    

    pd.options.display.show_dimensions = False
    np.set_printoptions(suppress=True)
        

    if sum(x.sum() == len(x)) == 0:
        x.loc[:, 'Xconst'] = 1
    else:
        temp_cols = pd.Series(list(x.columns), index = x.columns)
        temp_cols[x.sum() == len(x)] = 'Xconst'
        x.columns = temp_cols
        

    xorg = x.copy()

    if scale == 1:
        x = pd.DataFrame(preprocessing.StandardScaler().fit_transform(x), columns = x.columns, index = x.index)
        x['Xconst'] = 1


    if test_s > 0:
        if max(y == 1):
            x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = test_s, random_state = rnd)
        else:
            x_train, x_test, y_train, y_test = multi_split(x, y, test_size = test_s, random_state = rnd)
    else:
        x_train = x_test = x
        y_train = y_test = y

    return x_train, y_train, x_test, y_test, xorg, y




def dmg_data(df, prc_stp, how, by, rnd = 0, dmg_indx = []):         #damage data funtion
    
    dfo = df.copy()
    byo = by.copy()
    
    dmg_indx = []
    
    np.random.seed(rnd)
        
    for prc in [0.65, 0.75, 0.85, 0.95]:                    #Damage is applied incrementally to ensure valid comparisons across damage levels.
        df = dfo.copy()                                     #For example, the 75% damage setting includes all instances removed at the 65% level, plus an additional 10%
        by = byo.copy()
    
        if len(dmg_indx) > 0:
            df_dmg = df.drop(labels = dmg_indx)
            if how != 'rnd':
                by = by.drop(labels = dmg_indx)
        else:
            df_dmg = df
            
        num_rmv = len(dmg_indx)
    
        dmg = int(prc * df.y.sum())
            
        if how == 'rnd':        #MAR damaging
            dmg_indx = np.concatenate([dmg_indx, df_dmg.loc[df_dmg.y == 1].sample(dmg - num_rmv, random_state = rnd).index])
        else:                   #MNAR damaging   
            by = by.sort_values(ascending = how == 'left')
        
            while num_rmv != dmg:
                thrs = 0.9 - num_rmv * 0.1/dmg 
                for row in by.index:
                    if df_dmg.loc[row, 'y'] == 1 and np.random.rand() <= thrs:
                        dmg_indx.append(row)
                        num_rmv += 1
                        thrs -= 0.1/dmg
                    if num_rmv == dmg:
                        break
                    
        if prc == prc_stp:
            break
    
            
    return df.drop(labels = dmg_indx)


def calc_acc_auc(y, proba, y_train = [0], proba_train = [0]):       #function to calculate performance
    
    if len(y_train) > 1:
        
        thresholds = np.linspace(0, 1, 100)
        f1_scores = []

        for t in thresholds:
            y_pred_train = (proba_train >= t).astype(int)
            f1 = metrics.f1_score(y_train, y_pred_train)
            f1_scores.append(f1)

        best_idx = np.argmax(f1_scores)
        thresh = thresholds[best_idx]
    else:
        thresh = 0.5
        
    y_pred = np.zeros(len(y))
    y_pred[proba >= thresh] = 1
    
    return metrics.balanced_accuracy_score(y, y_pred), metrics.roc_auc_score(y, proba), metrics.average_precision_score(y, proba), metrics.precision_score(y, y_pred), metrics.recall_score(y, y_pred), metrics.f1_score(y, y_pred)


  


def get_results_sur(model, x_train, y_train, x_test, y_test, min_lab, num_vars):        #function that implement RBE (random balanced ensembling)
     
    
    min_num_lab = int((y_train == min_lab).sum())
    maj_num_lab = len(y_train) - min_num_lab
                 
    ratio = maj_num_lab / min_num_lab
    
    num_lab = int(np.round(ratio))
    num_sample = int(np.ceil(maj_num_lab /num_lab))

    
    if num_lab > 1:
        num_run = int(round(1 * ratio))
    else:
        num_run = 1
    

    num_votes = 0
    proba_train = pd.DataFrame(np.zeros([len(y_train),2]), index = y_train.index)
    proba_test = pd.DataFrame(np.zeros([len(y_test),2]), index = y_test.index)
    min_indx = y_train.loc[y_train == min_lab].index
    for run in range(num_run):
        y_train_mod = y_train.copy()
        for i in range(num_lab):
            if num_sample > len(y_train.loc[y_train_mod == 1 - min_lab]):
                num_sample = len(y_train.loc[y_train_mod == 1 - min_lab])
            if num_sample == 0:
                continue
            indx = y_train.loc[y_train_mod == 1 - min_lab].sample(num_sample, random_state = run).index
            y_train_mod[indx] = min_lab
            indx = indx.union(min_indx)
            model.fit(x_train.loc[indx,:], y_train[indx])

            proba_tr = model.predict_proba(x_train)
            proba_ts = model.predict_proba(x_test)
            num_votes += 1
            proba_train += proba_tr
            proba_test += proba_ts
    
    proba_train /= num_votes
    proba_test /= num_votes
    
        
    return proba_train, proba_test


def get_results_surclust(model, x_train, y_train, x_test, y_test, x_train_nons, min_lab, num_vars):           #Function to implement CBE
     
    
    min_num_lab = int((y_train == min_lab).sum())
    maj_num_lab = len(y_train) - min_num_lab
                 
    ratio = maj_num_lab / min_num_lab
        
    num_maj_lab = int(np.round(ratio)) 

    if num_maj_lab > 1:
        num_run = int(round(1 * ratio))
    else:
        num_run = 1
        

    votes = np.zeros(num_run * num_maj_lab)
    proba_train = pd.DataFrame(np.zeros([len(y_train),2]), index = y_train.index)
    proba_test = pd.DataFrame(np.zeros([len(y_test),2]), index = y_test.index)
    min_indx = y_train.loc[y_train == min_lab].index
    maj_indx = y_train.loc[y_train == 1 - min_lab].index
    x_train_maj = x_train_nons.loc[y_train == 1 - min_lab, :]

    for run in range(num_run):
            
        maj_labels = range(num_maj_lab)
        if num_maj_lab > 1:
            kmeans_maj = KMeans(n_clusters = num_maj_lab, random_state = run, n_init = 10).fit(x_train_maj)

        for i in maj_labels:
            
            if num_maj_lab > 1:
                maj_indx = x_train_maj.loc[kmeans_maj.labels_ == i].index
            
            indx = min_indx.union(maj_indx)
            
            model.fit(x_train.loc[indx,:], y_train[indx])
            
            proba_tr = model.predict_proba(x_train)
            proba_ts = model.predict_proba(x_test)
            acc = metrics.average_precision_score(y_train, proba_tr[:, 1])
            votes[run * num_maj_lab + i] = acc
            proba_train += proba_tr * acc
            proba_test += proba_ts * acc
                
    proba_train /= sum(votes)
    proba_test /= sum(votes)
    
                
    return proba_train, proba_test





def get_results_EKR(model, x_train, y_train, x_test, y_test, num_vars):         #Function to implement EKR
    
    min_lab = 1
    rnd = 42
    
    scores = {}
    for k in range(2, 11):
        km = KMeans(n_clusters = k, random_state = rnd)
        labels = km.fit_predict(x_train)
        try:
            score = silhouette_score(x_train, labels)
        except ValueError:
            score = -1  # invalid, skip
        scores[k] = score
    
    best_k = max(scores, key=scores.get)
    
    
    proba_train = pd.DataFrame(np.zeros([len(y_train),2]), index = y_train.index)
    proba_test = pd.DataFrame(np.zeros([len(y_test),2]), index = y_test.index)
    min_indx = y_train.loc[y_train == min_lab].index
    x_train_maj = x_train.loc[y_train == 1 - min_lab, :]
    
    maj_labels = KMeans(n_clusters=best_k, random_state=rnd).fit_predict(x_train_maj)
    
    num_votes = 0
    for cluster_id in range(best_k):
        cl_indx = x_train_maj.loc[maj_labels == cluster_id, :].index
        
        if len(cl_indx) >= len(min_indx):
            if len(cl_indx) > len(min_indx):
                cl_indx = y_train.loc[cl_indx].sample(len(min_indx), random_state = rnd).index
            indx = cl_indx.union(min_indx)
            x_bl = x_train.loc[indx, :]
            y_bl = y_train.loc[indx]
        else:
            indx = cl_indx.union(min_indx)
            try:
                x_bl, y_bl = ADASYN(random_state=rnd).fit_resample(x_train.loc[indx, :], y_train.loc[indx])
            except:
                try:
                    x_bl, y_bl = SMOTE(random_state=rnd).fit_resample(x_train.loc[indx, :], y_train.loc[indx])
                except:
                    x_bl, y_bl = RandomOverSampler(sampling_strategy='auto', random_state=rnd).fit_resample(x_train.loc[indx, :], y_train.loc[indx])
                    
        model.fit(x_bl, y_bl)

        proba_tr = model.predict_proba(x_train)
        proba_ts = model.predict_proba(x_test)

        num_votes += 1
        proba_train += proba_tr 
        proba_test += proba_ts 
                
    proba_train /= num_votes
    proba_test /= num_votes
    
                
    return proba_train, proba_test


def focal_loss(x, y, model, alpha=0.25, gamma=2.0):
    
    model.fit(x, y)
    probs = model.predict_proba(x)  

    p_t = np.where(y == 1, probs[:, 1], probs[:, 0])

    # alpha_t = alpha if positive else (1 - alpha)
    alpha_t = np.where(y == 1, alpha, 1 - alpha)

    # Focal weights
    weights = alpha_t * ((1 - p_t) ** gamma)

    return weights


def class_loss(x, y):
    
    n_samples = len(y)
    
    beta = (n_samples - 1) / n_samples
    
    counts = y.value_counts().sort_index()  
    
    effective_num = (1 - np.power(beta, counts)) / (1 - beta)
    class_weights = 1.0 / effective_num
    class_weights = class_weights / class_weights.sum() * 2  
    
    weights = y.map({0: class_weights[0], 1: class_weights[1]}).to_numpy()
    
    return weights


def adjust_proba(proba, y, tau=1.0, eps=1e-12):
    
    pi1 = y.mean()
    p = np.clip(np.asarray(proba), eps, 1 - eps)
    z = np.log(p / (1 - p))                               # logit(p)
    z_adj = z + tau * np.log(pi1 / (1 - pi1))             # add prior log-odds
    p_adj = 1 / (1 + np.exp(-z_adj))                      # sigmoid
    
    return p_adj


def weight_oversample(x, y, w):
    
    w /= w.sum()
    
    rng = np.random.default_rng(42)
    idx_min = y.loc[y == 1].index
    idx_maj = y.loc[y == 0].index
    add_idx = rng.choice(idx_min, size=len(idx_maj) - len(idx_min), replace=True, p=w)
    xn = pd.concat([x, x.loc[add_idx]], axis=0)
    yn = pd.concat([y, y.loc[add_idx]], axis=0)
    
    return xn, yn
        