from .base import Tracker
from .fenwick_tree import FenwickTree
from collections import defaultdict
from typing import Dict, Set


class GlobalRecencyTracker(Tracker):
    """
    Tracking recency of dst nodes globally
    """
    def __init__(self, all_timestamps : Dict[int,int], num_nodes : int, num_edges : int):
        super().__init__()
        self.rank_tracker = FenwickTree(len(all_timestamps))
        self.dst_by_time = defaultdict(set)
        self.time_by_dst = {}
        self.time2idx = all_timestamps
        self.num_nodes = num_nodes
        self.timetracker = {}
        self.timetrackermax = -1
        self.fw_remove_tracker = FenwickTree(len(all_timestamps))
    
    def get_entities_to_filter_out(self, src, dst, t):
        
        if not self.contains_dst(src, dst):
            return set(self.time_by_dst.keys())
        
        
        # Otherwise 
        dst_time = self.time_by_dst[dst]
        to_return = {cand_dst for cand_dst, cand_time in self.time_by_dst.items() if cand_time > dst_time and cand_dst != dst}
        return to_return
    
    

    def get_tied_entities(self, src, dst, t, candidate_set = None):
        if not self.contains_dst(src, dst):
            
            assert candidate_set is not None
            return self._get_tied_entities_within_set(src, dst, t, candidate_set)
        
        if candidate_set is not None and len(candidate_set) < len(self.time_by_dst):
            return self._get_tied_entities_within_set(src, dst, t, candidate_set)
        # Otherwise 
        dst_time = self.time_by_dst[dst]
        # TODO: THis can be made *way* more efficient. 
        to_return = {cand_dst for cand_dst, cand_time in self.time_by_dst.items() if cand_time == dst_time and cand_dst != dst}
        return to_return


    def update(self, src, dst, t):
        if dst not in self.time_by_dst:
            self.time_by_dst[dst] = t
            self.rank_tracker.update(self.time2idx[t], 1)
            self.dst_by_time[t].add(dst)
        else:
            old_t = self.time_by_dst[dst]
            if old_t == t:
                return
            self.rank_tracker.update(self.time2idx[old_t], -1)
            self.rank_tracker.update(self.time2idx[t], 1)
            self.time_by_dst[dst] = t
            self.dst_by_time[old_t].remove(dst)
            if len(self.dst_by_time[old_t]) == 0:
                self.fw_remove_tracker.update(self.time2idx[old_t], 1)
                
                del self.dst_by_time[old_t]
                del self.timetracker[old_t]
            self.dst_by_time[t].add(dst)
        
        if t not in self.timetracker:
            self.timetrackermax += 1
            self.timetracker[t] = self.timetrackermax

    def filter_dst(self, src, dst, to_filter):
        num_to_filter_opt = 0
        num_to_filter_pess = 0
        for node in to_filter:
            if node in self.time_by_dst and node != dst:
                num_to_filter_opt += int(self.time_by_dst[node] > self.time_by_dst[dst])
                num_to_filter_pess += int(self.time_by_dst[node] >= self.time_by_dst[dst])
        return num_to_filter_opt, num_to_filter_pess
    

    def get_rank(self, src, dst, current_t, to_filter : set = None):
        

        if not self.contains_dst(None, dst):
            num_before = self.get_total(None)

            num_to_filter_opt, num_to_filter_pess = self.filter_dst(None, dst, to_filter)

            optimistic = num_before - num_to_filter_opt
            pessimistic = (self.num_nodes - 1) - num_to_filter_pess

            assert optimistic <= pessimistic
            return optimistic, pessimistic
            


        num_total = self.rank_tracker.get_sum(self.time2idx[current_t]) - 1 # Take out the current dst node

        if self.time2idx[self.time_by_dst[dst]] > 0:
            num_after_dst_pessimistic = self.rank_tracker.get_sum(self.time2idx[self.time_by_dst[dst]] - 1)
        else:
            num_after_dst_pessimistic = 0
        num_after_dst_optimistic = self.rank_tracker.get_sum(self.time2idx[self.time_by_dst[dst]]) - 1 # Take out the current dst node with - 1
        assert self.time2idx[current_t] > self.time2idx[self.time_by_dst[dst]]
        num_to_filter_opt = 0
        num_to_filter_pess = 0
        if to_filter is not None:
            # TODO: This needs revisiting! Are we filtering correctly? 
            num_to_filter_opt, num_to_filter_pess = self.filter_dst(None, dst, to_filter)
        num_to_filter_not_equal = 0
        
        optimistic = num_total - num_after_dst_optimistic - num_to_filter_opt - num_to_filter_not_equal
        pessimistic = num_total - num_after_dst_pessimistic - num_to_filter_pess - num_to_filter_not_equal
        assert optimistic >= 0 and pessimistic >= 0 and optimistic <= pessimistic
        
        return optimistic, pessimistic
    
    def contains_dst(self, src, dst):
        return dst in self.time_by_dst
    
    def get_score(self, src, dst):
        return self.time_by_dst.get(dst, -1)
    
    def get_full_mrr(self, src, dst, t, to_filter : set = None):
        if not self.contains_dst(None, dst):
            raise Exception("Should not happen")
            return 0
        optimistic, pessimistic = self.get_rank(src, dst, t, to_filter=to_filter)
        rank = (0.5*(optimistic + pessimistic)) + 1
        assert optimistic <= pessimistic
        return 1.0/rank, optimistic + 1, pessimistic + 1
    
    def get_total(self, src):
        return self.rank_tracker.get_sum(len(self.time2idx))
    
    def get_node_rank(self, src, dst, t):
        if dst not in self.time_by_dst:
            # raise Exception("Should not happen")
            return self.num_nodes, -1, 1
        timerank = self.timetracker[self.time_by_dst[dst]]
        # sorted(self.timetracker.keys()).index(self.time_by_dst[dst])
        rank = len(self.timetracker) - (timerank - self.fw_remove_tracker.get_sum(self.time2idx[self.time_by_dst[dst]]))
        assert rank <= self.num_nodes and 1 <= rank, (rank, self.num_nodes)
        timedelta = t - self.time_by_dst[dst]
        """try:
            _, opt, pess = self.get_full_mrr(src, dst, t)
            assert rank <= opt and rank <= pess
        except AssertionError:
            breakpoint()"""

        return rank, timedelta, len(self.timetracker)
    
    def tracker_score_type(self):
        return int