from collections import defaultdict
from typing import Dict
from .base import Tracker
from .fenwick_tree import FenwickTree

class LocalRecencyTracker(Tracker):
    def __init__(self, all_timestamps : Dict[int,int], num_nodes : int, num_edges : int):
        super().__init__()
        self.rank_tracker = defaultdict(lambda: FenwickTree(len(all_timestamps)))
        self.dst_by_time = defaultdict(lambda: defaultdict(set))
        self.time_by_dst = defaultdict(dict)
        self.timetrackers = defaultdict(dict)
        self.time2idx = all_timestamps
        self.num_nodes = num_nodes
        self.timetrackermaxes = defaultdict(lambda : -1)
        self.fw_remove_trackers = defaultdict(lambda: FenwickTree(len(all_timestamps)))

    def update(self, src, dst, t):
        if dst not in self.time_by_dst[src]:
            self.time_by_dst[src][dst] = t
            self.rank_tracker[src].update(self.time2idx[t], 1)
            self.dst_by_time[src][t].add(dst)
        else:
            old_t = self.time_by_dst[src][dst]
            if old_t == t:
                return
            self.rank_tracker[src].update(self.time2idx[old_t], -1)
            self.rank_tracker[src].update(self.time2idx[t], 1)
            self.time_by_dst[src][dst] = t
            self.dst_by_time[src][old_t].remove(dst)
            if len(self.dst_by_time[src][old_t]) == 0:
                self.fw_remove_trackers[src].update(self.time2idx[old_t], 1)
                
                del self.dst_by_time[src][old_t]
                del self.timetrackers[src][old_t]
            self.dst_by_time[src][t].add(dst)

        if t not in self.timetrackers[src]:
            self.timetrackermaxes[src] += 1
            self.timetrackers[src][t] = self.timetrackermaxes[src]

    def filter_dst(self, src, dst, to_filter):
        num_to_filter_opt = 0
        num_to_filter_pess = 0
        for node in to_filter:
            if node in self.time_by_dst[src] and node != dst:
                if dst in self.time_by_dst[src]:
                    num_to_filter_opt += int(self.time_by_dst[src][node] > self.time_by_dst[src][dst])
                    num_to_filter_pess += int(self.time_by_dst[src][node] >= self.time_by_dst[src][dst])
                else:
                    num_to_filter_opt += int(self.contains_dst(src, node))
                    num_to_filter_pess += int(self.contains_dst(src, node))

                
        return num_to_filter_opt, num_to_filter_pess

    def get_rank(self, src, dst, current_t, to_filter : set = None):
        if not self.contains_dst(src, dst):
            
            num_before = self.get_total(src)

            num_to_filter_opt, num_to_filter_pess = self.filter_dst(src, dst, to_filter)
            
            optimistic = num_before - num_to_filter_opt
            # If it is not contained, the worst case scenario is that it comes after ALL other nodes except for itself!
            pessimistic = (self.num_nodes - 1) - num_to_filter_pess

            return optimistic, pessimistic

        num_total = self.rank_tracker[src].get_sum(self.time2idx[current_t]) - 1 # Do - 1 to take out the current dst node
        
        if self.time2idx[self.time_by_dst[src][dst]] > 0:
            num_after_dst_pessimistic = self.rank_tracker[src].get_sum(self.time2idx[self.time_by_dst[src][dst]] - 1)
        else:
            num_after_dst_pessimistic = 0
        
        num_after_dst_optimistic = self.rank_tracker[src].get_sum(self.time2idx[self.time_by_dst[src][dst]]) - 1 # Take out the current dst node

        assert self.time2idx[current_t] > self.time2idx[self.time_by_dst[src][dst]]
        num_to_filter_opt = 0
        num_to_filter_pess = 0
        if to_filter is not None:
            # TODO: This needs revisiting! Are we filtering correctly? 
            num_to_filter_opt, num_to_filter_pess = self.filter_dst(src, dst, to_filter)
        
        optimistic = (num_total - num_after_dst_optimistic) - num_to_filter_opt
        pessimistic = (num_total - num_after_dst_pessimistic) - num_to_filter_pess
        
        assert optimistic >= 0 and pessimistic >= 0 and optimistic <= pessimistic
        
        return optimistic, pessimistic

    def get_entities_to_filter_out(self, src, dst, t):
        
        if not self.contains_dst(src, dst):
            return set(self.time_by_dst[src].keys())
        
        
        # Otherwise 
        dst_time = self.time_by_dst[src][dst]
        to_return = {cand_dst for cand_dst, cand_time in self.time_by_dst[src].items() if cand_time > dst_time and cand_dst != dst}
        return to_return
    
    

    def get_tied_entities(self, src, dst, t, candidate_set : set = None):
        if not self.contains_dst(src, dst):
            assert candidate_set is not None

            return self._get_tied_entities_within_set(src, dst, t, candidate_set)
        
        if candidate_set is not None and len(candidate_set) < len(self.time_by_dst[src]):
            return self._get_tied_entities_within_set(src, dst, t, candidate_set)        
        # Otherwise 
        dst_time = self.time_by_dst[src][dst]
        to_return = {cand_dst for cand_dst, cand_time in self.time_by_dst[src].items() if cand_time == dst_time and cand_dst != dst}
        return to_return
    
    def contains_dst(self, src, dst):
        return dst in self.time_by_dst[src]
    
    def get_score(self, src, dst):
        return self.time_by_dst[src].get(dst, -1)

    def get_full_mrr(self, src, dst, t, to_filter : set = None):
        if not self.contains_dst(src, dst):
            return 0
        optimistic, pessimistic = self.get_rank(src, dst, t, to_filter=to_filter)
        rank = (0.5*(optimistic + pessimistic)) + 1
        assert optimistic <= pessimistic
        return 1.0/rank, optimistic + 1, pessimistic + 1
    
    def get_total(self, src):
        return self.rank_tracker[src].get_sum(len(self.time2idx))
    
    def get_node_rank(self, src, dst, t):
        if dst not in self.time_by_dst[src]:
            # raise Exception("Should not happen")
            return self.num_nodes, -1, 1
        timerank = self.timetrackers[src][self.time_by_dst[src][dst]]
        # sorted(self.timetracker.keys()).index(self.time_by_dst[dst])
        rank = len(self.timetrackers[src]) - (timerank - self.fw_remove_trackers[src].get_sum(self.time2idx[self.time_by_dst[src][dst]]))
        assert rank <= self.num_nodes and 1 <= rank, (rank, self.num_nodes)
        timedelta = t - self.time_by_dst[src][dst]
        """try:
            _, opt, pess = self.get_full_mrr(src, dst, t)
            assert rank <= opt and rank <= pess
        except AssertionError:
            breakpoint()"""

        return rank, timedelta, len(self.timetrackers[src])

    def tracker_score_type(self):
        return int