import torch
import numpy as np
from scipy.sparse import csr_matrix

def generate_stress_effect(n_housekeeping, n_clusters, clusters_to_genes):
    """Generate stress-induced closure of specific gene clusters."""
    stress_closure = torch.tensor(np.concatenate((
        np.zeros(n_housekeeping), 
        np.random.binomial(1, 0.1, n_clusters-n_housekeeping)
    )))
    stress_closure_by_gene = torch.clamp(stress_closure @ clusters_to_genes, min=0, max=1)
    return stress_closure, stress_closure_by_gene

def generate_cell_cycle_effect(n_genes, cell_cycle_phases=['G1', 'S', 'G2', 'M']):
    """Generate cell cycle effects on gene expression."""
    cell_cycle_gene_distribution = torch.distributions.Poisson(rate=1.0)
    cell_cycle_genes = cell_cycle_gene_distribution.sample((len(cell_cycle_phases), n_genes))
    
    cell_cycle_open_distribution = torch.distributions.Bernoulli(probs=0.99)
    cell_cycle_open = cell_cycle_open_distribution.sample((len(cell_cycle_phases), n_genes))
    
    return cell_cycle_genes, cell_cycle_open

def calculate_open_chromatin(n_cells, ct, transcription_activity, ct_to_genes, 
                             cell_cycle, cell_cycle_genes, cell_cycle_open,
                             stress_level, stress_closure_by_gene):
    """Calculate open chromatin regions based on cell type and other factors."""
    # Cell type specific accessibility
    open_chromatin_per_sample = (transcription_activity.clone().unsqueeze(1).unsqueeze(1) * 
                                 ct_to_genes.clone().unsqueeze(0))[torch.arange(n_cells), ct]
    
    # Include cell cycle effect
    open_chromatin_per_sample = (open_chromatin_per_sample.unsqueeze(1) + 
                                cell_cycle_genes.clone().unsqueeze(0))[torch.arange(n_cells), cell_cycle]
    open_chromatin_per_sample = (open_chromatin_per_sample.unsqueeze(1) * 
                                cell_cycle_open.clone().unsqueeze(0))[torch.arange(n_cells), cell_cycle]
    
    # Apply stress-induced closure
    open_chromatin_per_sample = open_chromatin_per_sample * (1 - (
        stress_level.clone().unsqueeze(1) * stress_closure_by_gene.clone().unsqueeze(0)
    ))
    
    return open_chromatin_per_sample

def generate_peak_profiles(n_cells, cluster_sizes, lengths, enhancer_lengths, gap_lengths, open_chromatin_per_sample):
    """Generate chromatin accessibility peak profiles."""
    peaks = []
    peaks_nonoise = []
    window_size = 500
    
    for i in range(len(cluster_sizes)):
        cluster_size = cluster_sizes[i]
        start_gene = sum(cluster_sizes[:i])
        end_gene = sum(cluster_sizes[:i+1]) - 1
        
        # Handle enhancer regions
        enhancer_l = min(1, int(enhancer_lengths[i] / window_size))
        enhancer_prob = torch.clamp(
            (open_chromatin_per_sample[:,start_gene:end_gene+1].sum(1) / cluster_size).int() - 2, 
            min=0
        ).unsqueeze(1).expand(-1, enhancer_l)
        
        # Apply sampling efficiency for enhancers
        sampling_efficiency = torch.distributions.Beta(
            concentration1=10.0, concentration0=1.0
        ).sample((n_cells, 1)).expand(-1, enhancer_l) * torch.Tensor(
            [max(0.1,(1-i/100)) for i in range(enhancer_l)]
        ).unsqueeze(0).expand(n_cells, -1)
        
        enhancer_prob = enhancer_prob * sampling_efficiency
        peaks_nonoise.append(enhancer_prob.clone())
        enhancer_prob = enhancer_prob * torch.distributions.Bernoulli(probs=0.4).sample((n_cells, enhancer_l)).int() # add dropout
        peaks.append(enhancer_prob)

        gap_l = int(gap_lengths[start_gene] / window_size*10)
        gap_l = min(gap_l, 3)
        gap_prob = torch.distributions.Bernoulli(probs=0.01).sample((n_cells, gap_l)).int()
        peaks.append(gap_prob)
        peaks_nonoise.append(torch.zeros(n_cells, gap_l))
        
        # Handle gene bodies
        for j in range(cluster_size):
            gene_l = min(1, int(lengths[start_gene+j] / window_size))
            
            # Sample open chromatin for this gene
            open_chromatin = open_chromatin_per_sample[:, start_gene+j].clone().int().unsqueeze(1).expand(-1, gene_l)
            
            # Sampling efficiency decreases with distance from TSS
            sampling_efficiency = torch.distributions.Beta(
                concentration1=10.0, concentration0=1.0
            ).sample((n_cells, 1)).expand(-1, gene_l) * torch.Tensor(
                [max(0.01,(1-i/100)) for i in range(gene_l)]
            ).unsqueeze(0).expand(n_cells, -1)
            
            open_chromatin = open_chromatin * sampling_efficiency
            peaks_nonoise.append(open_chromatin.clone())
            open_chromatin = open_chromatin * torch.distributions.Bernoulli(probs=0.4).sample((n_cells, gene_l)).int()
            peaks.append(open_chromatin)
            gap_l = int(gap_lengths[start_gene+j] / window_size*10)
            gap_l = min(gap_l, 5)
            gap_prob = torch.distributions.Bernoulli(probs=0.01).sample((n_cells, gap_l)).int()
            peaks.append(gap_prob)
            peaks_nonoise.append(torch.zeros(n_cells, gap_l))
    
    peaks = np.concatenate(peaks, axis=1)
    peaks = csr_matrix(peaks)
    peaks_nonoise = torch.cat(peaks_nonoise, dim=1)
    
    return peaks, peaks_nonoise