import pandas as pd
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from catboost import CatBoostClassifier
from lightgbm import LGBMClassifier
import warnings
warnings.filterwarnings('ignore')
def load_dataset(exp):
    x_train = pd.read_csv(f'/{your_path}/{exp}_Xtrain.csv')
    x_idtest = pd.read_csv(f'/{your_path}/{exp}_Xidtest.csv')
    x_ood = pd.read_csv(f'/{your_path}/{exp}_Xood.csv')
    y_train = pd.read_csv(f'/{your_path}/{exp}_ytrain.csv')
    y_idtest = pd.read_csv(f'/{your_path}/{exp}_yidtest.csv')
    y_ood = pd.read_csv(f'/{your_path}/{exp}_yood.csv')
    return x_train,x_idtest,x_ood,y_train,y_idtest,y_ood
def check_con_columns(data:pd.DataFrame, columns)->bool:
    return data[columns].dtype == 'float64'

def construct_init_prompts(data:pd.DataFrame) -> str:
    col_prompt = ''
    for c in data.columns:
        strs = ''
        if check_con_columns(data,c):
            strs = f'-{c}: (Numerical value of {min(set(data[c].values))} ~ {max(set(data[c].values))})\n'
        else:
            if len(c.split('_',1)) > 1:
                cats = c.rsplit('_',1)[0]
                bool_values = c.rsplit('_',1)[-1]
                strs = f'-{c}: (One hot encoder of category:{cats}, with value:{bool_values})\n'
            else:
                strs = f'-{c}: (Categorical value of {min(set(data[c].values))} - {max(set(data[c].values))})\n'
        col_prompt+=strs
    return col_prompt

def stat_prompts(x_train:pd.DataFrame, x_ood:pd.DataFrame = None, col:str = None) ->str:
    stat_prompt = ''
    if col is None:
        col = x_train.columns
    stat_info = x_train.describe()[col]
    strs = f'id information of {col}: mean:{stat_info.values[1]}, std:{stat_info.values[2]}, min:{stat_info.values[3]}, max:{stat_info.values[-1]}, 25%:{stat_info.values[4]}, 50%:{stat_info.values[5]}, 75%:{stat_info.values[6]}\n'
    stat_prompt += strs
    stat_prompt += strs
    return stat_prompt


def create_label_shift(x_train:pd.DataFrame, y_train:pd.DataFrame, shift_ratio:float, direction:int) -> tuple[pd.DataFrame, pd.DataFrame]:
    total_sample, ori_ratio = x_train.shape[0], y_train.mean().values[0]
    target_ratio = ori_ratio + (1 - ori_ratio) * shift_ratio if direction == 1 else (1 - ori_ratio) + (1 - (1-ori_ratio)) * shift_ratio
    sample_num = 0
    if direction == 1:
        sample_num = total_sample * ori_ratio * ( 1 / (target_ratio) - 1)
    else:
        ori_ratio = 1 - ori_ratio
        sample_num = total_sample * ori_ratio * ( 1 / (target_ratio) - 1)
    full_df = pd.concat([x_train, y_train], axis=1)
    sample_df = full_df[full_df[y_train.columns[0]] == abs(1-direction)].sample(n=int(np.round(sample_num)))
    sample_df = pd.concat([sample_df, full_df[full_df[y_train.columns[0]] == direction]]).sample(frac=1).reset_index(drop=True)
    return sample_df.drop(columns=y_train.columns[0]), sample_df[y_train.columns[0]]


