import torch
import numpy as np

def generate_cell_type_hierarchy(n_stem_cell_types=3, n_cell_levels=4):
    """
    Generate a hierarchical cell type structure with stem cells,
    progenitors, differentiated and specialized cells.
    """
    cell_types = {
        'stem': ['stem'],
        'progenitor': ['stem', 'progenitor'],
        'differentiated': ['stem', 'progenitor', 'differentiated'],
        'specialized': ['stem', 'progenitor', 'differentiated', 'specialized']
    }
    
    # Calculate total number of cell types in the hierarchy
    n_cell_types = sum([n_stem_cell_types * 2**(i) for i in range(len(cell_types))])
    
    return cell_types, n_cell_types

def create_cell_type_program_mapping(n_stem_cell_types, cell_types, n_gene_programs):
    """Create mapping between cell types and gene programs."""
    cell_type_to_programs = {}
    cell_type_to_programs_mtrx = np.zeros((0, n_gene_programs))
    cell_type_levels = []
    mtrx_count = 0
    
    for i in range(n_stem_cell_types):
        # Set seed for reproducible cell line divergence
        np.random.seed(i)
        
        # Sample initial gene programs for stem cell
        active_programs = np.random.binomial(1, 0.9, n_gene_programs)
        
        # Initialize data structure
        cell_type_to_programs[i] = {}
        cell_type_to_programs[i][0] = {}
        cell_type_to_programs[i][0]['parents'] = [active_programs]
        
        # Add to matrix
        cell_type_to_programs_mtrx = np.vstack([cell_type_to_programs_mtrx, active_programs])
        cell_type_levels.append(0)
        mtrx_count += 1
        
        # Generate differentiated progeny
        for j in range(len(cell_types)-1):
            cell_type_to_programs[i][j]['children'] = []
            
            for parent in cell_type_to_programs[i][j]['parents']:
                for k in range(2):
                    np.random.seed(k)
                    active_programs = parent.copy()
                    
                    # Remove 10% of active programs
                    active_program_ids = np.where(active_programs == 1)[0]
                    remove_ids = np.random.binomial(1, 0.2, len(active_program_ids))
                    active_programs[active_program_ids[remove_ids == 1]] = 0
                    
                    cell_type_to_programs[i][j]['children'].append(active_programs)
                    cell_type_to_programs_mtrx = np.vstack([cell_type_to_programs_mtrx, active_programs])
                    cell_type_levels.append(j+1)
                    mtrx_count += 1
            
            cell_type_to_programs[i][j+1] = {}
            cell_type_to_programs[i][j+1]['parents'] = cell_type_to_programs[i][j]['children']
    
    # Reset random seed
    np.random.seed(0)
    
    return cell_type_to_programs, cell_type_to_programs_mtrx, cell_type_levels

def connect_cell_types_to_genes(cell_type_to_programs_mtrx, programs_by_clusters, clusters_to_genes):
    """Connect cell types to genes through programs and clusters."""
    ct_to_clusters = torch.tensor(cell_type_to_programs_mtrx @ programs_by_clusters)
    ct_to_clusters = ct_to_clusters / ct_to_clusters.max(axis=0).values
    ct_to_clusters = (ct_to_clusters > 0.5).float()
    
    ct_to_genes = torch.tensor(ct_to_clusters @ clusters_to_genes)
    
    return ct_to_clusters, ct_to_genes