import numpy as np
from collections.abc import Sequence

import torch

from .provenance import Provenance, Tag, TagBatch

class DAMP(Provenance):
    def zero(self, batch_shape, device="cpu"):
        return 0.0

    def one(self, batch_shape, device="cpu"):
        return 1.0
    
    def zeros(self, shape, device="cpu") -> TagBatch:
        return torch.zeros(shape, device=device)
    
    def add(self, a: Tag, b: Tag):
        return a + b
    
    def mul(self, a: Tag, b: Tag):
        return a * b
    
    def neg(self, a: Tag):
        return 1 - a

    def add_batch(self, a: TagBatch, b: TagBatch = None) -> TagBatch:
        if b is not None:
            assert a.shape == b.shape, "a and b must be of the same shape"
            return (a + b).clamp(min=0.0, max=1.0)
        
        elif isinstance(a, Sequence):
            return sum(a).clamp(min=0.0, max=1.0)
        else:
            assert isinstance(a, torch.Tensor), "Expected a tensor"
            return torch.sum(a, dim=-1).clamp(min=0.0, max=1.0)
        
    def mul_batch(self, a: TagBatch, b: TagBatch = None) -> TagBatch:
        if b is not None:
            assert a.shape == b.shape, "a and b must be of the same shape"
            return a * b
        else:
            if isinstance(a, torch.Tensor):
                return torch.prod(a, dim=-1)
            else:
                assert isinstance(a, Sequence), "Expected a sequence"
                return torch.prod(torch.stack(a), dim=0)
            
    def neg_batch(self, a: TagBatch) -> TagBatch:
        return 1 - a
        
    def reduce_symbols(self, prod: TagBatch, results: np.ndarray):
        # print(len(results))
        if results.dtype != np.object_:
            symbols, idx = np.unique(results, return_inverse=True)
        else:
            sym = dict()
            symbols = []
            idx = []
            i = 0
            for r in results:
                if r not in sym:
                    sym[r] = i
                    i += 1
                    symbols.append(r)
                idx.append(sym[r])
        
        # print("REDUCING SYMBOLS")
        # print(len(symbols), len(results))
        ident_matrix = torch.eye(len(symbols), device=prod.device)
        # print(ident_matrix.shape, len(idx))
        indicator_tensor = ident_matrix[idx]

        final_probs = prod @ indicator_tensor

        return final_probs.clamp(min=0.0, max=1.0), symbols
    
    def cartesian_prod(self, a: TagBatch, b: TagBatch) -> TagBatch:
        b_size_a, num_a = a.shape
        b_size_b, num_b = b.shape
        b_size = max(b_size_a, b_size_b)
        if b_size_a < b_size:
            a = a.expand(b_size, num_a)
        elif b_size_b < b_size:
            b = b.expand(b_size, num_b)
        return torch.bmm(a.view(-1, num_a, 1), b.view(-1, 1, num_b)).view(b_size, -1)