import pandas as pd
import matplotlib.pyplot as plt

csv_path = "/Users/home/Documents/naz/research_codes/uncert_prop/realworld_exp/hai_down1/concatenated_train.csv"  # Set your CSV file path here

df = pd.read_csv(csv_path)

def is_boolean_col(col):
    uniq = set(col.dropna().unique())
    if len(uniq) <= 2:
        try:
            uniq_float = set(float(x) for x in uniq)
            return True
        except Exception:
            return False
    return False

bool_columns = [col for col in df.columns if is_boolean_col(df[col])]
print(f"Number of boolean columns: {len(bool_columns)}")
print("Boolean columns:", bool_columns)

if bool_columns:
    # Extract only the boolean columns and drop rows with NaN in any boolean column
    bool_df = df[bool_columns].dropna()
    bool_df = bool_df.astype(int)
    # Count occurrences of each unique combination
    combo_counts = bool_df.value_counts().sort_index()
    n_combos = combo_counts.shape[0]
    print(f"Number of unique boolean combinations present in data: {n_combos}")
    print("Combinations and their counts:")
    for combo, count in combo_counts.items():
        print(f"{tuple(combo)}: {count}")

    # Plot histogram
    plt.figure(figsize=(10, 5))
    combo_labels = [''.join(map(str, combo)) for combo in combo_counts.index]
    plt.bar(combo_labels, combo_counts.values)
    plt.xlabel("Boolean Combination")
    plt.ylabel("Number of Data Points")
    plt.title("Histogram of Boolean Combinations")
    plt.tight_layout()
    plt.show()

    # Find largest continuous stretch for each combination
    bool_tuples = bool_df.apply(lambda row: tuple(row), axis=1).values
    stretches = {}
    current_combo = None
    current_start = 0
    for i, combo in enumerate(bool_tuples):
        if combo != current_combo:
            if current_combo is not None:
                # Save previous stretch
                prev_length = i - current_start
                if current_combo not in stretches or prev_length > stretches[current_combo][0]:
                    stretches[current_combo] = (prev_length, current_start, i - 1)
            current_combo = combo
            current_start = i
    # Handle last stretch
    if current_combo is not None:
        prev_length = len(bool_tuples) - current_start
        if current_combo not in stretches or prev_length > stretches[current_combo][0]:
            stretches[current_combo] = (prev_length, current_start, len(bool_tuples) - 1)

    print("\nLargest continuous stretch for each combination:")
    for combo, (length, start, end) in stretches.items():
        print(f"Combination {combo}: Length={length}, Start Index={start}, End Index={end}")