from sympy import gamma
from logdiff.score.sampling_compositional import Expression
import torch

class ModelCompExpression(Expression):
    def log_probability(self, classifier, x, t):
        return 0.0
        
#### AND ###
# AND: constant mix of models 0.5 s(a) + 0.5 * s(b)
class AndModelsAB(ModelCompExpression):
    method = "ModelComp"

    # 0.5 s(a) + 0.5 * s(b)
    def __init__(self, left: Expression, right: Expression):
        super().__init__()
        self.left: Expression = left
        self.right: Expression = right

    def __str__(self):
        return f"({self.left} ∧ {self.right})"

    def to_fn(self, model, classifier, x_0_unconditional, xt, t, guidance_dict, ll_state, scheduler, score_cache, neg_guiding=False, method = "ModelComp"):
        left_fn = self.left.to_fn(model, classifier, x_0_unconditional, xt, t, guidance_dict, ll_state, scheduler, score_cache, neg_guiding, method=self.method)
        right_fn = self.right.to_fn(model, classifier, x_0_unconditional, xt, t, guidance_dict, ll_state, scheduler, score_cache, neg_guiding, method=self.method)

        def fn():      
            return guidance_dict["constant"]["and"] * (0.5 * left_fn() + 0.5 * right_fn())
        return fn
    


    
# AND: 
class AndSkreta(ModelCompExpression):
    method = "Skreta"

    def __init__(self, left, right):
        self.left = left
        self.right = right

    def solve_superposition_system(self, eps_l, eps_r, xt, t, scheduler):
        # 1. Convert to scores
        sigma_t = torch.sqrt(1 - scheduler.alphas_cumprod[t])
        s1 = -eps_l / sigma_t
        s2 = -eps_r / sigma_t
        
        # 2. Compute Inner Products (Eq. 15)
        e1 = s1.view(s1.shape[0], -1)
        e2 = s2.view(s2.shape[0], -1)
        dot11 = torch.sum(e1 * e1, dim=1)
        dot22 = torch.sum(e2 * e2, dim=1)
        dot12 = torch.sum(e1 * e2, dim=1)
        
        # 3. Solve for weights that satisfy the intersection
        # G * w = b where b = [||s1||^2, ||s2||^2]
        det = dot11 * dot22 - dot12**2 + 1e-8
        w1 = (dot11 * dot22 - dot12 * dot22) / det
        w2 = (dot11 * dot22 - dot11 * dot12) / det
        
        # 4. Construct result
        s_and = (w1.view(-1,1,1,1) * s1 + w2.view(-1,1,1,1) * s2)
        
        # Return as epsilon
        return -s_and * sigma_t

    def to_fn(self, model, classifier, x_0_unconditional, xt, t, guidance_dict, 
              ll_state, scheduler, score_cache, neg_guiding=False, method="Skreta"):
        
        left_fn = self.left.to_fn(model, classifier, x_0_unconditional, xt, t, 
                                  guidance_dict, ll_state, scheduler, score_cache, neg_guiding, method=self.method)
        right_fn = self.right.to_fn(model, classifier, x_0_unconditional, xt, t, 
                                   guidance_dict, ll_state, scheduler, score_cache, neg_guiding, method=self.method)

        def fn():
            eps_l = left_fn()
            eps_r = right_fn()
            
            # Solve the Skreta linear system
            eps_and = self.solve_superposition_system(eps_l, eps_r, xt, t, scheduler)
            
            # Density Update for AND (Composition)
            # The joint density is the sum of the densities (Eq. 14/15)
            l_l = ll_state[str(self.left)]
            l_r = ll_state[str(self.right)]
            ll_state[str(self)] = l_l + l_r
            
            return eps_and
            
        return fn
    

### OR ###