def propose_cols(e,convert_pd_x,covert_pd_y,is_train:bool=False):
    x_train,x_idtest,x_ood,y_train,y_idtest,y_ood = load_dataset(e)
    gbdt = GradientBoostingClassifier(random_state=37,n_estimators=100)
    gbdt.fit(x_train,y_train)
    init_prompt = construct_init_prompts(x_train)
    # stat_prompt = stat_prompts(x_train)
    importances = gbdt.feature_importances_
    recover_col = x_train.columns[importances.argmax()]
    Objective = covert_pd_y.columns[0] #'the goal is to predict whether a diabetic patient is readmitted to the hospital within 30 days of their initial release'
    job_describe = f'''### Your task ###
    Your objective is to predict {Objective}. You have access to the following attributes:
    {init_prompt}

    To enhance prediction performance in real world applications, you need to propose rule to recover the value of column {recover_col} using other columns to overcome the situation that when this columns value is missing when testing. The rule you designed should be reasonable to its original meaning, for example, for the original value, if it is the bigger the better, then the recovered value should have the same properties. You can change the value to catgorial with the value has same trend to original value.
    The rule you designed will be directly used to recover the whole column. You need to covert all the value to a number which can be direct used by machine learning model.

    ### Answer ###'''

    from openai import OpenAI
    client = OpenAI(api_key='your api key', base_url="https://api.deepseek.com/v1")

    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=[
            {"role": "system", "content": "You are a helpful assistant"},
            {"role": "user", "content": job_describe},
        ],
        stream=False
    )

    rule = response.choices[0].message.content
    response = client.chat.completions.create(
    model="deepseek-chat",
    messages=[
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": f'We have a rule:{rule}, covert it to python code, the input parameter is only the test pd.Dataframe, we want the rule can be directly used with the function. The function name should be recover, it need to be used as test_df = recover(test_df). You need to just give me the code which can be directly use exec to run, no other imformation.'},
    ],
    stream=False
)
    code = response.choices[0].message.content
    code = code.replace('```python','').replace('```','')
    return code

def add_col(e,code,convert_pd_x,covert_pd_y, is_train):
    def run_code(code, x_ood):
        envs = {"x_ood":x_ood}
        exec(code, envs)
        exec('x_ood_tune = recover(x_ood)',envs)
        return envs['x_ood_tune']
    if is_train:
        convert_pd_x_cp = convert_pd_x.copy()
        tuned_x = run_code(code, convert_pd_x)
        return pd.concat([convert_pd_x_cp, tuned_x], axis=0), pd.concat([covert_pd_y,covert_pd_y], axis=0)
    else:
        return run_code(code, convert_pd_x), covert_pd_y
    
import scipy.stats as stats
from sklearn.mixture import GaussianMixture
def check_con_columns(data:pd.DataFrame, columns)->bool:
    return data[columns].dtype == 'float64'
def compute_wasserstein_distance_kde(train_df, test_df, num_samples=100):
    distances = []
    

    common_columns = list(set(train_df.columns) & set(test_df.columns))
    
    for col in common_columns: 
        if check_con_columns(train_df, col):
            train_data = train_df[col].dropna().values 
            test_data = test_df[col].dropna().values

            
            if len(train_data) == 0 or len(test_data) == 0:
                continue 
            

            train_data_noisy = train_data + np.random.normal(0, 1e-4, train_data.shape)
            test_data_noisy = test_data + np.random.normal(0, 1e-4, test_data.shape)

            kde_train = stats.gaussian_kde(train_data_noisy)
            kde_test = stats.gaussian_kde(test_data_noisy)
            
            x_vals = np.linspace(min(train_data.min(), test_data.min()), 
                                max(train_data.max(), test_data.max()), 
                                num_samples)
            
            pdf_train = kde_train(x_vals)
            pdf_test = kde_test(x_vals)
            
            d = stats.wasserstein_distance(x_vals, x_vals, pdf_train, pdf_test)

            
        else:
            train_counts = train_df[col].value_counts(normalize=True).sort_index()
            test_counts = test_df[col].value_counts(normalize=True).sort_index()

            all_categories = train_counts.index.union(test_counts.index)
            train_probs = train_counts.reindex(all_categories, fill_value=0).values
            test_probs = test_counts.reindex(all_categories, fill_value=0).values
            train_probs, test_probs = train_probs+1e-8, test_probs+1e-8
            d = stats.entropy(train_probs, test_probs)
            d = np.sqrt(2*d)
        distances.append(d)
    return np.mean(distances) if distances else float("inf") 

