import os
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from plot_p_values import heatmap
import matplotlib.colors as mcolors
import random


def combine_dependent_p_values(p_values, number_to_combine=None):
    """
    Combine p-values for dependent tests using the formula:
    P_combined = 1 - exp(sum(log(1 - p_i)) for i in range(n))
    If number_to_combine is set, randomly select number_to_combine p-values.

    Args:
        p_values: numpy array or list of p-values
        number_to_combine: int or None, number of p-values to combine

    Returns:
        float: combined p-value
    """
    p_values = np.array(p_values)

    # Remove any NaN values
    p_values = p_values[~np.isnan(p_values)]

    if len(p_values) == 0:
        return np.nan

    if number_to_combine is not None and len(p_values) >= number_to_combine:
        # Pick random values instead of first ones
        p_values = np.array(random.sample(list(p_values), number_to_combine))
    elif number_to_combine is not None:
        # If not enough values, use all available
        pass

    # Ensure p-values are in valid range (0, 1)
    # Clip to avoid log(0) issues
    p_values = np.clip(p_values, 1e-15, 1 - 1e-15)

    # Apply the formula: P_combined = 1 - exp(sum(log(1 - p_i)))
    log_sum = np.sum(np.log(1 - p_values))
    p_combined = 1 - np.exp(log_sum)

    return p_combined


def read_p_values_and_combine(csv_path, number_to_combine=None):
    """
    Read all p-values from the last column of CSV file and combine them
    using the dependent p-values combination formula.
    If number_to_combine is set, only use the first number_to_combine p-values.

    Args:
        csv_path: path to CSV file
        number_to_combine: int or None, number of p-values to combine

    Returns:
        float: combined p-value
    """
    try:
        df = pd.read_csv(csv_path)

        # Get p-values from the last column for all seeds
        if 'p_1000' in df.columns:
            p_values = df['p_1000'].values
        else:
            # Use the last column if p_1000 doesn't exist
            p_values = df.iloc[:, -1].values

        # Combine p-values using the dependent test formula
        combined_p = combine_dependent_p_values(p_values, number_to_combine=number_to_combine)
        return combined_p

    except Exception as e:
        print(f"Error reading {csv_path}: {e}")
        return np.nan


def read_p_value(csv_path):
    """
    Original function - kept for backward compatibility
    """
    try:
        df = pd.read_csv(csv_path)
        row = df[df['seed'] == 'seed_0']
        if row.empty:
            return np.nan
        if 'p_1000' in row.columns:
            return float(row['p_1000'].values[0])
        else:
            return float(row.iloc[0, -1])
    except Exception as e:
        print(f"Error reading {csv_path}: {e}")
        return np.nan


def main():
    parser = argparse.ArgumentParser(description="Generate heatmap from linear_di p-values.")
    parser.add_argument('--base_dir', type=str, required=True, help='Base results_linear_di directory')
    parser.add_argument('--pvalues_subdir', type=str, required=True, help='Subdirectory with p-value csv files')
    parser.add_argument('--output_name', type=str, default='linear_di_heatmap.jpg', help='Output image filename')
    parser.add_argument('--combine_seeds', action='store_true', help='Combine p-values from all seeds using dependent test formula')
    parser.add_argument('--number_to_combine', type=int, default=3, help='Number of p-values to combine when using --combine_seeds (default: 3)')
    parser.add_argument('--title', type=str, default='', help='Title for the plot')
    args = parser.parse_args()

    pvalues_dir = os.path.join(args.base_dir, args.pvalues_subdir)
    datasets = []
    member_pvalues = []
    nonmember_pvalues = []

    for fname in os.listdir(pvalues_dir):
        if fname.startswith("train_val_") and fname.endswith(".csv"):
            ds_name = fname[len("train_val_"):-4]
            datasets.append(ds_name)

    datasets_sorted = sorted(datasets)

    # Choose which function to use based on the combine_seeds flag
    def read_func(path):
        if args.combine_seeds:
            return read_p_values_and_combine(path, number_to_combine=args.number_to_combine)
        else:
            return read_p_value(path)

    for ds_name in datasets_sorted:
        train_path = os.path.join(pvalues_dir, f"train_val_{ds_name}.csv")
        test_path = os.path.join(pvalues_dir, f"test_val_{ds_name}.csv")
        member_pvalues.append(read_func(train_path))
        nonmember_pvalues.append(read_func(test_path))

    data = np.array([member_pvalues, nonmember_pvalues]) 
    row_labels = ["Member", "Non-member"]
    col_labels = datasets_sorted

    custom_cmap = mcolors.LinearSegmentedColormap.from_list(
        "red_white_blue_005",
        [
            (0.0, "#ff6666"),   # red for low p-values
            (0.05, "white"),   # white at p=0.05
            (1.0, "blue")      # blue for high p-values
        ]
    )

    fig, ax = plt.subplots(figsize=(max(10, len(col_labels)*0.7), 3.5))
    cbar_kw = {"shrink": 0.5, "aspect": 20}
    im, cbar = heatmap(
        data,
        row_labels,
        col_labels,
        ax=ax,
        cmap=custom_cmap,
        cbarlabel="P value",
        cbar_kw=cbar_kw
    )

    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            val = data[i, j]
            txt = "<1e-3" if val < 0.001 else f"{val:.3f}" if not np.isnan(val) else "nan"
            ax.text(j, i, txt, ha="center", va="center", color="black")

    # Set title from argument
    ax.set_title(args.title, fontsize=16)

    plt.tight_layout()
    out_path = os.path.join(args.base_dir, os.path.splitext(args.output_name)[0] + ".pdf")
    plt.savefig(out_path)
    print(f"Saved heatmap to {out_path}")


if __name__ == "__main__":
    main()
