from __future__ import annotations
from collections import defaultdict
import logging
from time import time
from typing import Any, List, Tuple, Dict, Callable, Union
import itertools
import uuid

import torch
import numpy as np

logger = logging.getLogger("dolphin.symbolic")

logger.stats = defaultdict(float)
logger.reset_stats = lambda : logger.stats.clear()

class Distribution:
    _p = None
    _k = None

    @property
    def provenance(self):
        return type(self)._p
    
    @provenance.setter
    def provenance(self, value: Any) -> None:
        type(self)._p = value

    @property
    def k(self) -> Any:
        return type(self)._k
    
    @k.setter
    def k(self, value: Any) -> None:
        type(self)._k = value

    @staticmethod
    def copy(d: Distribution) -> Distribution:
        assert isinstance(d, Distribution), "Input must be of type Distribution"
        return Distribution(d.tags, d.symbols, dist_as_probs=False, src=d.src)

    @staticmethod
    def stack(distributions: List) -> Distribution:
        from .utils import symbolic_collate_fn
        return symbolic_collate_fn(distributions)
    
    @staticmethod
    def l_and(a: Distribution, b: Distribution) -> Distribution:
        assert isinstance(a, Distribution) and isinstance(b, Distribution), "All inputs must be of type Distribution"
        # assert a.type == np.bool_ and b.type == np.bool_, "All inputs must have the type `np.bool_`"

        return a.__compute_possibilities(b, lambda s1, s2 : s1 and s2)

    @staticmethod
    def l_or(a: Distribution, b: Distribution) -> Distribution:
        assert isinstance(a, Distribution) and isinstance(b, Distribution), "All inputs must be of type Distribution"
        # assert a.type == np.bool_ and b.type == np.bool_, "All inputs must have the type `np.bool_`"

        return a.__compute_possibilities(b, lambda s1, s2 : s1 or s2)

    @staticmethod
    def l_not(a: Distribution) -> Distribution:
        assert isinstance(a, Distribution), "All inputs must be of type Distribution"
        # assert a.type == np.bool_, "All inputs must have the type `np.bool_`"

        return a.__compute_possibilities(a, lambda s1, s2 : not s1)
        
    def __get_symbols_from_array(self, symbol_list):
        if isinstance(symbol_list, np.ndarray) and len(symbol_list.shape) == 1:
            # print("Straight copy")
            return symbol_list
        t = time()
        try:
            symbols = np.array(symbol_list)
        except:
            symbols = np.array(symbol_list, dtype=np.object_)
        # symbols = np.array(symbol_list, dtype=np.object_)
        # symbols = np.array(symbol_list, dtype=object)
        # symbols = np.empty(len(symbol_list), dtype=object)
        # symbols[:] = symbol_list
        if symbols.shape == ():
            symbols = np.array([symbols])
        elif len(symbols.shape) > 1:
            symbols = np.empty(len(symbol_list), dtype=object)
            symbols[:] = symbol_list
        logger.stats["T_SymArr"] += (time() - t)
        return symbols

    def __init__(self, distribution: torch.Tensor, symbols, dist_as_probs = True, disjunctions = None, src = None) -> None:
        """
        distribution:
            torch.Tensor of shape (N, M)  where N, M is the shape of symbols
        dist_as_probs:
            True if we want to treat distribution tensor as probabilities
            False if we want to treat distribution tensor as tags
        """
        assert self.provenance is not None, "Provenance not set"
        
        self.symbols = self.__get_symbols_from_array(symbols)

        if dist_as_probs:
            assert distribution.dim() <= 2, "Distribution must be 1D or 2D"
            assert (distribution.dim() == 0 and len(self.symbols.shape) == 0) or distribution.shape[-1] == len(self.symbols), f"Length of symbols must match number of columns of the distribution: {distribution.shape[-1]}, {len(self.symbols)}"
           
            if distribution.dim() > 2:
                probs = distribution.view(distribution.shape[0], -1)
            elif distribution.dim() == 1:
                probs = distribution.view(1, -1)
            else:
                probs = distribution

            if disjunctions is None:
                disjunctions = [list(range(len(self.symbols)))]

            self.tags = self.provenance.tags_from_probs(probs, disjunctions=disjunctions)
        else:
            # TODO: dimensionality checks for tag distributions
            if distribution.dim() == 1:
                self.tags = distribution.view(1, -1)
            else:
                self.tags = distribution

        if src:
            self.src = src
        else:
            self.src = [self]

        self.inverted = False

        self.id = uuid.uuid4()

    # def num_possibilities(self) -> Distribution:
    #     pwrset_idx = list(powerset(range(len(self.symbols))))
    #     pwrset_tensors = []

    #     for i in range(len(pwrset_idx)):
    #         idx_in = pwrset_idx[i]
    #         idx_out = pwrset_idx[-(i+1)]
            
    #         tensor_in = self.provenance.mul(self.tags[:, idx_in]) if len(idx_in) > 0 else torch.full(self.tags.shape[0], self.provenance.one(self.tags.shape), device=self.tags.device)
    #         tensor_out = self.provenance.mul(self.provenance.negate(self.tags[:, idx_out])) if len(idx_out) > 0 else torch.full(self.tags.shape[0], self.provenance.one(self.tags.shape), device=self.tags.device)

    #         pwrset_tensors.append(self.provenance.mul(tensor_in, tensor_out))

    #     subsets = Distribution(torch.stack(pwrset_tensors, dim=1), pwrset_idx, dist_as_probs=False)
        
    #     return subsets.map(lambda x : len(x))
            

    def sample_top_k(self, k, categorical=True) -> Distribution:
        t = time()
        if k is None or k > len(self.symbols):
            return self
        
        p = self.get_probabilities()
        # Avoid all-zero distributions
        p += p.sum(dim=1).view(p.shape[0], -1) == 0
        if categorical:
            categ = torch.distributions.Categorical(p)
            indices = categ.sample((k, )).T
        else:
            topk = torch.topk(p, k)
            indices = topk.indices
        flattened_indices = indices.unique()
        
        new_symbols = self.symbols[flattened_indices.cpu()]
        if not isinstance(new_symbols, np.ndarray):
            new_symbols = np.array([new_symbols, ])
        if len(flattened_indices) != len(new_symbols):
            assert len(flattened_indices) == 1, f"Indices: {flattened_indices}, Symbols: {new_symbols}"
            new_symbols = [new_symbols]
        
        if self.tags.dim() == 1:
            new_distribution = self.tags[flattened_indices]
        else:
            # mask = torch.full(self.tags.shape, self.provenance.zero(self.tags.shape), dtype=torch.float, device=self.tags.device).scatter_(1, indices, 1.)
            # new_distribution = (self.tags * mask)[:, flattened_indices]
            mask = self.provenance.zeros(self.tags.shape, device=self.tags.device)
            mask[torch.arange(self.tags.shape[0]).unsqueeze(1), indices] = self.provenance.one(self.tags.shape, device=self.tags.device)
            new_distribution = self.provenance.mul_batch(self.tags, mask)[:, flattened_indices]
        # print("SAMPLED: ", new_symbols, new_distribution)
        d = Distribution(new_distribution, new_symbols, dist_as_probs=False, src=self.src)
        # logger.stats["T_TopK"] += (time() - t)

        return d

    def __calculate_possibilities(self, dist_b: Distribution, function: Callable, conditional = False) -> Dict[object, List[torch.Tensor]]:
        if len(self.symbols) == 0:
            return Distribution.copy(self)
        if len(dist_b.symbols) == 0:
            return Distribution.copy(dist_b)
        
        a = self.sample_top_k(self.k)
        b = dist_b.sample_top_k(self.k)

        a_tags, b_tags, ab_src = self.provenance.combine_tag_sources(a, b)

        num_a = len(a.symbols)
        num_b = len(b.symbols)

        res_list = list((function(a.symbols[ia], b.symbols[ib]) for ia, ib in itertools.product(range(num_a), range(num_b))))
        # print(num_a, num_b, len(res_list))
        results = self.__get_symbols_from_array(res_list)
        
        prod_distribution = Distribution(self.provenance.cartesian_prod(a_tags, b_tags), results, dist_as_probs=False, src=ab_src)
        if conditional:
            prod_distribution = prod_distribution.drop_symbol(None)
        final_tags, symbols = self.provenance.reduce_symbols(prod_distribution.tags, prod_distribution.symbols)
        
        return Distribution(final_tags, symbols, dist_as_probs=False, src=ab_src)
    
    def __compute_possibilities(self, dist_b: Union[Distribution|np.ndarray|Any], function: Callable, conditional = False) -> Distribution:
        assert self.provenance is not None, "Provenance not set"
        if not isinstance(dist_b, Distribution):
            if isinstance(dist_b, np.ndarray):
                assert len(dist_b) == len(self.symbols), "Length of symbols must match"
                dist_b = Distribution(torch.ones(self.tags.shape[:2], device=self.tags.device), dist_b)
            else:
                if self.tags.dim() > 1:
                    dist_b = Distribution(torch.ones((self.tags.shape[0], 1), device=self.tags.device), [dist_b, ])
                else:
                    dist_b = Distribution(torch.ones(1, device=self.tags.device), [dist_b, ])

        t = time()
        res = self.__calculate_possibilities(dist_b, function, conditional)
        logger.stats["T_Compute"] += (time() - t)
        return res
    
    def apply(self, dist_b: Distribution, function: Callable) -> Distribution:
        """
        self:
            symbols: T1
            distribution: torch.Tensor

        dist_b:
            symbols: T2
            distribution: torch.Tensor

        function:
            T1 x T2 -> T3 (caveat: T3 should be hashable)
        """
        return self.__compute_possibilities(dist_b, function)
    
    def apply_if(self, dist_b: Distribution, function: Callable, condition: Callable) -> Distribution:
        """
        self:
            symbols: T1
            distribution: torch.Tensor

        dist_b:
            symbols: T2
            distribution: torch.Tensor

        function:
            T1 x T2 -> T3 (caveat: T3 should be hashable)

        condition:
            T1 x T2 -> Bool (only consider results for which condition returns True)
        """
        return self.__compute_possibilities(dist_b, lambda a, b: function(a, b) if condition(a, b) else None, conditional=True)
    
    def softmax(self) -> Distribution:
        t = time()
        probs = self.get_probabilities()
        d = Distribution(torch.nn.functional.softmax(probs, dim=-1), self.symbols, src=self.src)
        logger.stats["T_Softmax"] += (time() - t)
        return d
    
    def map(self, function: Callable) -> Distribution:
        """
        function: something that applies to each symbol
        """
        t = time()

        # results = np.array(list(map(function, self.symbols)))
        results = self.__get_symbols_from_array(list(map(function, self.symbols)))
        final_tags, symbols = self.provenance.reduce_symbols(self.tags, results)
        
        logger.stats["T_Map"] += (time() - t)

        return Distribution(final_tags, symbols, dist_as_probs=False, src=self.src)

    def __add__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 + s2)
    
    def __mul__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 * s2)
    
    def __sub__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 - s2)
    
    def __truediv__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 / s2)
    
    def __floordiv__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 // s2)
    
    def __mod__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 % s2)
    
    def __pow__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 ** s2)
    
    def __eq__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 == s2)
    
    def __ne__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 != s2)
    
    def __gt__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 > s2)
    
    def __ge__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 >= s2)
    
    def __lt__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 < s2)
    
    def __le__(self, dist_b: Union[Distribution|np.ndarray|Any]) -> Distribution:
        return self.__compute_possibilities(dist_b, lambda s1, s2 : s1 <= s2)

    def __invert__(self) -> Distribution:
        assert len(self.symbols) == 2, "Only binary distributions can be inverted"
        return Distribution(self.provenance.neg_batch(self.tags), self.symbols, dist_as_probs=False, src=self.src)
    
    def __repr__(self) -> str:
        return f"{{Symbols: {self.symbols}, Distribution: {self.tags}}}"
    
    def __getitem__(self, indices: int | slice | torch.Tensor | List | Tuple | None) -> Distribution:
        t = time()
        d = self.tags[indices]
        if torch.numel(d):
            d = d.view([-1] + list(self.tags.shape[1:]))
        else:
            d = d.view([1] + list(self.tags.shape[1:]))
        # logger.stats["T_GetItemDIST"] += (time() - t)
        t = time()
        if isinstance(indices, tuple):
            s = self.symbols[indices[1]]
        else:
            if self.tags.dim() <= 1:
                s = self.symbols[indices]
            else:
                s = self.symbols
        # logger.stats["T_GetItemSYM"] += (time() - t)
        t = time()
        f = Distribution(d, s, dist_as_probs=False, src=self.src)
        # logger.stats["T_GetItemINIT"] += (time() - t)
        # logger.stats["T_GetItem"] += (time() - t)

        return f
    
    def __iter__(self):
        return (self[i] for i in range(len(self)))
    
    def __len__(self) -> int:
        return len(self.tags)
    
    def __hash__(self):
        return self.id.int
    
    def filter(self, filter_function) -> Distribution:
        t = time()
        filtered_indices = [ filter_function(s) for s in self.symbols ]
        true_symbols = self.symbols[filtered_indices]
        logger.stats["T_TrueSym"] += (time() - t)
        if self.tags.dim() == 1:
            d = Distribution(self.tags[filtered_indices], true_symbols, dist_as_probs=False, src=self.src)
        else:
            d = Distribution(self.tags[:, filtered_indices], true_symbols, dist_as_probs=False, src=self.src)

        logger.stats["T_Filter"] += (time() - t)
        return d
    
    def map_symbols(self, new_symbols: np.ndarray) -> Distribution:
        t = time()
        # new_symbols = np.array(new_symbols)
        t_norm = time()
        new_symbols = self.__get_symbols_from_array(new_symbols)

        if self.tags.dim() == 1:
            org_dist = self.tags.view(1, -1)
        else:
            org_dist = self.tags
        logger.stats["T_Norm"] += (time() - t_norm)

        t_alloc = time()
        new_shape = list(org_dist.shape)
        new_shape[1] = len(new_symbols)

        new_dist = self.provenance.zeros(new_shape, device=self.tags.device)
        if self.inverted:
            new_dist = self.provenance.neg_batch(new_dist)

        logger.stats["T_Alloc"] += (time() - t_alloc)

        t_copy = time()
        if len(self.symbols) != 0:
            _, idx1, idx2 = np.intersect1d(self.symbols, new_symbols, return_indices=True, assume_unique=True)
            new_dist[:, idx2] = org_dist[:, idx1]

        logger.stats["T_Copy"] += (time() - t_copy)

        t_instance = time()
        d = Distribution(new_dist, new_symbols, dist_as_probs=False, src=self.src)
        logger.stats["T_Instance"] += (time() - t_instance)

        logger.stats["T_MapSym"] += (time() - t)
        # print(d.tags.shape)
        return d
    
    def diff(self, dist_b: Distribution) -> Distribution:
        t = time()
        a = self.sample_top_k(self.k)
        b = dist_b.sample_top_k(self.k)

        new_symbols = np.setdiff1d(a.symbols, b.symbols, assume_unique=True)
        logger.stats["T_DiffSym"] += (time() - t)
        x = a.map_symbols(new_symbols)
        logger.stats["T_Diff"] += (time() - t)
        return x
    
    # probabilistic logic
    def __and__(self, dist_b: Distribution) -> Distribution:
        assert self.provenance is not None, "Provenance not set"
        # assert self.distribution.shape == dist_b.distribution.shape, "Distributions must have the same shape"
        # symbols = np.unique(np.concatenate((self.symbols, dist_b.symbols)), axis=0)
        t = time()
        symbols = np.union1d(self.symbols, dist_b.symbols)

        if not np.array_equal(self.symbols, symbols):
            a = self.map_symbols(symbols)
        else:
            a = self

        if not np.array_equal(dist_b.symbols, symbols):
            b = dist_b.map_symbols(symbols)
        else:
            b = dist_b

        a_tags, b_tags, ab_src = self.provenance.combine_tag_sources(a, b)
        new_dist = self.provenance.mul_batch(a_tags, b_tags)
        d = Distribution(new_dist, symbols, dist_as_probs=False, src=ab_src)
        logger.stats["T_PAnd"] += (time() - t)
        return d
    
    def __or__(self, dist_b: Distribution) -> Distribution:
        assert self.provenance is not None, "Provenance not set"
        # assert self.distribution.shape == dist_b.distribution.shape, "Distributions must have the same shape"
        # symbols = np.unique(np.concatenate((self.symbols, dist_b.symbols)), axis=0)
        t = time()
        symbols = np.union1d(self.symbols, dist_b.symbols)
        logger.stats["T_UnionOR"] += (time() - t)
        
        t_map = time()
        if not np.array_equal(self.symbols, symbols):
            a = self.map_symbols(symbols)
        else:
            a = self

        if not np.array_equal(dist_b.symbols, symbols):
            b = dist_b.map_symbols(symbols)
        else:
            b = dist_b

        logger.stats["T_MapOR"] += (time() - t_map)

        a_tags, b_tags, ab_src = self.provenance.combine_tag_sources(a, b)
        new_dist = self.provenance.add_batch(a_tags, b_tags)
        d = Distribution(new_dist, symbols, dist_as_probs=False, src=ab_src)
        logger.stats["T_POr"] += (time() - t)
        return d
    
    def __invert__(self) -> Distribution:
        assert self.provenance is not None, "Provenance not set"
        t = time()
        new_dist = self.provenance.neg_batch(self.tags)
        d = Distribution(new_dist, self.symbols, dist_as_probs=False, src=self.src)
        d.inverted = not self.inverted
        logger.stats["T_PNot"] += (time() - t)
        return d

    def drop_symbol(self, symbol) -> Distribution:
        return self.filter(lambda s : s != symbol)
    
    def get_probabilities(self) -> torch.Tensor:
        return self.provenance.probs_from_tags(self.tags)
