from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder
import pandas as pd
import pickle
import numpy as np
from bilevel.ExpertsAbstract import Expert
import pickle

def save_ob(name, obj):
    with open(name, 'wb') as handle:
        pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
def load_ob(name):
    with open(name, 'rb') as handle:
        obj = pickle.load(handle)
    return obj

def numeric_scaler(df, cols):
    df_new = df.copy()
    mmscaler = MinMaxScaler()
    df_new[cols] = mmscaler.fit_transform(df_new[cols])
    return df_new

def ordinal_encoder(df, cols):
    df_new = df.copy()
    ordinal_enc = OrdinalEncoder()
    df_new[cols] = ordinal_enc.fit_transform(df_new[cols])
    return df_new

def one_hot(df, cols):
    for each in cols:
        dummies = pd.get_dummies(df[each], prefix=each, drop_first=False)
        df = pd.concat([df, dummies], axis=1)
    return df

def BaselinevsAnh(baseline_imp, Anh) -> pd.DataFrame:
    columns = ['Tg', baseline_imp.name + 'loss', 'Anh cuml loss', 'relative diff %']
    results_df = pd.DataFrame(columns=columns)
    N = len(Anh.cuml_loss_curve)
    for gnum in range(len(An)):
        baseline_loss = baseline_imp.cumloss_groupwise_oridge[gnum]
        Tg = len(baseline_loss) # number of rounds this group is active
        Anh_end = Anh.cuml_loss_curve[gnum][-1] # last time steps cumulative loss
        oridge_end = oridge_loss_g[-1]
        rel_diff = 2 * abs(oridge_end - Anh_end) / (abs(oridge_end) + abs(Anh_end))
        results_df.loc[results_df.shape[0]] = [Tg, oridge_end, Anh_end, rel_diff*100]
    return results_df

def save_pickle(obj, file_dest):
    pickle_out = open(file_dest, "wb")
    pickle.dump(obj, pickle_out)
    pickle_out.close()

def load_pickle(file_source):
    pickle_in = open(file_source,"rb")
    return pickle.load(pickle_in)

def fill_subsequence_losses(expert : Expert, A_t : np.ndarray) -> list[np.ndarray]:
    cumloss_groupwise = []
    N = A_t.shape[1]
    loss_groupwise = []
    loss_tarr = np.array(expert.loss_tarr)
    for gnum in range(N): # build cumulative loss for  on each group subsequence
        loss_groupwise.append(loss_tarr[A_t[:, gnum].astype(bool)]) # select those losses where group gnum active
        cumloss_groupwise.append(np.cumsum(loss_groupwise[-1]))
    return cumloss_groupwise
   