# analysis_common.py
import os
import math
import numpy as np
import pandas as pd
from scipy.stats import ttest_ind
import matplotlib.pyplot as plt

def _ensure_dir(path):
    os.makedirs(path, exist_ok=True)

def log2p1(x):
    return np.log2(x + 1.0)

def simple_cpm(df_counts):
    # counts per million normalization
    lib = df_counts.sum(axis=0)
    cpm = df_counts.divide(lib, axis=1) * 1e6
    return cpm

def differential_expression(counts, group_labels):
    """
    counts: DataFrame genes x samples (raw counts or abundance proxies)
    group_labels: list or array of same length as columns of counts
    Returns: DataFrame with log2FC, pval, padj (BH)
    """
    df = counts.copy()
    groups = pd.Series(group_labels, index=df.columns)
    gvals = sorted(groups.unique())
    if len(gvals) != 2:
        raise ValueError("This function currently supports exactly 2 groups.")
    g1, g2 = gvals[0], gvals[1]
    idx1 = groups[groups==g1].index
    idx2 = groups[groups==g2].index

    # log2-CPM for variance stabilization
    cpm = simple_cpm(df)
    logcpm = np.log2(cpm + 1.0)

    # compute stats per gene
    means1 = logcpm[idx1].mean(axis=1)
    means2 = logcpm[idx2].mean(axis=1)
    log2fc = means2 - means1  # (g2 vs g1)

    # Welch t-test on logCPM
    t_stat, pvals = ttest_ind(logcpm[idx2].T, logcpm[idx1].T, equal_var=False)
    res = pd.DataFrame({
        'gene': df.index,
        'log2FC': log2fc.values,
        'pval': pvals
    }).set_index('gene')

    # Benjamini-Hochberg
    m = len(res)
    order = res['pval'].rank(method='first')
    bh = res['pval'] * m / order
    res['padj'] = np.minimum.accumulate(bh.sort_values(ascending=False))[res['pval'].sort_values(ascending=False).index].reindex(res.index)
    # fill padj for any NaN
    res['padj'] = res['padj'].fillna(1.0)
    return res.sort_values('padj')

def volcano_plot(res_df, out_png, title="Volcano Plot", alpha=0.05, lfc_thresh=1.0):
    _ensure_dir(os.path.dirname(out_png))
    x = res_df['log2FC'].values
    y = -np.log10(res_df['pval'].clip(lower=1e-300)).values

    plt.figure(figsize=(6,5), dpi=150)
    plt.scatter(x, y, s=6)
    # thresholds
    plt.axhline(-math.log10(alpha), linestyle='--')
    plt.axvline(-lfc_thresh, linestyle='--')
    plt.axvline(lfc_thresh, linestyle='--')
    plt.xlabel("log2FC")
    plt.ylabel("-log10(p)")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_png)
    plt.close()

def heatmap_top_genes(counts, res_df, group_labels, out_png, top_n=20, title="Top DE genes heatmap"):
    _ensure_dir(os.path.dirname(out_png))
    # choose top_n by padj
    top_genes = res_df.sort_values('padj').head(top_n).index
    cpm = simple_cpm(counts.loc[top_genes])
    # z-score by gene
    z = (np.log2(cpm+1) - np.log2(cpm+1).mean(axis=1).values[:,None]) / (np.log2(cpm+1).std(axis=1).values[:,None] + 1e-9)

    plt.figure(figsize=(7,6), dpi=150)
    plt.imshow(z, aspect='auto')
    plt.colorbar(label='Z (log2-CPM)')
    plt.yticks(range(len(top_genes)), list(top_genes), fontsize=6)
    plt.xticks(range(len(counts.columns)), list(counts.columns), rotation=90, fontsize=6)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_png)
    plt.close()
