import sys
from resampling_data_to_get_balanced_data import *
from reattachcolumns import recover_full_rows_from_originals, get_excluded_and_subsets, recover_full_rows_by_query_keys
import pandas as pd
import os
from wrangle_data import add_contradiction_stats_to_df
from clean_data import process_dataframe, transform_edge_data

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)
if __name__ == "__main__":
    #task1 create training---------------------------------------------------------------------------------------------
    input_pickle_file_path = GiveFileName
    df = pd.read_pickle(input_pickle_file_path)
    desired_OPEC_level = 5
    df = df[df['max_non_path_atoms'] == desired_OPEC_level]
    print(df.shape)
    print(df.shape)
    print(df['max_rule_chain_len'].unique())
    sampled_dfs = []
    for chain_len in sorted(df['max_rule_chain_len'].unique()):
        subset = df[df['max_rule_chain_len'] == chain_len]
        sample_size = min(500, len(subset))  # Take 500 or all available if less
        sampled = subset.sample(n=sample_size, random_state=42)  # random_state for reproducibility
        sampled_dfs.append(sampled)
    # Combine all the sampled DataFrames
    df = pd.concat(sampled_dfs, ignore_index=False)## NEVER ignore index otherwise messes things up  
    print(f'''   Before recovering columsn we have  {df.shape}  ''')
    df = recover_full_rows_by_query_keys(df, '../../DatsetsGenerated/NonPathMeasurementMOreNoAmbg')
    print(f''' After recovering columns we have  {df.shape} ''')    
    print(df.iloc[0][['unique_facts', 'query_edge', 'query_relation', 'derivation_chain', 'max_non_path_atoms']])
    print(f'''  just befor saving --------------------------------------------------------------------------------------------------------------   ''')
    df = add_contradiction_stats_to_df(df)
    df = process_dataframe(input= df)
    print(df.iloc[0][['unique_facts', 'query_edge', 'query_relation', 'derivation_chain']])
    df = transform_edge_data(df)## let keep stroy encoding laigned with world rules
    output_pickle_path =  f'dataset_{desired_OPEC_level}_off_path_nodes.pkl'
    df.to_pickle(output_pickle_path)
