from typing import Dict
from .global_tracker import GlobalRecencyTracker
from .local_tracker import LocalRecencyTracker
from .base import Tracker


class ReviewScorer(Tracker):
    def __init__(self, all_timestamps : Dict[int,int], num_nodes : int):
        super().__init__()
        self.local_tracker = LocalRecencyTracker(all_timestamps, num_nodes)
        self.global_tracker = GlobalRecencyTracker(all_timestamps, num_nodes)
        self.num_nodes = num_nodes

    def update(self, src, dst, t):
        self.local_tracker.update(src, dst, t)
        self.global_tracker.update(src, dst, t)

    def contains_dst(self, src, dst):
        return self.global_tracker.contains_dst(None, dst)

    def get_total(self, src):
        return self.global_tracker.get_total(None)

    def get_score(self, src, dst):
        # Returns the global score of the dst node if it has not been seen for the node in question. 
        if not self.local_tracker.contains_dst(src, dst):
            return self.global_tracker.get_score(None, dst)
        else:
            return -float("inf")
        # return self.global_tracker.get_score(None, dst) if not self.local_tracker.contains_dst(src, dst) else -float("inf")
    
    def get_rank(self, src, dst, current_t, to_filter : set = None):
        # Returns the global rank of the dst node if it has not been seen for the node in question. 
        if not self.local_tracker.contains_dst(src, dst):
            # NOTE: This is our important assumption - if a node has already been seen for a specific src node
            # Then we should ignore it. 
            # This is because we consider any node seen as irrelevant since they should only be reviewed *once* 
            # _to_filter = self.seen_per_head[src] | (to_filter if to_filter is not None else set())
            return self.global_tracker.get_rank(dst, current_t, to_filter=to_filter)
        else:
            return self.num_nodes, self.num_nodes
    
    def get_full_mrr(self, src, dst, t, to_filter : set = None):
        if not self.contains_dst(None, dst):
            raise Exception("Should not happen")
            return 0
        optimistic, pessimistic = self.get_rank(src, dst, t, to_filter=to_filter)
        rank = (0.5*(optimistic + pessimistic)) + 1
        assert optimistic <= pessimistic
        return 1.0/rank, optimistic + 1, pessimistic + 1


    def get_node_rank(self, src, dst, t):
        if self.local_tracker.contains_dst(src, dst):
            return self.num_nodes, self.num_nodes, self.num_nodes
        return self.global_tracker.get_node_rank(src, dst, t)
    
    def tracker_score_type(self):
        return int