import torch
import numpy as np

def prob_transcription_success(length):
    """Calculate probability of successful transcription based on length."""
    init_prob = 0.95
    end_prob = 0.95
    if length < 1000:
        init_prob = 0.9
    elif length > 10000:
        end_prob = 0.8
    return init_prob * end_prob

def prob_mrna_degradation(length):
    """Calculate probability of mRNA degradation based on length."""
    if length < 500:
        return 0.8
    elif length < 1000:
        return 0.5
    elif length > 10000:
        return 0.1
    else:
        return 0.3

def calculate_transcription_params(lengths):
    """Calculate transcription success and degradation probabilities for all genes."""
    # Calculate transcription probabilities
    transcription_probs = torch.tensor([prob_transcription_success(length) for length in lengths])
    transcription_probs += torch.distributions.normal.Normal(loc=0.0, scale=0.05).sample((len(lengths),))
    transcription_probs = torch.clamp(transcription_probs, 0.01, 0.99)
    
    # Calculate degradation probabilities
    degradation_probs = torch.tensor([prob_mrna_degradation(length) for length in lengths])
    degradation_probs += torch.distributions.normal.Normal(loc=0.0, scale=0.05).sample((len(lengths),))
    degradation_probs = torch.clamp(degradation_probs, 0.01, 0.99)
    
    return transcription_probs, degradation_probs

def generate_gene_expression_base(n_genes):
    """Generate base expression level for each gene."""
    mean_expression_distrib = torch.distributions.NegativeBinomial(total_count=1000, probs=0.01)
    mean_expression = mean_expression_distrib.sample((n_genes,)) + 1
    return mean_expression

def calculate_transcription(open_chromatin_per_sample, mean_expression, 
                           damage_prob, transcription_probs, degradation_probs):
    """Calculate actual transcription levels based on various factors."""
    # Calculate potential transcription
    potential_transcription = open_chromatin_per_sample * mean_expression.clone().unsqueeze(0)
    
    # Add noise to simulate sampling variation
    #noise_distribution = torch.distributions.normal.Normal(loc=1.0, scale=0.1)
    #noise = noise_distribution.sample(potential_transcription.shape)
    #potential_transcription *= noise
    potential_transcription = potential_transcription.int()
    
    # Calculate real transcription accounting for damage and efficiencies
    real_transcription = (potential_transcription * 
                         (1 - damage_prob.clone().unsqueeze(1)) * 
                         transcription_probs.clone().unsqueeze(0) * 
                         torch.exp(-degradation_probs.clone().unsqueeze(0))).int()
    
    return potential_transcription, real_transcription

def simulate_technical_effects(real_transcription, n_cells, n_genes, 
                              rna_dropout_rates, rna_efficiency_rates, n_cov_rna=3):
    """Simulate technical batch effects in RNA-seq."""
    # Sample batch assignments
    p_batch = torch.distributions.categorical.Categorical(1/n_cov_rna*torch.ones(n_cov_rna))
    batch_samples = p_batch.sample((n_cells,))
    
    # Create dropout mask
    dropout_distribution = torch.distributions.Bernoulli(probs=torch.tensor(rna_dropout_rates))
    dropout_mask = dropout_distribution.sample((n_cells, n_genes))[torch.arange(n_cells), :, batch_samples]
    
    # Sample capture efficiencies
    capture_efficiency = (torch.ones((n_cells, n_cov_rna)) * 
                         torch.tensor(rna_efficiency_rates))[torch.arange(n_cells), batch_samples]
    
    # Add noise to capture efficiency
    capture_efficiency += torch.distributions.normal.Normal(loc=0.0, scale=0.02).sample((n_cells,))
    capture_efficiency = torch.clamp(capture_efficiency, 0.01, 0.7)
    
    # Apply technical effects
    observed_transcription = (real_transcription * capture_efficiency.clone().unsqueeze(1)).int()
    observed_transcription[dropout_mask.bool()] = 0
    
    return observed_transcription, dropout_mask, capture_efficiency, batch_samples