import torch
import numpy as np
import pandas as pd

def compute_translation_ease(aa_indices, aa_freqs, mask):
    """
    Vectorized function that computes translation ease for all genes at once.
    
    Args:
        aa_indices: tensor of shape [num_genes, max_seq_len]
        aa_freqs: tensor of frequencies for each amino acid
        mask: tensor indicating valid positions (not padding)
    
    Returns:
        ease: tensor of translation ease scores
    """
    # Get frequencies for each amino acid in the compositions
    freqs = aa_freqs[aa_indices]
    
    # Apply mask and sum frequencies
    masked_freqs = freqs * mask
    
    # Sum frequencies and divide by sequence length for each gene
    total_freqs = masked_freqs.sum(dim=1)
    seq_lengths = mask.sum(dim=1)
    
    # Avoid division by zero
    ease = total_freqs / torch.clamp(seq_lengths, min=1.0)
    return ease

def generate_aa_compositions(n_genes, lengths, aa_freq_file='./01_data/aa_freq_human.csv'):
    """Generate amino acid compositions for each gene."""
    # Load AA frequencies
    aa_freqs = pd.read_csv(aa_freq_file)
    aa_freqs['freq'] = aa_freqs['Observed Frequency in Vertebrates (%)'] / 100
    aa_freq_tensor = torch.tensor(aa_freqs['freq'].values)
    
    # Calculate sequence parameters
    max_seq_len = int(lengths.max() / 3) + 1
    seq_lengths = (lengths / 3).int()
    
    # Create mask for valid positions
    position_indices = torch.arange(max_seq_len).unsqueeze(0).expand(n_genes, -1)
    mask = (position_indices < seq_lengths.unsqueeze(1)).float()
    
    # Sample AA compositions
    aa_compositions = torch.zeros((n_genes, max_seq_len), dtype=torch.long)
    for i in range(n_genes):
        seq_len = seq_lengths[i].item()
        if seq_len > 0:
            aa_compositions[i, :seq_len] = torch.multinomial(
                input=aa_freq_tensor,
                num_samples=seq_len,
                replacement=True
            )
    
    # Calculate translation ease
    ease_of_translation = compute_translation_ease(aa_compositions, aa_freq_tensor, mask)
    
    return aa_compositions, ease_of_translation, seq_lengths, aa_freq_tensor

def prob_translation_success(length):
    """Calculate probability of successful translation based on sequence length."""
    init_prob = 0.9
    end_prob = 0.9
    if length < 1000:
        init_prob = 0.7
    elif length > 10000:
        end_prob = 0.7
    return init_prob * end_prob

def prob_degradation(length):
    """Calculate probability of protein degradation based on sequence length."""
    if length < 1000:
        return 0.5
    elif length > 10000:
        return 0.1
    else:
        return 0.3

def calculate_translation_params(seq_lengths):
    """Calculate translation probabilities for all genes."""
    translation_probs = torch.tensor([prob_translation_success(length) for length in seq_lengths])
    degradation_probs = torch.tensor([prob_degradation(length) for length in seq_lengths])
    return translation_probs, degradation_probs

def simulate_protein_variables(n_cells, real_transcription, cluster_sizes, n_genes):
    """Simulate protein-related variables for each cell."""
    # Ribosome movement rate
    ribosome_rate_distrib = torch.distributions.Beta(concentration1=1.0, concentration0=1.0)
    ribosome_rate = ribosome_rate_distrib.sample((n_cells,))
    
    # Ribosome availability (dependent on rDNA cluster)
    ribosome_availability = real_transcription[:, :cluster_sizes[0]].sum(axis=1)
    free_ribosomes = ribosome_availability * 100
    
    # tRNA availability
    tRNA_availability_distrib = torch.distributions.Beta(concentration1=2.0, concentration0=1.0)
    tRNA_availability = tRNA_availability_distrib.sample((n_cells,))
    
    # Proteasome activity
    proteasome_activity_distrib = torch.distributions.Beta(concentration1=1.0, concentration0=2.0)
    proteasome_activity = proteasome_activity_distrib.sample((n_cells,))
    
    # Technical batch effects
    n_prot_cov = 2
    p_noise = torch.distributions.categorical.Categorical(1/n_prot_cov * torch.ones(n_prot_cov))
    batch_samples = p_noise.sample((n_cells,))
    
    noise_distribution = torch.distributions.normal.Normal(torch.tensor([0.8, 0.9]), 0.03)
    prot_noise = torch.clamp(noise_distribution.sample((n_cells,))[torch.arange(n_cells), batch_samples], 0, 1)
    
    # Dropout
    prot_dropout_rates = [0.2, 0.4]
    dropout_distribution = torch.distributions.Bernoulli(probs=torch.tensor(prot_dropout_rates))
    dropout_mask = dropout_distribution.sample((n_cells, n_genes))[torch.arange(n_cells), :, batch_samples]
    prot_dropout_rates = [0.2, 0.4]
    
    # Combine variables
    protein_sample_variables = torch.cat((
        ribosome_rate.unsqueeze(1), 
        free_ribosomes.unsqueeze(1), 
        tRNA_availability.unsqueeze(1), 
        proteasome_activity.unsqueeze(1), 
        prot_noise.unsqueeze(1)
    ), dim=1)
    
    return protein_sample_variables, dropout_mask, batch_samples

def simulate_abundance(rna, total_molecules, avg_aa_freq, t_s_prob, d_s_prob, sample_vars, dropout):
    """
    Simulate protein abundance based on RNA levels and various factors.
    
    Args:
        rna: RNA expression tensor
        total_molecules: Sum of RNA molecules
        avg_aa_freq: Average amino acid frequency for each gene
        t_s_prob: Translation success probability
        d_s_prob: Degradation probability
        sample_vars: Sample-specific variables
        dropout: Dropout mask
    
    Returns:
        prots_real: Actual protein levels
        prots_observed: Observed protein levels after technical effects
    """
    # Calculate ribosome efficiency
    ribosome_efficiency = ((avg_aa_freq / 0.05) * t_s_prob).unsqueeze(0) * (
        (sample_vars[:,1] / total_molecules) * sample_vars[:,2]
    ).unsqueeze(1)
    
    # Calculate proteins translated
    prots_translated = rna * ribosome_efficiency
    
    # Calculate degradation rate
    degradation_rate = d_s_prob.unsqueeze(0) * sample_vars[:,3].unsqueeze(1)
    
    # Calculate real protein levels
    prots_real = prots_translated * torch.exp(-degradation_rate)
    
    # Apply technical noise and dropout
    prots_observed = prots_real * sample_vars[:,4].unsqueeze(1)
    prots_observed[dropout.bool()] = 0

    return prots_translated.int(), prots_real.int(), prots_observed.int()