import math
import torch
from random import random

try:
    from .semiring import TensorSemiring
    from .log_counting_semiring import LogCountingSemiring
except ImportError:
    from semiring import TensorSemiring
    from log_counting_semiring import LogCountingSemiring


class OverflowLogCountingSemiring(TensorSemiring):
    def __init__(self, size: int):
        super().__init__()
        assert size >= 2, "size must be >= 2"
        self.size = size

    def add(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        # Pointwise logaddexp
        return torch.logaddexp(a, b)

    def add_in_place(self, a: torch.Tensor, b: torch.Tensor) -> None:
        torch.logaddexp(a, b, out=a)

    def add_one_in_place(self, a: torch.Tensor) -> None:
        # a[...,0] = logaddexp(a[...,0], 0)
        torch.logaddexp(a[..., 0], a.new_zeros(()), out=a[..., 0])

    def sum(self, a: torch.Tensor, dims: tuple[int, ...]) -> torch.Tensor:
        if dims:
            return torch.logsumexp(a, dim=dims)
        else:
            return a

    @classmethod
    def multiply(cls, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        """
        Overflow Cauchy product implementation that adapts to variable input dimensions.
        """
        # Get the size from the last dimension of both tensors
        size_a = a.size(-1)
        size_b = b.size(-1)
        
        # Ensure sizes match
        assert size_a == size_b, f"Last dimension sizes must match: got {size_a} and {size_b}"
        size = size_a
        
        # Determine the output shape by broadcasting the batch dimensions
        batch_dims_a = a.shape[:-1]
        batch_dims_b = b.shape[:-1]
        
        # Create a dummy tensor to determine the broadcast batch shape
        dummy_a = torch.zeros(batch_dims_a + (1,), device=a.device)
        dummy_b = torch.zeros(batch_dims_b + (1,), device=b.device)
        broadcast_batch_shape = (dummy_a + dummy_b).shape[:-1]
        
        # Create output tensor with the broadcast batch shape
        out_shape = broadcast_batch_shape + (size,)
        out = torch.full(out_shape, -math.inf, dtype=a.dtype, device=a.device)
        
        # Handle each bin individually, adapting to the broadcast shape
        for i in range(size):
            bin_terms = []
            
            if i < size - 1:  # Regular bins
                # For each position in the result, we need all pairs where r+s=i
                for r in range(min(i + 1, size)):
                    s = i - r
                    if s < size:
                        # Get the values from a and b, keeping all batch dimensions
                        # Use indexing to maintain broadcasting behavior
                        a_term = a[..., r]
                        b_term = b[..., s]
                        
                        # Ensure the terms will broadcast correctly
                        while a_term.dim() < len(broadcast_batch_shape):
                            a_term = a_term.unsqueeze(0)
                        while b_term.dim() < len(broadcast_batch_shape):
                            b_term = b_term.unsqueeze(0)
                        
                        bin_terms.append(a_term + b_term)
            else:  # Overflow bin (i == size-1)
                # For overflow bin, we need all pairs where r+s >= size-1
                for r in range(size):
                    for s in range(size):
                        if r + s >= size - 1:
                            a_term = a[..., r]
                            b_term = b[..., s]
                            
                            # Ensure the terms will broadcast correctly
                            while a_term.dim() < len(broadcast_batch_shape):
                                a_term = a_term.unsqueeze(0)
                            while b_term.dim() < len(broadcast_batch_shape):
                                b_term = b_term.unsqueeze(0)
                            
                            bin_terms.append(a_term + b_term)
            
            if bin_terms:
                # Stack the results along a new dimension at the end
                stacked = torch.stack(bin_terms, dim=-1)
                
                # Reduce with logsumexp along the new dimension
                bin_result = torch.logsumexp(stacked, dim=-1)
                
                # Reshape bin_result to match the broadcast batch shape if needed
                if bin_result.shape != broadcast_batch_shape:
                    # This is a complex case where we need to reshape
                    # For simplicity, let's use broadcasting to handle it
                    bin_result_expanded = bin_result
                    while bin_result_expanded.dim() < len(broadcast_batch_shape):
                        bin_result_expanded = bin_result_expanded.unsqueeze(0)
                    
                    # Now use broadcasting to update the output
                    out[..., i] = bin_result_expanded
                else:
                    # Simple case, shapes match
                    out[..., i] = bin_result
        
        return out

    def star(self, x: torch.Tensor, max_iter=150, tol=1e-10) -> torch.Tensor:
        """
        Iterative Kleene star in log space:
          x^* = sum_{n=0..inf} x^n

        -  result = I
        -  term   = I
        -  loop:
             term = term * x  # semiring multiply
             result = result + term  # semiring add
           until convergence

        """
        result = self.ones(x.shape[:-1], x.dtype, x.device)
        term = result.clone()

        for _ in range(max_iter):
            term = self.multiply(term, x)
            new_result = self.add(result, term)
            diff = torch.max(torch.abs(new_result - result))
            result = new_result
            if diff.item() < tol:
                break
        return result

    def star_efficient(self, x):
        raise NotImplementedError("Missing, use the iterative and the fact that should sum to 1 with acceptance!")

    #def zeros(self, shape: tuple[int, ...], dtype: torch.dtype, device: torch.device):
    #    return torch.full(shape + (self.size,), float('-inf'), dtype=dtype, device=device)

    #def ones(self, shape: tuple[int, ...], dtype: torch.dtype, device: torch.device):
    #    out = torch.full(shape + (self.size,), float('-inf'), dtype=dtype, device=device)
    #    out[..., 0] = 0.0
    #    return out

    def zeros(self, size: tuple[int, ...], dtype: torch.dtype, device: torch.device):
        return torch.full(size + (self.size,), -math.inf, dtype=dtype, device=device)

    def ones(self, size: tuple[int, ...], dtype: torch.dtype, device: torch.device):
        result = torch.full(size + (self.size,), -math.inf, dtype=dtype, device=device)
        result[..., 0] = 0
        return result



def assert_close(a, b, atol=1e-7, rtol=1e-7, msg=""):
    if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
        diff = torch.max(torch.abs(a - b))
        assert float(diff) <= atol + rtol * float(torch.max(torch.abs(b))), f"{msg}  diff={diff}"
    else:
        if not math.isclose(a, b, abs_tol=atol, rel_tol=rtol):
            raise AssertionError(f"{msg}  |{a} - {b}| too large")


def test_zeros_ones():
    semiring = OverflowLogCountingSemiring(size=4)
    z = semiring.zeros((2,3), torch.float32, torch.device('cpu'))
    assert z.shape == (2,3,4), f"zeros shape mismatch, got {z.shape}"
    assert torch.isinf(z).all() and (z < 0).all(), "zeros should be -inf"

    o = semiring.ones((2,3), torch.float32, torch.device('cpu'))
    assert o.shape == (2,3,4), f"ones shape mismatch, got {o.shape}"
    assert torch.allclose(o[...,0], torch.zeros_like(o[...,0]))
    mask_infinite = torch.isinf(o)
    assert mask_infinite.sum() == (2*3*(4-1)), "expected all bins except bin0 to be -inf"


def test_add():
    semiring = OverflowLogCountingSemiring(size=3)
    a = torch.tensor([[0.0, -math.inf, -math.inf]]) 
    b = torch.tensor([[-math.inf, 0.0, -math.inf]]) 
    c = semiring.add(a, b)
    assert_close(c[0,0].item(), 0.0, msg="bin0 must be 0.0")
    assert_close(c[0,1].item(), 0.0, msg="bin1 must be 0.0")
    assert math.isinf(c[0,2].item()) and c[0,2].item() < 0, "bin2 must be -inf"


def test_multiply():
    semiring = OverflowLogCountingSemiring(size=3)
    a = torch.log(torch.tensor([[0.5, 0.5, 0.0]]))
    b = torch.log(torch.tensor([[0.2, 0.8, 0.0]]))
    out = semiring.multiply(a, b)
    out_exp = torch.exp(out)
    assert_close(out_exp[0,0].item(), 0.1, msg="bin0 mismatch")
    assert_close(out_exp[0,1].item(), 0.5, msg="bin1 mismatch")
    assert_close(out_exp[0,2].item(), 0.4, msg="bin2 mismatch")


def test_star_small():
    semiring = OverflowLogCountingSemiring(size=3)
    x = torch.log(torch.tensor([[0.3, 0.5, 0.1]]))
    x_star = semiring.star(x)
    x_star_exp = torch.exp(x_star)
    assert not torch.isnan(x_star_exp).any(), "NaN in star result" 


def test_compare_regular_and_overflow():
    N = 20 
    non_overflow = LogCountingSemiring(N)
    overflow = OverflowLogCountingSemiring(N+1)

    B = 2 
    x_prob = torch.rand(B, N) * 0.5
    row_sums = x_prob.sum(dim=-1, keepdim=True)

    scale = 0.9 / row_sums 
    x_prob = x_prob * scale
    x_log = torch.log(x_prob)

    x_overflow = torch.full((B, N+1), float('-inf'))
    x_overflow[..., :N] = x_log

    # Non-overflow version
    x2_non = non_overflow.multiply(x_log, x_log)
    # Overflow version
    x2_ovf = overflow.multiply(x_overflow, x_overflow)

    # Check that the first N bins match
    for i in range(N):
        # compare x2_non[...,i] vs x2_ovf[...,i]
        diff = torch.abs(x2_non[..., i] - x2_ovf[..., i])
        max_diff = diff.max().item()
        assert max_diff < 1e-6, (
            f"Mismatch in multiply result bin {i}:\n"
            f"  non-overflow={x2_non[..., i]}\n"
            f"  overflow=    {x2_ovf[..., i]}\n"
            f"  max_diff={max_diff}"
        )

    star_non = non_overflow.star(x_log)
    star_ovf = overflow.star(x_overflow)

    # Compare first N bins again
    for i in range(N):
        diff = torch.abs(star_non[..., i] - star_ovf[..., i])
        max_diff = diff.max().item()
        assert max_diff < 1e-6, (
            f"Mismatch in star result bin {i}:\n"
            f"  star_non={star_non[..., i]}\n"
            f"  star_ovf={star_ovf[..., i]}\n"
            f"  max_diff={max_diff}"
        )
    print("test_compare_regular_and_overflow: PASS")


def test_star():
    """
    In a sub-stochastic scenario, x* should sum to ~1 in normal space.
    We'll create a vector x such that sum(x) < 1, run star(x), and check.
    """
    from math import isclose

    N = 4
    semiring = OverflowLogCountingSemiring(size=N)

    acc_values = [0.3, 0, 0, 0]
    x_values = [0, 0.7, 0.0, 0.0]  
    x_log = torch.log(torch.tensor(x_values))
    acc_log = torch.log(torch.tensor(acc_values))

    # star it
    x_star = semiring.star(x_log)
    # convert to normal space and sum over bins
    x_star_exp = torch.exp(x_star)

    xr = semiring.multiply(x_star, acc_log)
    final_sum = torch.exp(xr).sum(dim=-1).item()

    # we want final_sum ~ 1
    print(f"Sub-stochastic star sum => {final_sum}")
    assert isclose(final_sum, 1.0, rel_tol=1e-5, abs_tol=1e-5), (
        f"Expected ~1 but got {final_sum}"
    )


def test_associativity():
    print("Testing associativity of OverflowLogCountingSemiring multiplication...")
    
    for size in [2, 3, 4, 10]:
        semiring = OverflowLogCountingSemiring(size=size)
        
        # We test a bunch of vectors
        test_vectors = [
            # Simple one-hot vectors
            torch.log(torch.tensor([[1.0, 0.0] + [0.0] * (size-2)])),
            torch.log(torch.tensor([[0.0, 1.0] + [0.0] * (size-2)])), 
            
            # Vectors with values in overflow bin
            torch.log(torch.tensor([[0.0] * (size-1) + [1.0]])), 
            
            # Vectors with distribution across multiple bins including overflow
            torch.log(torch.tensor([[0.4, 0.3] + [0.0] * (size-3) + [0.3]])),
            
            # Uniform vectors
            torch.log(torch.tensor([[1.0/size] * size])),
            
            # Random vectors (normalized to sum to 1)
            torch.log(torch.nn.functional.softmax(torch.randn(1, size), dim=-1)), 
            torch.log(torch.nn.functional.softmax(torch.randn(1, size), dim=-1)),
        ]
        
    # Test all combinations of three vectors
    for i, a in enumerate(test_vectors):
        for j, b in enumerate(test_vectors):
            for k, c in enumerate(test_vectors):
                ab = semiring.multiply(a, b)
                abc_left = semiring.multiply(ab, c)
                
                bc = semiring.multiply(b, c)
                abc_right = semiring.multiply(a, bc)
                
                max_diff = torch.max(torch.abs(abc_left - abc_right)).item()
                
                if max_diff > 1e-5:
                    print(f"Associativity violation detected with size={size}, vectors {i},{j},{k}")
                    print(f"a: {torch.exp(a)}")
                    print(f"b: {torch.exp(b)}")
                    print(f"c: {torch.exp(c)}")
                    print(f"(ab)c: {torch.exp(abc_left)}")
                    print(f"a(bc): {torch.exp(abc_right)}")
                    print(f"Max difference: {max_diff}")
                    return False
    
    print("Associativity test passed for all test cases!")
    return True


def run_all_tests():
    print("Running tests for difference-trickOverflowLogCountingSemiring...")
    test_zeros_ones()
    print(" test_zeros_ones: PASS")
    test_add()
    print(" test_add: PASS")
    test_multiply()
    print(" test_multiply: PASS")
    test_star_small()
    print(" test_star_small: PASS")
    test_star()
    print(" test_star: PASS")
    test_associativity()
    print(" test_associativity: PASS")

    print("All tests passed successfully!")


if __name__ == "__main__":
    run_all_tests()