import math

import torch

try:
    from .semiring import TensorSemiring
except ImportError:
    from semiring import TensorSemiring



class LogCountingSemiring(TensorSemiring):

    def __init__(self, size: int):
        super().__init__()
        self.size = size

    def add(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        return torch.logaddexp(a, b)

    def add_in_place(self, a: torch.Tensor, b: torch.Tensor) -> None:
        torch.logaddexp(a, b, out=a)

    def add_one_in_place(self, a: torch.Tensor) -> None:
        out = a[..., 0]
        torch.logaddexp(out, a.new_zeros(()), out=out)

    def sum(self, a: torch.Tensor, dims: tuple[int, ...]) -> torch.Tensor:
        if dims:
            return torch.logsumexp(a, dim=dims)
        else:
            return a
    
    @classmethod 
    def multiply(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        A = a.unsqueeze(-2)
        r = torch.arange(b.size(-1), device=b.device)
        i = r[:, None] - r[None, :]
        del r
        B = b[..., i]
        B[..., i < 0] = -math.inf
        del i
        return torch.logsumexp(A + B, dim=-1)

    def star(self, a: torch.Tensor) -> torch.Tensor:
        a_star = torch.empty_like(a)
        # The `torch.flip()` function always creates a copy instead of a view,
        # so we maintain a flipped copy of a_star to avoid flipping it for
        # every i.
        flipped_a_star = torch.empty_like(a)
        a0_star = log_star(a[..., 0])
        size = a.size(-1)
        for i in range(size):
            c = torch.logsumexp(a[..., 1:i+1] + flipped_a_star[..., size-i:], dim=-1)
            if i == 0:
                torch.logaddexp(c, a.new_zeros(()), out=c)
            a_star[..., i] = flipped_a_star[..., size-1-i] = a0_star + c
        return a_star

    def zeros(self, size: tuple[int, ...], dtype: torch.dtype, device: torch.device):
        return torch.full(size + (self.size,), -math.inf, dtype=dtype, device=device)

    def ones(self, size: tuple[int, ...], dtype: torch.dtype, device: torch.device):
        result = torch.full(size + (self.size,), -math.inf, dtype=dtype, device=device)
        result[..., 0] = 0
        return result

def log_star(a: torch.tensor) -> torch.tensor:
    return -torch.log(1 - torch.exp(a))



def test_complex_machine_with_multiple_accept_states():
    # Parameters for our test
    semiring_size = 15  # Track up to 14 occurrences
    semiring = LogCountingSemiring(semiring_size)
    
    # Create a 4-state machine where all states have a small acceptance probability
    # All states can accept with 0.05 probability
    
    # Approach 1: Standard Lehmann's (potentially problematic)
    W = torch.full((4, 4, semiring_size), -math.inf)
    
    # State 0 transitions
    W[0, 0, 1] = math.log(0.35)  # Self-loop emitting 'a'
    W[0, 1, 0] = math.log(0.30)  # To state 1 (no emit)
    W[0, 2, 0] = math.log(0.30)  # To state 2 (no emit)
    # 0.05 probability of accepting from state 0
    
    # State 1 transitions
    W[1, 0, 1] = math.log(0.30)  # To state 0 emitting 'a'
    W[1, 1, 0] = math.log(0.30)  # Self-loop (no emit)
    W[1, 3, 1] = math.log(0.30)  # To state 3 emitting 'a'
    # 0.05 probability of accepting from state 1
    
    # State 2 transitions
    W[2, 1, 1] = math.log(0.30)  # To state 1 emitting 'a'
    W[2, 2, 0] = math.log(0.35)  # Self-loop (no emit)
    W[2, 3, 0] = math.log(0.30)  # To state 3 (no emit)
    # 0.05 probability of accepting from state 2
    
    # State 3 transitions
    W[3, 0, 0] = math.log(0.30)  # To state 0 (no emit)
    W[3, 2, 1] = math.log(0.30)  # To state 2 emitting 'a'
    W[3, 3, 1] = math.log(0.30)  # Self-loop emitting 'a'
    # 0.05 probability of accepting from state 3
    
    # Initial state distribution (start at state 0 with probability 1)
    lambda_vec = torch.full((4, semiring_size), -math.inf)
    lambda_vec[0, 0] = 0.0  # log(1) = 0
    
    # Final acceptance weights (all states accept with 0.05 probability)
    rho_vec = torch.full((4, semiring_size), -math.inf)
    rho_vec[0, 0] = math.log(0.05)
    rho_vec[1, 0] = math.log(0.05)
    rho_vec[2, 0] = math.log(0.05)
    rho_vec[3, 0] = math.log(0.05)
    
    # Compute W* using your semiring's operations
    I = torch.full((4, 4, semiring_size), -math.inf)
    for i in range(4):
        I[i, i, 0] = 0.0  # Identity elements
    
    # Use a simple iterative approach to compute W*
    W_star = I.clone()
    curr = W.clone()
    for _ in range(200):  # More iterations for this complex machine
        # Add current term to the sum
        for i in range(4):
            for j in range(4):
                W_star[i, j] = semiring.add(W_star[i, j], curr[i, j])
        
        # Compute next term: curr = curr * W
        next_curr = torch.full((4, 4, semiring_size), -math.inf)
        for i in range(4):
            for j in range(4):
                for k in range(4):
                    temp = semiring.multiply(curr[i, k], W[k, j])
                    next_curr[i, j] = semiring.add(next_curr[i, j], temp)
        curr = next_curr
    
    # Compute path sum: λᵀW*ρ
    wrong_path_sum = torch.full((semiring_size,), -math.inf)
    for i in range(4):
        for j in range(4):
            temp = semiring.multiply(lambda_vec[i], W_star[i, j])
            temp = semiring.multiply(temp, rho_vec[j])
            wrong_path_sum = semiring.add(wrong_path_sum, temp)
    
    # Approach 2: With explicit accepting state
    # Create a 5-state automaton: [states 0-3, accepting state 4]
    W_correct = torch.full((5, 5, semiring_size), -math.inf)
    
    # Copy the same transitions as before
    for i in range(4):
        for j in range(4):
            W_correct[i, j] = W[i, j]
    
    # Add transitions to the accepting state (all with 0.05 probability)
    W_correct[0, 4, 0] = math.log(0.05)
    W_correct[1, 4, 0] = math.log(0.05)
    W_correct[2, 4, 0] = math.log(0.05)
    W_correct[3, 4, 0] = math.log(0.05)
    
    # Initial state
    lambda_vec_correct = torch.full((5, semiring_size), -math.inf)
    lambda_vec_correct[0, 0] = 0.0  # Start at state 0
    
    # Only the accepting state has final weight 1.0
    rho_vec_correct = torch.full((5, semiring_size), -math.inf)
    rho_vec_correct[4, 0] = 0.0  # log(1) = 0
    
    # Compute W* for the correct approach
    I_correct = torch.full((5, 5, semiring_size), -math.inf)
    for i in range(5):
        I_correct[i, i, 0] = 0.0  # Identity elements
    
    W_star_correct = I_correct.clone()
    curr = W_correct.clone()
    for _ in range(200):  # More iterations for this complex machine
        # Add current term to the sum
        for i in range(5):
            for j in range(5):
                W_star_correct[i, j] = semiring.add(W_star_correct[i, j], curr[i, j])
        
        # Compute next term: curr = curr * W
        next_curr = torch.full((5, 5, semiring_size), -math.inf)
        for i in range(5):
            for j in range(5):
                for k in range(5):
                    temp = semiring.multiply(curr[i, k], W_correct[k, j])
                    next_curr[i, j] = semiring.add(next_curr[i, j], temp)
        curr = next_curr
    
    # Compute path sum: λᵀW*ρ
    correct_path_sum = torch.full((semiring_size,), -math.inf)
    for i in range(5):
        for j in range(5):
            temp = semiring.multiply(lambda_vec_correct[i], W_star_correct[i, j])
            temp = semiring.multiply(temp, rho_vec_correct[j])
            correct_path_sum = semiring.add(correct_path_sum, temp)
    
    # Convert from log domain to probabilities
    wrong_probs = torch.exp(wrong_path_sum)
    correct_probs = torch.exp(correct_path_sum)
    
    # Print results
    print("Approach 1 (Wrong) - Accepting weight after Kleene star:")
    for i in range(min(10, semiring_size)):
        print(f"Probability of emitting 'a' {i} times: {wrong_probs[i]:.8f}")
        
    print("\nApproach 2 (Correct) - With explicit accepting state:")
    for i in range(min(10, semiring_size)):
        print(f"Probability of emitting 'a' {i} times: {correct_probs[i]:.8f}")
    
    # Calculate the sum of probabilities to check if they're normalized
    print(f"\nSum of wrong probabilities: {wrong_probs[:semiring_size].sum():.8f}")
    print(f"Sum of correct probabilities: {correct_probs[:semiring_size].sum():.8f}")
    
    # Calculate the absolute difference between approaches
    diff = torch.abs(wrong_probs - correct_probs)
    print(f"\nMaximum absolute difference: {diff.max():.8f}")
    print(f"Average absolute difference: {diff.mean():.8f}")
    
    return wrong_probs, correct_probs

if __name__ == "__main__":
    test_complex_machine_with_multiple_accept_states()