# Load TRRUST and generate GRN with added noisy edges
import pandas as pd
import numpy as np
from SERGIO.sergio import sergio
from collections import defaultdict

# Load TRRUST file
df = pd.read_csv("trrust_rawdata.mouse.tsv", sep="\t", header=None,
                 names=["TF", "Target", "Mode", "PMID"])

# Select top 5 TFs with most targets
n_tfs = 5

# Pick candidate top TFs
candidate_tfs = df["TF"].value_counts().head(n_tfs).index.tolist()

# Pick top n_tfs from independent TFs
top_tfs = candidate_tfs
print("Picked master regulators:", top_tfs)
df = df[df["TF"].isin(top_tfs)]

# Ignore "Unknown" interactions
df = df[df["Mode"].isin(["Activation", "Repression"])]

# Only keep 30% of genes regulated by 1 TF (to have more challenging mixture)
gene_to_regulators = defaultdict(set)
for _, row in df.iterrows():
    gene_to_regulators[row["Target"]].add(row["TF"])
filtered_genes_single_TF = [gene for gene, tfs in gene_to_regulators.items() if len(tfs) == 1]
del_single_genes = np.random.choice(filtered_genes_single_TF, size=int(0.7 * len(filtered_genes_single_TF)), replace=False)
df = df[~df["Target"].isin(del_single_genes)]

# Summary
gene_to_regulators = defaultdict(set)
for _, row in df.iterrows():
    gene_to_regulators[row["Target"]].add(row["TF"])
total_genes = len(gene_to_regulators)
print("Total genes after filtering:", total_genes)
for i in range(n_tfs):
    filtered_genes = [gene for gene, tfs in gene_to_regulators.items() if len(tfs) == i+1]
    print(f"Number of genes regulated by {i+1} TFs: {len(filtered_genes)}")

# Build gene index
all_genes = pd.unique(df[["TF", "Target"]].values.ravel())
gene2idx = {gene: i for i, gene in enumerate(all_genes)}
idx2gene = {i: gene for gene, i in gene2idx.items()}
tf2idx = {tf: gene2idx[tf] for tf in top_tfs}
n_genes = len(gene2idx)

# Write regulators.csv
regulator_rows = [[gene2idx[tf], 5.0 + 2.0 * np.random.rand()] for tf in top_tfs]
pd.DataFrame(regulator_rows).to_csv("regulators.csv", index=False, header=False)

# Prepare GRN
noise_k_range = (0.1, 1.0)
strong_k_range = (1.5, 1.8)
coop_range = (1, 4)

target_rows = []
for target in all_genes:
    if target in top_tfs:
        continue
    target_idx = gene2idx[target]
    # Identify real regulators from TRRUST
    real_regulators = df[df["Target"] == target]["TF"].unique()
    real_tf_idxs = [gene2idx[tf] for tf in real_regulators if tf in tf2idx]

    tf_indices = []
    k_vals = []
    coop_vals = []

    for tf in top_tfs:
        tf_idx = gene2idx[tf]
        tf_indices.append(tf_idx)
        if tf in real_regulators:
            mode = df[(df["Target"] == target) & (df["TF"] == tf)]["Mode"].iloc[0]
            sign = -1 if mode == "Repression" else 1
            k = sign * np.random.uniform(*strong_k_range)
        else:
            k = np.random.uniform(*noise_k_range) * np.random.choice([-1, 1])
        c = np.random.randint(*coop_range)
        k_vals.append(k)
        coop_vals.append(c)

    row = [target_idx, len(tf_indices)] + tf_indices + k_vals + coop_vals
    target_rows.append(row)

# Pad and save
target_len = max(len(row) for row in target_rows)
target_rows = [row + [""] * (target_len - len(row)) for row in target_rows]
pd.DataFrame(target_rows).to_csv("targets.csv", index=False, header=False)

# Run SERGIO simulation
master_regulators = [gene2idx[tf] for tf in top_tfs]

sim = sergio(number_genes=n_genes,
             number_bins=1,
             number_sc=20000,
             noise_params=0.2,
             decays=0.8,
             sampling_state=15,
             noise_type='dpd')

sim.build_graph(input_file_taregts="targets.csv",
                input_file_regs="regulators.csv",
                shared_coop_state=0)

sim.simulate()

# Retrieve simulated expression and TF activity
X = sim.getExpressions()[0].T
true_tfs = X[:, master_regulators]

# Save to CSV
np.savetxt("X.csv", X, delimiter=",")
np.savetxt("true_tfs.csv", true_tfs, delimiter=",")
