#!/usr/bin/env python
"""
Main script for running multimodal genomic data simulation.
This maintains the original structure while using modular components.
"""
import torch
import numpy as np
import math
from scipy import sparse
import json
import pandas as pd
import tqdm

# add src to path
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

from src.data.utils import set_seed
from src.data.genome import (generate_gene_lengths, generate_gene_clusters, generate_gene_programs,
                           generate_gaps_and_enhancers)
from src.data.cell_types import (generate_cell_type_hierarchy,
                               create_cell_type_program_mapping, connect_cell_types_to_genes)
from src.data.chromatin import (generate_stress_effect, generate_cell_cycle_effect,
                              calculate_open_chromatin, generate_peak_profiles)
from src.data.transcription import (calculate_transcription_params, generate_gene_expression_base,
                                 calculate_transcription, simulate_technical_effects)
from src.data.translation import (generate_aa_compositions, calculate_translation_params,
                                simulate_protein_variables, simulate_abundance)

if __name__ == "__main__":
    # Configure simulation parameters
    n_samples_total = 5e5
    n_cells = 10000  # batch size for sample generation
    n_genes = 20000  # protein coding genes
    rdna_tandem_cluster_size = 400
    mhc_cluster_size = 200
    n_stem_cell_types = 3
    data_dir = '01_data/mm_sim/'
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    
    # IMPORTANT: Set seed 0 for genome structure creation
    genome_seed = 0
    set_seed(genome_seed)
    
    # Generate genome structure (with seed 0)
    lengths = generate_gene_lengths(n_genes)
    cluster_sizes, n_clusters, clusters_to_genes, gene_cluster_labels = generate_gene_clusters(
        n_genes, rdna_tandem_cluster_size, mhc_cluster_size
    )
    # save clusters to genes as sparse npy matrix
    sparse.save_npz(data_dir+"clusters_to_genes.npz", sparse.csr_matrix(clusters_to_genes))
    gap_lengths, enhancer_lengths = generate_gaps_and_enhancers(cluster_sizes, lengths)
    # save enhancer lengths
    np.save(data_dir+"enhancer_lengths.npy", enhancer_lengths.numpy())
    
    # Generate cell type hierarchy (with seed 0)
    cell_types, n_cell_types = generate_cell_type_hierarchy(n_stem_cell_types)
    # save cell types as json
    with open(data_dir+"cell_types.json", "w") as f:
        json.dump(cell_types, f)
    programs_by_clusters, n_housekeeping = generate_gene_programs(n_clusters)
    # save programs by clusters as sparse npy matrix
    sparse.save_npz(data_dir+"programs_by_clusters.npz", sparse.csr_matrix(programs_by_clusters))
    cell_type_to_programs, cell_type_to_programs_mtrx, cell_type_levels = create_cell_type_program_mapping(
        n_stem_cell_types, cell_types, len(programs_by_clusters)
    )
    # save the cell_type_to_programs_mtrx as sparse npy matrix
    sparse.save_npz(data_dir+"cell_type_to_programs_mtrx.npz", sparse.csr_matrix(cell_type_to_programs_mtrx))
    ct_to_clusters, ct_to_genes = connect_cell_types_to_genes(
        cell_type_to_programs_mtrx, programs_by_clusters, clusters_to_genes
    )
    # save ct_to_genes as sparse npy matrix
    sparse.save_npz(data_dir+"ct_to_genes.npz", sparse.csr_matrix(ct_to_genes))
    
    # Generate chromatin accessibility effects (with seed 0)
    stress_closure, stress_closure_by_gene = generate_stress_effect(
        n_housekeeping, n_clusters, clusters_to_genes
    )
    cell_cycle_genes, cell_cycle_open = generate_cell_cycle_effect(n_genes)
    
    # Generate transcription parameters (with seed 0)
    transcription_probs, degradation_probs_1 = calculate_transcription_params(lengths)
    mean_expression = generate_gene_expression_base(n_genes)
    
    # Generate translation parameters (with seed 0)
    aa_compositions, ease_of_translation, seq_lengths, aa_freq_tensor = generate_aa_compositions(
        n_genes, lengths
    )
    translation_probs, degradation_probs_2 = calculate_translation_params(seq_lengths)

    # save all gene_specific params in a dataframe
    # print the shapes of each array/list/tensor
    gene_specific_params = {
        "gene_ids": np.arange(n_genes),
        "cluster_labels": gene_cluster_labels,
        "lengths": lengths.numpy(),
        "stress_closure": stress_closure_by_gene.numpy(),
        "cell_cycle_amplification (G1)": cell_cycle_genes[0,:].numpy(),
        "cell_cycle_amplification (S)": cell_cycle_genes[1,:].numpy(),
        "cell_cycle_amplification (G2)": cell_cycle_genes[2,:].numpy(),
        "cell_cycle_amplification (M)": cell_cycle_genes[3,:].numpy(),
        "cell_cycle_open (G1)": cell_cycle_open[0,:].numpy(),
        "cell_cycle_open (S)": cell_cycle_open[1,:].numpy(),
        "cell_cycle_open (G2)": cell_cycle_open[2,:].numpy(),
        "cell_cycle_open (M)": cell_cycle_open[3,:].numpy(),
        "transcription_probs": transcription_probs.numpy(),
        "mrna_degradation_probs": degradation_probs_1.numpy(),
        "mean_expression": mean_expression.numpy(),
        "translation_probs": translation_probs.numpy(),
        "translation_aa_comp": ease_of_translation.numpy(),
        "protein_degradation_probs": degradation_probs_2.numpy()
    }
    gene_specific_params_df = pd.DataFrame(gene_specific_params)
    gene_specific_params_df.to_csv(data_dir+"gene_specific_params.csv", index=False)
    
    # Process batches of cells with different seeds
    for cell_batch in tqdm.tqdm(range(math.ceil(n_samples_total / n_cells))):
        ###
        # Set batch-specific seed for cell sampling
        ###
        cell_sampling_seed = cell_batch + 1  # Different seed for each batch
        set_seed(cell_sampling_seed)
        #print(f"Processing batch {cell_batch} with seed {cell_sampling_seed}")
        
        ###
        # Sample cell-specific variables
        ###
        p_ct = torch.distributions.categorical.Categorical(1/n_cell_types*torch.ones(n_cell_types))
        ct = p_ct.sample((n_cells,))
        
        stress_level_distrib = torch.distributions.Bernoulli(probs=0.05)
        stress_level = stress_level_distrib.sample((n_cells,))
        
        cell_cycle_distrib = torch.distributions.categorical.Categorical(torch.tensor([0.1, 0.2, 0.3, 0.4]))
        cell_cycle = cell_cycle_distrib.sample((n_cells,))
        
        transcription_activity_distribution = torch.distributions.poisson.Poisson(rate=4.0)
        transcription_activity = torch.clamp(transcription_activity_distribution.sample((n_cells,)), min=0, max=9) + 1
        
        ###
        # chromatin accessibility
        ###
        open_chromatin_per_sample = calculate_open_chromatin(
            n_cells, ct, transcription_activity, ct_to_genes,
            cell_cycle, cell_cycle_genes, cell_cycle_open,
            stress_level, stress_closure_by_gene
        )

        # Generate peak profiles
        peaks, peaks_nonoise = generate_peak_profiles(
            n_cells, cluster_sizes, lengths, enhancer_lengths, gap_lengths, open_chromatin_per_sample
        )
        
        ###
        # transcriptomics
        ###
        # Calculate DNA damage probability
        damage_distrib = torch.distributions.Beta(concentration1=1.0, concentration0=2.0)
        damage_prob = (damage_distrib.sample((n_cells,)) / 10)
        ct_diff_stage = [cell_type_levels[ct[i].item()] for i in range(n_cells)]
        damage_prob = damage_prob * torch.tensor(ct_diff_stage)
        
        # Calculate transcription
        potential_transcription, real_transcription = calculate_transcription(
            open_chromatin_per_sample, mean_expression, damage_prob, 
            transcription_probs, degradation_probs_1
        )
        
        # Apply technical effects to transcription
        rna_dropout_rates = [0.4, 0.3, 0.2]
        rna_efficiency_rates = [0.3, 0.25, 0.2]
        observed_transcription, dropout_mask, capture_efficiency, mrna_batches = simulate_technical_effects(
            real_transcription, n_cells, n_genes, rna_dropout_rates, rna_efficiency_rates
        )
        
        ###
        # proteomics
        ###
        # Generate protein variables
        protein_sample_variables, protein_dropout_mask, prot_batches = simulate_protein_variables(
            n_cells, real_transcription, cluster_sizes, n_genes
        )
        
        # Simulate protein abundance
        prots_translated, prots_real, prot_counts = simulate_abundance(
            real_transcription, real_transcription.sum(1), ease_of_translation,
            translation_probs, degradation_probs_2, protein_sample_variables, protein_dropout_mask
        )

        ###
        # save results per batch
        ###
        causal_variables = torch.cat((
            ct.unsqueeze(1),
            stress_level.unsqueeze(1),
            cell_cycle.unsqueeze(1),
            transcription_activity.unsqueeze(1),
            damage_prob.unsqueeze(1),
            #capture_efficiency.unsqueeze(1),
            mrna_batches.unsqueeze(1),
            protein_sample_variables[:,:-1],
            prot_batches.unsqueeze(1)
        ), dim=1)
        ###
        # saving all data
        ###
        column_names = ["cell_type", "stress_level", "cell_cycle", "transcription_activity", "damage_prob", "mrna_batch_effect", "ribosome_rate", "free_ribosomes", "tRNA_availability", "proteasome_activity", "prot_batch_effect"]
        # causal variables
        causal_variables_df = pd.DataFrame(causal_variables.numpy(), columns=column_names)
        causal_variables_df.to_csv(data_dir+f"causal_variables_batch_{cell_batch}.csv", index=False)
        
        # ATAC
        sparse.save_npz(data_dir+f"peaks_batch_{cell_batch}.npz", peaks) # noisy data
        if cell_batch < 10:
            open_chromatin_per_sample = sparse.csr_matrix(open_chromatin_per_sample.numpy())
            sparse.save_npz(data_dir+f"open_chromatin_batch_{cell_batch}.npz", open_chromatin_per_sample) # raw data
            peaks_nonoise = sparse.csr_matrix(peaks_nonoise.numpy())
            sparse.save_npz(data_dir+f"peaks_nonoise_batch_{cell_batch}.npz", peaks_nonoise) # processed, non-noisy data

        # RNA
        observed_transcription = sparse.csr_matrix(observed_transcription.numpy())
        sparse.save_npz(data_dir+f"observed_transcription_batch_{cell_batch}.npz", observed_transcription) # noisy data
        if cell_batch < 10:
            potential_transcription = sparse.csr_matrix(potential_transcription.numpy())
            sparse.save_npz(data_dir+f"potential_transcription_batch_{cell_batch}.npz", potential_transcription) # raw data
            real_transcription = sparse.csr_matrix(real_transcription.numpy())
            sparse.save_npz(data_dir+f"real_transcription_batch_{cell_batch}.npz", real_transcription) # processed, non-noisy data

        # Prot
        prot_counts = sparse.csr_matrix(prots_real.numpy())
        sparse.save_npz(data_dir+f"prot_counts_batch_{cell_batch}.npz", prot_counts) # noisy data
        if cell_batch < 10:
            prots_translated = sparse.csr_matrix(prots_translated.numpy())
            sparse.save_npz(data_dir+f"prots_translated_batch_{cell_batch}.npz", prots_translated)  # raw data
            prots_real = sparse.csr_matrix(prots_real.numpy())
            sparse.save_npz(data_dir+f"prots_real_batch_{cell_batch}.npz", prots_real)  # processed, non-noisy data
        if cell_batch >= 10:
            break