from .base import Tracker
from .fenwick_tree import FenwickTree
from collections import defaultdict, Counter
from typing import Dict, Set


class PopularityTracker(Tracker):
    """
    Tracking and scoring using popularity of dst nodes globally or locally
    """
    def __init__(self, all_timestamps : Dict[int,int], num_nodes : int, num_edges : int, use_local_popularity : bool):
        super().__init__()
        self.rank_tracker = defaultdict(lambda: FenwickTree(num_edges))
        self.dst_by_popularity = defaultdict(lambda: defaultdict(set))
        self.popularity_by_dst = defaultdict(Counter)

        # We do not care about the timestamps here

        self.num_nodes = num_nodes

        self.popularity_fw_trees = defaultdict(lambda: FenwickTree(num_edges))
        self.use_local_popularity = int(use_local_popularity)
        self.max_popularity = defaultdict(lambda : 0)

    def get_entities_to_filter_out(self, src, dst, t):
        srcnode = self.get_srcnode(src)
        if not self.contains_dst(src, dst):
            
            return set(self.popularity_by_dst[srcnode].keys())

        # Otherwise
        dst_popularity = self.popularity_by_dst[srcnode][dst]
        to_return = {cand_dst for cand_dst, cand_popularity in self.popularity_by_dst[srcnode].items() if cand_popularity > dst_popularity and cand_dst != dst}
        return to_return
    
    def get_tied_entities(self, src, dst, t, candidate_set : Set[int] = None):
        srcnode = self.get_srcnode(src)
        if not self.contains_dst(src, dst):
            
            assert candidate_set is not None
            return self._get_tied_entities_within_set(src, dst, t, candidate_set)
        
        
        if candidate_set is not None and len(candidate_set) < len(self.popularity_by_dst[srcnode]):
            return self._get_tied_entities_within_set(src, dst, t, candidate_set)

        # Otherwise 
        dst_popularity = self.popularity_by_dst[srcnode][dst]
        to_return = {cand_dst for cand_dst, cand_popularity in self.popularity_by_dst[srcnode].items() if cand_popularity == dst_popularity and cand_dst != dst}
        return to_return

    def get_srcnode(self, src):
        return src*self.use_local_popularity

    def update(self, src, dst, t):
        srcnode = self.get_srcnode(src)
        self.popularity_by_dst[srcnode][dst] += 1
        
        current_popularity = self.popularity_by_dst[srcnode][dst]
        self.dst_by_popularity[srcnode][current_popularity].add(dst)

        if len(self.dst_by_popularity[srcnode][current_popularity]) == 1:
            # Then we have a new one! 
            assert self.popularity_fw_trees[srcnode].get_sum(current_popularity) - self.popularity_fw_trees[srcnode].get_sum(current_popularity - 1) == 0
            self.popularity_fw_trees[srcnode].update(current_popularity, 1)
            assert self.popularity_fw_trees[srcnode].get_sum(current_popularity) - self.popularity_fw_trees[srcnode].get_sum(current_popularity - 1) == 1

        self.rank_tracker[srcnode].update(current_popularity, 1)
        

        if current_popularity > 1:
            prev_popularity = current_popularity - 1
            self.rank_tracker[srcnode].update(prev_popularity, -1)
            self.dst_by_popularity[srcnode][prev_popularity].remove(dst)

            if len(self.dst_by_popularity[srcnode][prev_popularity]) == 0:
                # Then remove this one
                self.popularity_fw_trees[srcnode].update(prev_popularity, -1)
                del self.dst_by_popularity[srcnode][prev_popularity]
        

        self.max_popularity[srcnode] = max(self.max_popularity[srcnode], current_popularity)
        
    def filter_dst(self, src, dst, to_filter):
        srcnode = self.get_srcnode(src)

        num_to_filter_opt = 0
        num_to_filter_pess = 0
        for node in to_filter:
            if node in self.popularity_by_dst[srcnode] and node != dst:
                if self.contains_dst(src, dst):
                    num_to_filter_opt += int(self.popularity_by_dst[srcnode][node] > self.popularity_by_dst[srcnode][dst])
                    num_to_filter_pess += int(self.popularity_by_dst[srcnode][node] >= self.popularity_by_dst[srcnode][dst])
                else:
                    num_to_filter_opt += int(self.contains_dst(src, node))
                    num_to_filter_pess += int(self.contains_dst(src, node))

        return num_to_filter_opt, num_to_filter_pess

    def get_rank(self, src, dst, current_t, to_filter : set = None):
        srcnode = self.get_srcnode(src)

        if not self.contains_dst(src, dst):
            num_before = self.get_total(src)
            num_to_filter_opt, num_to_filter_pess = self.filter_dst(src, dst, to_filter)
            optimistic = num_before - num_to_filter_opt
            # If it is not contained, the worst case scenario is that it comes after ALL other nodes except for itself!
            pessimistic = (self.num_nodes - 1) - num_to_filter_pess

            return optimistic, pessimistic

        current_popularity = self.popularity_by_dst[srcnode][dst]
        max_popularity = self.max_popularity[srcnode]
        num_total = self.rank_tracker[srcnode].get_sum(max_popularity) - 1

        if current_popularity > 1:
            num_after_dst_pessimistic = self.rank_tracker[srcnode].get_sum(current_popularity - 1)
        else:
            num_after_dst_pessimistic = 0
        
        num_after_dst_optimistic = self.rank_tracker[srcnode].get_sum(current_popularity) - 1

        num_to_filter_opt = 0
        num_to_filter_pess = 0

        if to_filter is not None:
            num_to_filter_opt, num_to_filter_pess = self.filter_dst(src, dst, to_filter)
        
        optimistic = (num_total - num_after_dst_optimistic) - num_to_filter_opt
        pessimistic = (num_total - num_after_dst_pessimistic) - num_to_filter_pess
        try:
            assert optimistic >= 0 and pessimistic >= 0 and optimistic <= pessimistic
        except AssertionError as e:
            breakpoint()
        return optimistic, pessimistic
        
    def contains_dst(self, src, dst):
        if dst in self.popularity_by_dst[self.get_srcnode(src)]:
            return True
        return False
    
    def get_popularity(self, src, dst):
        return self.popularity_by_dst[self.get_srcnode(src)][dst]
    
    def get_score(self, src, dst):
        return self.get_popularity(src, dst)
    
    def get_full_mrr(self, src, dst, t, to_filter = None):
        if not self.contains_dst(src, dst):
            return 0

        optimistic, pessimistic = self.get_rank(src, dst, t, to_filter)
        rank = (0.5*(optimistic + pessimistic)) + 1
        assert optimistic <= pessimistic
        return 1.0/rank, optimistic + 1, pessimistic + 1
        

    def get_total(self, src):
        srcnode = self.get_srcnode(src)
        return self.rank_tracker[srcnode].get_sum(self.max_popularity[srcnode]+1)

    def get_node_rank(self, src, dst, t):
        srcnode = self.get_srcnode(src)
        if not self.contains_dst(src, dst):
            # Then we are in a different situation! 
            return self.num_nodes, -1, 1

        totaltorank = self.popularity_fw_trees[srcnode].get_sum(self.max_popularity[srcnode]+1)

        noderank = self.popularity_fw_trees[srcnode].get_sum(self.popularity_by_dst[srcnode][dst])

        rank = (totaltorank - noderank) + 1
        
        
        popularitydelta = self.max_popularity[srcnode] - self.popularity_by_dst[srcnode][dst]
        # try:
        #     assert rank <= self.num_nodes and 1 <= rank, (rank, self.num_nodes)
        #     ctr_tmp = sorted(set(self.popularity_by_dst[srcnode].values()), reverse=True)    
        #     assert ctr_tmp.index(self.popularity_by_dst[srcnode][dst]) == rank - 1
        # except (AssertionError, ValueError):
        #     breakpoint()
        return rank, popularitydelta, len(self.popularity_by_dst[srcnode])

    def tracker_score_type(self):
        return int



class GlobalPopularityTracker(PopularityTracker):
    def __init__(self, all_timestamps : Dict[int,int], num_nodes : int, num_edges : int):
        super().__init__(all_timestamps, num_nodes, num_edges, use_local_popularity=False)

class LocalPopularityTracker(PopularityTracker):
    def __init__(self, all_timestamps : Dict[int,int], num_nodes : int, num_edges : int):
        super().__init__(all_timestamps, num_nodes, num_edges, use_local_popularity=True)