import numpy as np
import pandas as pd
from scipy import stats
from typing import Tuple, Optional, Dict, Any

class TFSimulatedData:
    """
    Simulated dataset with transcription factors and regulated/non-regulated genes.
    
    The simulation creates:
    - 100 hidden TFs with Poisson-distributed activity
    - 1000 genes directly regulated by TFs through a sparse binary matrix
    - 10000 non-regulated genes with high variance
    """
    
    def __init__(
        self,
        n_samples: int = 10000,
        n_tfs: int = 100,
        n_regulated_genes: int = 1000,
        n_nonregulated_genes: int = 10000,
        tf_lambda_range: Tuple[float, float] = (0.5, 5.0),
        tf_gene_sparsity: float = 0.95,  # 95% zeros in the regulatory matrix
        noise_level: float = 0.1,
        dropout_rate: float = 0.2,
        random_state: Optional[int] = None
    ):
        """
        Initialize the TF-gene simulation.
        
        Args:
            n_samples: Number of samples (cells)
            n_tfs: Number of transcription factors
            n_regulated_genes: Number of genes directly regulated by TFs
            n_nonregulated_genes: Number of non-regulated genes (high variance)
            tf_lambda_range: Range of lambda parameters for Poisson distributions
            tf_gene_sparsity: Sparsity level for TF-gene regulatory matrix (0-1)
            noise_level: Level of noise to add to gene expression
            dropout_rate: Probability of dropout in gene expression
            random_state: Random seed for reproducibility
        """
        self.n_samples = n_samples
        self.n_tfs = n_tfs
        self.n_regulated_genes = n_regulated_genes
        self.n_nonregulated_genes = n_nonregulated_genes
        self.n_total_genes = n_regulated_genes + n_nonregulated_genes
        self.tf_lambda_range = tf_lambda_range
        self.tf_gene_sparsity = tf_gene_sparsity
        self.noise_level = noise_level
        self.dropout_rate = dropout_rate
        
        # Set random state
        if random_state is not None:
            np.random.seed(random_state)
            
        # Initialize data structures
        self.tf_expression = None
        self.tf_gene_matrix = None
        self.gene_expression = None
        self.gene_names = None
        self.tf_names = None
        
    def generate_data(self) -> Dict[str, Any]:
        """
        Generate the complete simulated dataset.
        
        Returns:
            Dict containing generated data components
        """
        # Generate TF expression (Poisson distributed)
        self._generate_tf_expression()
        
        # Generate TF-gene regulatory matrix (binary with sparsity)
        self._generate_tf_gene_matrix()
        
        # Generate regulated gene expression
        regulated_expression = self._generate_regulated_genes()
        
        # Generate non-regulated genes with high variance
        nonregulated_expression = self._generate_nonregulated_genes()
        
        # Combine gene expression
        self.gene_expression = np.hstack([regulated_expression, nonregulated_expression])
        
        # Apply noise
        self._add_noise()
        
        # Apply dropout
        self._apply_dropout()
        
        # Create gene and TF names
        self._create_feature_names()
        
        # Return data as dictionary
        return self.get_data_dict()
    
    def _generate_tf_expression(self):
        """Generate Poisson-distributed TF expression."""
        # Generate random lambda values for each TF
        min_lambda, max_lambda = self.tf_lambda_range
        tf_lambdas = np.random.uniform(min_lambda, max_lambda, self.n_tfs)
        
        # Generate Poisson-distributed expression for each TF
        self.tf_expression = np.zeros((self.n_samples, self.n_tfs))
        for i in range(self.n_tfs):
            self.tf_expression[:, i] = np.random.poisson(tf_lambdas[i], self.n_samples)
    
    def _generate_tf_gene_matrix(self):
        """Generate binary TF-gene regulatory matrix with controllable sparsity."""
        # Initialize matrix with zeros
        self.tf_gene_matrix = np.random.binomial(
            1, 1 - self.tf_gene_sparsity, (self.n_tfs, self.n_regulated_genes)
        )
        
        # Ensure each gene is regulated by at least one TF
        for gene_idx in range(self.n_regulated_genes):
            if np.sum(self.tf_gene_matrix[:, gene_idx]) == 0:
                # Randomly select one TF to regulate this gene
                tf_idx = np.random.randint(0, self.n_tfs)
                self.tf_gene_matrix[tf_idx, gene_idx] = 1
    
    def _generate_regulated_genes(self) -> np.ndarray:
        """
        Generate expression for genes regulated by TFs.
        
        Returns:
            Array of regulated gene expression
        """
        # Calculate expression of regulated genes based on TF activity
        regulated_expression = np.dot(self.tf_expression, self.tf_gene_matrix)
        
        # Add baseline expression
        baseline = np.random.gamma(2, 2, size=self.n_regulated_genes)
        regulated_expression += baseline[np.newaxis, :]
        
        return regulated_expression
    
    def _generate_nonregulated_genes(self) -> np.ndarray:
        """
        Generate expression for non-regulated genes with high variance.
        
        Returns:
            Array of non-regulated gene expression
        """
        # Generate base expression for non-regulated genes
        nonregulated_expression = np.zeros((self.n_samples, self.n_nonregulated_genes))
        
        # Use negative binomial to create high-variance genes
        for i in range(self.n_nonregulated_genes):
            # Parameters for negative binomial: higher r = higher variance
            r = np.random.uniform(2, 10)
            p = np.random.uniform(0.1, 0.5)
            nonregulated_expression[:, i] = np.random.negative_binomial(r, p, self.n_samples)
        
        return nonregulated_expression
    
    def _add_noise(self):
        """Add Gaussian noise to gene expression data."""
        noise = np.random.normal(0, self.noise_level * np.mean(self.gene_expression), 
                                 self.gene_expression.shape)
        self.gene_expression = np.maximum(0, self.gene_expression + noise)
    
    def _apply_dropout(self):
        """Apply dropout to gene expression data."""
        dropout_mask = np.random.binomial(1, 1 - self.dropout_rate, self.gene_expression.shape)
        self.gene_expression = self.gene_expression * dropout_mask
    
    def _create_feature_names(self):
        """Create names for TFs and genes."""
        self.tf_names = [f"TF_{i+1}" for i in range(self.n_tfs)]
        self.gene_names = [f"Gene_reg_{i+1}" for i in range(self.n_regulated_genes)] + \
                         [f"Gene_nonreg_{i+1}" for i in range(self.n_nonregulated_genes)]
    
    def get_data_dict(self) -> Dict[str, Any]:
        """
        Get the generated data as a dictionary.
        
        Returns:
            Dictionary containing all generated data components
        """
        return {
            "tf_expression": pd.DataFrame(self.tf_expression, columns=self.tf_names),
            "gene_expression": pd.DataFrame(self.gene_expression, columns=self.gene_names),
            "tf_gene_matrix": pd.DataFrame(self.tf_gene_matrix, index=self.tf_names, 
                                          columns=self.gene_names[:self.n_regulated_genes]),
            "tf_names": self.tf_names,
            "gene_names": self.gene_names,
            "n_regulated_genes": self.n_regulated_genes,
            "n_nonregulated_genes": self.n_nonregulated_genes
        }

    def save_data(self, output_dir: str):
        """
        Save generated data to CSV files.
        
        Args:
            output_dir: Directory to save the data
        """
        data_dict = self.get_data_dict()
        data_dict["tf_expression"].to_csv(f"{output_dir}/tf_expression.csv")
        data_dict["gene_expression"].to_csv(f"{output_dir}/gene_expression.csv")
        data_dict["tf_gene_matrix"].to_csv(f"{output_dir}/tf_gene_matrix.csv")