import torch
import numpy as np
import random

def generate_gene_lengths(n_genes):
    """Generate gene lengths using negative binomial distribution."""
    mrna_length_distrib = torch.distributions.NegativeBinomial(total_count=300, probs=0.02)
    lengths = ((mrna_length_distrib.sample((n_genes,))**3.9) + 150).int()
    return lengths

def generate_gene_clusters(n_genes, rdna_tandem_cluster_size=400, mhc_cluster_size=200):
    """Generate gene clusters including special clusters like rDNA and MHC."""
    # Initialize with special clusters
    cluster_sizes = [rdna_tandem_cluster_size, mhc_cluster_size]
    
    # Generate remaining clusters
    remaining_genes = n_genes - sum(cluster_sizes)
    while remaining_genes > 0:
        size = random.randint(1, 50)
        size = min(size, remaining_genes)
        cluster_sizes.append(size)
        remaining_genes -= size
    
    # Create clusters to genes mapping
    n_clusters = len(cluster_sizes)
    clusters_to_genes = np.zeros((n_clusters, n_genes))
    start = 0
    for i, clustersize in enumerate(cluster_sizes):
        clusters_to_genes[i, start:(start+clustersize)] = 1
        start += clustersize
    
    gene_cluster_label = []
    for i in range(n_clusters):
        gene_cluster_label.extend([i] * cluster_sizes[i])
    gene_cluster_label = np.array(gene_cluster_label)
    
    return cluster_sizes, n_clusters, clusters_to_genes, gene_cluster_label

def generate_gene_programs(n_clusters, n_gene_programs=100, frac_housekeeping=0.1):
    """Generate gene programs and their relationship to gene clusters."""
    n_housekeeping = int(n_clusters * frac_housekeeping)
    
    # Create program-to-cluster mapping
    programs_by_clusters = np.ones((n_gene_programs, n_clusters))
    programs_by_clusters[:,n_housekeeping:] = np.random.binomial(
        1, 0.1, (n_gene_programs, n_clusters-n_housekeeping)
    )
    
    return programs_by_clusters, n_housekeeping

def generate_gaps_and_enhancers(cluster_sizes, lengths):
    """Generate gaps between genes and clusters, and enhancer regions."""
    # Sample gap lengths
    avg_gap_length = (1e9 - sum(lengths)) / (len(lengths) - 1)
    avg_cluster_gap_length = 1000000
    avg_gene_gap_length = 5000
    
    # Create distributions for sampling
    cluster_gap_length_distrib = torch.distributions.Normal(
        loc=avg_cluster_gap_length, 
        scale=avg_cluster_gap_length/10
    )
    gene_gap_length_distrib = torch.distributions.Normal(
        loc=avg_gene_gap_length, 
        scale=avg_gene_gap_length/10
    )
    
    # Sample gaps
    gap_lengths = []
    for i in range(len(cluster_sizes)):
        cluster_size = cluster_sizes[i]
        # Sample gap length before each cluster
        gap_lengths.append(cluster_gap_length_distrib.sample().int())
        # Sample gap lengths between genes in the cluster
        for j in range(cluster_size-1):
            gap_lengths.append(gene_gap_length_distrib.sample().int())
    
    # Sample enhancer lengths
    avg_enhancer_length = 500
    enhancer_length_distrib = torch.distributions.Normal(
        loc=avg_enhancer_length, 
        scale=avg_enhancer_length/5
    )
    enhancer_lengths = enhancer_length_distrib.sample((len(cluster_sizes),)).int()
    
    return gap_lengths, enhancer_lengths