class OrModels(ModelCompExpression):
    method = "ModelComp"

    def __init__(self, left: Expression, right: Expression):
        super().__init__()
        self.left: Expression = left
        self.right: Expression = right

    def __str__(self):
        return f"({self.left} ∨ {self.right})"

    def log_probability(self, classifier, x, t):
        # log(P_left + P_right) = log(exp(log P_left) + exp(log P_right))
        log_p_left = self.left.log_probability(classifier, x, t)
        log_p_right = self.right.log_probability(classifier, x, t)
        
        log_p_or = torch.logaddexp(log_p_left, log_p_right)
        
        return torch.clamp_max(log_p_or, 0.0)

    def to_fn(self, model, classifier, x_0_unconditional, xt, t, guidance_dict, ll_state, scheduler, score_cache, neg_guiding=False, method="ModelComp"):
        log_p_L = self.left.log_probability(classifier, self.get_classifier_x(xt, x_0_unconditional), t)
        log_p_R = self.right.log_probability(classifier, self.get_classifier_x(xt, x_0_unconditional), t)
        log_p_OR = self.log_probability(classifier, self.get_classifier_x(xt, x_0_unconditional), t)

        left_fn = self.left.to_fn(model, classifier, x_0_unconditional, xt, t, guidance_dict, ll_state, scheduler, score_cache, neg_guiding=neg_guiding, method=self.method)
        right_fn = self.right.to_fn(model, classifier, x_0_unconditional, xt, t, guidance_dict, ll_state, scheduler, score_cache, neg_guiding=neg_guiding, method=self.method)
        
        # Calculate P_L / P_OR and P_R / P_OR in log-space
        log_weight_L = log_p_L - log_p_OR
        log_weight_R = log_p_R - log_p_OR
        
        # Convert to probability weights
        weight_L = torch.exp(log_weight_L)
        weight_R = torch.exp(log_weight_R)

        def fn():
            w_L = self.match_dims(weight_L, xt)
            w_R = self.match_dims(weight_R, xt)

            w_L = torch.clamp_max(w_L, torch.tensor(self.max_guidance_scale).to(device=xt.device))
            w_R = torch.clamp_max(w_R, torch.tensor(self.max_guidance_scale).to(device=xt.device))

            or_me_score = 1/(w_L + w_R) * (w_L * left_fn() + w_R * right_fn())
            return guidance_dict["ours"]["or_me"] * or_me_score
        return fn 
    
    
# https://arxiv.org/abs/2412.17762
class OrSkreta(ModelCompExpression):
    method = "Skreta"

    def __init__(self, left: Expression, right: Expression):
        super().__init__()
        self.left: Expression = left
        self.right: Expression = right

    def __str__(self):
        return f"({self.left} ∨ {self.right})"


    def to_fn(self, model, classifier, x_0_unconditional, xt, t, guidance_dict, 
              ll_state, scheduler, score_cache, neg_guiding=False, method="Skreta"):
        
        # Pass the score_cache down to children
        left_fn = self.left.to_fn(model, classifier, x_0_unconditional, xt, t, 
                                  guidance_dict, ll_state, scheduler, score_cache, neg_guiding, method=self.method)
        right_fn = self.right.to_fn(model, classifier, x_0_unconditional, xt, t, 
                                   guidance_dict, ll_state, scheduler, score_cache, neg_guiding, method=self.method)

        def fn():
            eps_l = left_fn()
            eps_r = right_fn()
            
            # 1. Retrieve the densities calculated in the PREVIOUS step
            l_l = ll_state[str(self.left)]
            l_r = ll_state[str(self.right)]
            
            # 2. Compute mixing weights via Softmax (Eq. 12)
            weights = torch.softmax(torch.stack([l_l, l_r]), dim=0)
            w_l, w_r = weights[0].view(-1,1,1,1), weights[1].view(-1,1,1,1)
            
            # 3. Combine the noise
            eps_or = w_l * eps_l + w_r * eps_r
            
            # 4. The joint density for an OR is the logsumexp of children
            # This is used by parent nodes if this OR is nested
            ll_state[str(self)] = torch.logsumexp(torch.stack([l_l, l_r]), dim=0)
            
            return eps_or
            
        return fn
    
class NotModels(ModelCompExpression):  
    method = "ModelComp"
    # s(∅)− alpha * s(a)
    def __init__(self, expression: Expression):
        super().__init__()
        self.expression: Expression = expression
        self.alpha = 0.07

    def __str__(self):
        return f"¬{self.expression}"

    def to_fn(self, model, classifier, x_0_unconditional, xt, t, guidance_dict, ll_state, scheduler, score_cache, neg_guiding=False, method="ModelComp"):
        atom_fn = self.expression.to_fn(model, classifier, x_0_unconditional, xt, t, guidance_dict, ll_state, scheduler, score_cache, neg_guiding, method=self.method)

        def fn():
            return x_0_unconditional + guidance_dict["constant"]["not"] * -self.alpha * atom_fn()
        return fn
