import torch

class ClassifierFreeGuidance:
    def __init__(self, 
                 w: float = 0.5, 
                 p_uncond: float = 0.1,
                 cond_to_disgard: list = None,
                 cond_to_ignore: list = None,
        ):
        self.w = w
        self.p_uncond = p_uncond
        self.cond_to_disgard = cond_to_disgard
        self.cond_to_ignore = cond_to_ignore
    
    def discard_conditions(self, conditions):
        if self.cond_to_disgard is not None: 
            assert len(self.cond_to_disgard) <= conditions.shape[1]
            assert all([0 <= x < conditions.shape[1] for x in self.cond_to_disgard])
        else:
            self.cond_to_disgard = range(conditions.shape[1])

        mask = torch.ones_like(conditions).float()
        is_class_cond = torch.rand(size=(conditions.shape[0], len(self.cond_to_disgard)), 
                                   device=conditions.device) >= self.p_uncond
        mask[:, self.cond_to_disgard] = is_class_cond.float()
        return conditions * mask
    
    @staticmethod
    def get_assets_from_model(model):
        if hasattr(model, "diffusion_model"):
            diffusion_model = model.diffusion_model
        else:
            diffusion_model = model.model.diffusion_model
        
        return diffusion_model
    
    def __call__(self, model, sample_flag=False, x_orig=None, x_noised=None,
                 t=None, pe_input=None, conditions=None, input_gene_list=None, 
                 target_gene_list=None):
        diffusion_model = self.get_assets_from_model(model)
        if sample_flag:
            if self.cond_to_ignore is not None:
                assert len(self.cond_to_ignore) <= conditions.shape[1]
                assert all([0 <= x < conditions.shape[1] for x in self.cond_to_ignore])
                conditions[:, self.cond_to_ignore] = 0

            pred_cond, _ = diffusion_model(x_orig, x_noised, t, pe_input=pe_input, 
                                           conditions=conditions, mask=False, 
                                           input_gene_list=input_gene_list,
                                           target_gene_list=target_gene_list)
            pred_uncond, _ = diffusion_model(x_orig, x_noised, t, pe_input=pe_input, 
                                             conditions=conditions * 0, mask=False, 
                                             input_gene_list=input_gene_list,
                                             target_gene_list=target_gene_list)
            pred = (1 + self.w) * pred_cond - self.w * pred_uncond
            return pred
        else:
            conditions = self.discard_conditions(conditions)
            pred, mask = diffusion_model(x_orig, x_noised, t, pe_input=pe_input,
                                         conditions=conditions, mask=True, 
                                         input_gene_list=input_gene_list,
                                         target_gene_list=target_gene_list)
            return pred, mask
        
        