def creating_feature_x_set_and_learners(e, x_train,y_train):
    code = propose_cols(e,x_train,y_train)
    direction = [0, 1]
    ratio = [0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4]
    base_learner = []
    base_x = []
    base_y = []
    print('creating shifted datasets')
    for d in direction:
        for r in ratio:
            x_shifted, y_shifted = create_label_shift(x_train, y_train, r, d)
            x_shift_tuned, y_shift_tuned = add_col(e, code,x_shifted, y_shifted, True)
            base_x.append(x_shift_tuned)
            base_y.append(y_shift_tuned)
            print('index: ',len(base_y) - 1,y_shifted.mean())
            if model == 'cat':
                base_learner.append([CatBoostClassifier(random_state=37,iterations=500,logging_level='Silent'),False])
            elif model == 'light':
                base_learner.append([LGBMClassifier(random_state=37,n_estimators=100,verbose=-1),False])
            else:
                base_learner.append([GradientBoostingClassifier(random_state=37,n_estimators=100),False])
    if model =='cat':
        ori_tree = CatBoostClassifier(random_state=37,iterations=500,logging_level='Silent').fit(add_col(e, code, x_train, y_train, True)[0],add_col(e, code, x_train, y_train, True)[1])
    elif model =='light':
        ori_tree = LGBMClassifier(random_state=37,n_estimators=100,verbose=-1).fit(add_col(e, code, x_train, y_train, True)[0],add_col(e, code, x_train, y_train, True)[1])
    else:
        ori_tree = GradientBoostingClassifier(random_state=37,n_estimators=100).fit(add_col(e, code, x_train, y_train, True)[0],add_col(e, code, x_train, y_train, True)[1])
    return base_learner, ori_tree, base_x, base_y, code


def test(e):
    x_train,x_idtest,x_ood,y_train,y_idtest,y_ood = load_dataset(e)
    if model == 'cat':
        gbdt = CatBoostClassifier(random_state=37,iterations=500,logging_level='Silent')
    elif model == 'light':
        gbdt = LGBMClassifier(random_state=37,n_estimators=100,verbose=-1)
    else:
        gbdt = GradientBoostingClassifier(random_state=37,n_estimators=100)
    gbdt.fit(x_train,y_train)
    importances = gbdt.feature_importances_
    recover_col = x_train.columns[importances.argmax()]
    x_ood[recover_col] = 0
    x_train,x_idtest,x_ood,y_train,y_idtest,y_ood = load_dataset(e)
    x_ood_c = x_ood.copy()
    base_learner_list, ori_tree, base_x, base_y,code = creating_feature_x_set_and_learners(e, x_train, y_train)
    print("rule is \n",code)
    df,y_ood = add_col(e,code,x_ood_c,y_ood,False)
    results = []
    model_index = -1
    batch_index = {}
    import tqdm
    def get_operation_frequency(batch_idx):
        return 0.8 ** batch_idx
    for indexs,batch in enumerate(range(len(df) // 128+1)):
        remain_test = df.iloc[batch*128 : min(128+batch*128, len(df))]
        freq = get_operation_frequency(indexs)
        print(f"batch {indexs}, freq:{freq}")
        if np.random.rand() < freq:
            print(f'-------selecting models at batch {indexs}----------')
            best = np.inf
            model_index = -1
            buffer = df.iloc[: min(128+batch*128, len(df))]
            for i, xs in enumerate(base_x):
                dis = compute_wasserstein_distance_kde(xs, buffer)
                if dis < best:
                    best = dis
                    model_index = i
            print('seleccted model: ',model_index, 'target: ', base_y[model_index].mean(),'ground truth: ', y_ood.mean().values[0])
            batch_index[indexs] = model_index
            print('-------selecting models----------')
        if base_learner_list[model_index][1]:     
            results.append(base_learner_list[model_index][0].predict(remain_test))
        else:
            base_learner_list[model_index][0].fit(base_x[model_index], base_y[model_index])
            base_learner_list[model_index][1] = True
            results.append(base_learner_list[model_index][0].predict(remain_test))
    res = np.array([])
    for arr in results:
        res = np.concatenate((res,arr))
    from sklearn.metrics import accuracy_score
    print("DyLearner res acc", accuracy_score(res, y_ood))
    print(f"exp:{e}, batch_index : {batch_index}")



import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model', default=None)
agrs = parser.parse_args()
model = agrs.model


exp_list = [ 
    "anes", 
    "brfss_blood_pressure", 
    "acsincome", 
    "acspubcov", 
]

for e in exp_list:
    print('\n',e)
    for retry in range(5):
        try:
            test(e)
            break
        except Exception as err:
            print(e,err)

    
