import os
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import re
from collections import defaultdict

def read_feature_importance(csv_path):
    """
    Reads feature importance values from a CSV file.
    Assumes the CSV has features as rows and a column 'importance' or uses the last column.
    Returns: feature names, importance values
    """
    try:
        df = pd.read_csv(csv_path)
        # print(f"Columns in {csv_path}: {df.columns.tolist()}")
        if 'Importance' in df.columns:
            values = df['Importance'].values
        else:
            values = df.iloc[:, -1].values
        features = df['Feature'].values if 'Feature' in df.columns else df.index.astype(str).values
        return features, values
    except Exception as e:
        print(f"Error reading {csv_path}: {e}")
        return [], []

def collect_feature_importance_matrix(base_dir, subdir, file_names):
    """
    Collects feature importance values for all files.
    Returns: feature list, matrix (features x files)
    """
    all_features = []
    feature_values_dict = {}
    for fname in file_names:
        path = os.path.join(base_dir, subdir, fname)
        features, values = read_feature_importance(path)
        feature_values_dict[fname] = dict(zip(features, values))
        # Collect features in order from the first file only
        if not all_features and len(features) > 0:
            all_features = list(features)
        else:
            # Add any new features not yet in all_features
            for feat in features:
                if feat not in all_features:
                    all_features.append(feat)
    matrix = []
    for feat in all_features:
        row = []
        for fname in file_names:
            val = feature_values_dict.get(fname, {}).get(feat, np.nan)
            row.append(val)
        matrix.append(row)
    return all_features, np.array(matrix)

def heatmap(data, row_labels, col_labels, ax=None, cmap="bwr", cbarlabel=""):
    if not ax:
        ax = plt.gca()
    im = ax.imshow(data, aspect='auto', cmap=cmap, vmin=-1, vmax=1)
    cbar = ax.figure.colorbar(im, ax=ax)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
    ax.set_xticks(np.arange(len(col_labels)))
    ax.set_xticklabels(col_labels, rotation=90)
    ax.set_yticks(np.arange(len(row_labels)))
    ax.set_yticklabels(row_labels)
    return im, cbar

def group_files_by_split_and_dataset(file_names):
    """
    Groups file names by split and dataset.
    Returns: dict {split: {dataset: [file_names]}}
    """
    pattern = r'^(train_val|test_val)_([a-zA-Z0-9]+)_seed_\d+\.csv$'
    groups = defaultdict(lambda: defaultdict(list))
    for fname in file_names:
        m = re.match(pattern, fname)
        if m:
            split, dataset = m.group(1), m.group(2)
            groups[split][dataset].append(fname)
    return groups

def average_feature_importance(base_dir, subdir, file_names):
    """
    Computes average feature importance for a group of files.
    Returns: feature list, averaged values
    """
    features, matrix = collect_feature_importance_matrix(base_dir, subdir, file_names)
    # Average across columns (files)
    avg_values = np.nanmean(matrix, axis=1)
    return features, avg_values

def main():
    parser = argparse.ArgumentParser(description="Generate heatmap from feature importance values.")
    parser.add_argument('--base_dir', type=str, required=True, help='Base directory')
    parser.add_argument('--feature_importance_subdir', type=str, required=True, help='Subdirectory with feature importance csv files')
    parser.add_argument('--output_name', type=str, default='feature_importance_heatmap.pdf', help='Output image filename')
    args = parser.parse_args()

    fi_dir = os.path.join(args.base_dir, args.feature_importance_subdir)
    file_names = []
    for fname in os.listdir(fi_dir):
        if (fname.startswith("train_val_") or fname.startswith("test_val_")) and fname.endswith(".csv"):
            file_names.append(fname)
    file_names_sorted = sorted(file_names)

    # Group files by split and dataset
    groups = group_files_by_split_and_dataset(file_names_sorted)

    custom_cmap = mcolors.LinearSegmentedColormap.from_list(
        "red_white_blue",
        [
            (0.0, "red"),    # red for -1
            (0.5, "white"),  # white for 0
            (1.0, "blue")    # blue for 1
        ]
    )

    for split in ['train_val', 'test_val']:
        if split not in groups:
            continue
        datasets = sorted(groups[split].keys())
        # Collect averaged feature importance for each dataset
        feature_list = None
        avg_matrix = []
        for dataset in datasets:
            files = groups[split][dataset]
            feats, avg_vals = average_feature_importance(fi_dir, "", files)
            if feature_list is None:
                feature_list = feats
            else:
                # Ensure feature order is consistent
                if list(feats) != list(feature_list):
                    # Align avg_vals to feature_list order
                    feat_to_val = dict(zip(feats, avg_vals))
                    avg_vals = [feat_to_val.get(f, np.nan) for f in feature_list]
            avg_matrix.append(avg_vals)
        avg_matrix = np.array(avg_matrix).T  # shape: features x datasets

        fig, ax = plt.subplots(figsize=(max(10, len(datasets)*0.7), max(8, len(feature_list)*0.4)))
        im, cbar = heatmap(avg_matrix, feature_list, datasets, ax=ax, cmap=custom_cmap, cbarlabel="Feature importance")

        for i in range(avg_matrix.shape[0]):
            for j in range(avg_matrix.shape[1]):
                val = avg_matrix[i, j]
                txt = f"{val:.2f}" if not np.isnan(val) else "nan"
                ax.text(j, i, txt, ha="center", va="center", color="black", fontsize=8)

        plt.tight_layout()
        out_name = f"{split}_feature_importance_heatmap.pdf"
        out_path = os.path.join(args.base_dir, out_name)
        plt.savefig(out_path, format='pdf')
        print(f"Saved heatmap to {out_path}")

if __name__ == "__main__":
    main()
