from typing import List
from .base import Tracker
from .local_tracker import LocalRecencyTracker
from .global_tracker import GlobalRecencyTracker
from .popularity_tracker import LocalPopularityTracker, GlobalPopularityTracker
from typing import Iterable

class CombinationTracker(Tracker):
    def __init__(self, all_timestamps, num_nodes, num_edges, trackers_list : List[Tracker]):
        super().__init__()
        assert len(trackers_list) > 0
        self.trackers = trackers_list
        self.num_nodes = num_nodes

    def update(self, src, dst, t):
        for tracker in self.trackers:
            tracker.update(src, dst, t)
    
    def contains_dst(self, src, dst):
        return any([tracker.contains_dst(src, dst) for tracker in self.trackers])

    def get_total(self, src):
        return self.trackers[0].get_total(src)
    
    def get_score(self, src, dst):
        return tuple([tracker.get_score(src, dst) for tracker in self.trackers])
    
    def _join_ranks(self, optimistic_ranks, pessimistic_ranks):
        opt = sum(optimistic_ranks)
        pess = sum(optimistic_ranks[:-1]) + pessimistic_ranks[-1]
        return opt, pess
    
    def get_entities_to_filter_out(self, src, dst, t):
        raise NotImplementedError("This function is not implemented for the CombinationTracker")
    
    def get_tied_entities(self, src, dst, t):
        raise NotImplementedError("This function is not implemented for the CombinationTracker")

    def get_rank(self, src, dst, current_t, to_filter : set = None):
        to_filter = to_filter or set()
        if not self.contains_dst(src, dst):
            return self.num_nodes, self.num_nodes
        opt_rank, pess_rank = self.trackers[0].get_rank(src, dst, current_t, to_filter=to_filter)
        if opt_rank == pess_rank:
            return opt_rank, pess_rank

        optimistic_ranks = [opt_rank]
        pessimistic_ranks = [pess_rank]

        tied_entities = []
        entities_ranked_before = []



        # Now, get the tied entities AND the entities to filter out
        # The tied entities are the one we rank, and the entities coming before our 
        tracker_idx = 1

        if not self.trackers[0].contains_dst(src, dst):
            entities_ranked_before.append(self.trackers[0].get_entities_to_filter_out(src, dst, current_t))

            
            while tracker_idx < len(self.trackers) and not self.trackers[tracker_idx].contains_dst(src, dst):
                # Then for the subsequent
                all_to_filter = set.union(*(entities_ranked_before + [to_filter]))
                opt_rank, pess_rank = self.trackers[tracker_idx].get_rank(src, dst, current_t, to_filter=all_to_filter)
                optimistic_ranks.append(opt_rank)
                pessimistic_ranks.append(pess_rank)
                if opt_rank == pess_rank:
                    return self._join_ranks(optimistic_ranks, pessimistic_ranks)
                
                if tracker_idx < len(self.trackers):
                    entities_ranked_before.append(self.trackers[tracker_idx].get_entities_to_filter_out(src, dst, current_t))

                tracker_idx += 1

            if tracker_idx < len(self.trackers):
                assert self.trackers[tracker_idx].contains_dst(src, dst), "Unexpected event? wtf? "
                all_to_filter = set.union(*(entities_ranked_before + [to_filter]))
                opt_rank, pess_rank = self.trackers[tracker_idx].get_rank(src, dst, current_t, to_filter=all_to_filter)
                optimistic_ranks.append(opt_rank)
                pessimistic_ranks.append(pess_rank)
                if opt_rank == pess_rank:
                    return self._join_ranks(optimistic_ranks, pessimistic_ranks)
            
                
                # Get the tied entities to be ranked in the subsequent trackers. 
                tracker_idx += 1
                # *Important* optimization below: Skip fetching the ties if we are at the last tracker
                # Otherwise, we risk doing this for a tracker with many ties, for example the global tracker which is just 
                # crazy unnecessary amount of computation
                # these tied entities are not used anyway if we are at the last tracker
                if tracker_idx < len(self.trackers):
                    tied_entities.append(self.trackers[tracker_idx-1].get_tied_entities(src, dst, current_t))
                        
            
        else:
            tied_entities.append(self.trackers[0].get_tied_entities(src, dst, current_t))
        
        all_to_filter = set.union(*(entities_ranked_before+[to_filter]))
        while tracker_idx < len(self.trackers):
            candidates = (set.intersection(*sorted(tied_entities, key = lambda x : len(x))) - all_to_filter)
            opt_rank, pess_rank = self.trackers[tracker_idx].get_rank_within_set(src, dst, set_of_dst=candidates, current_t=current_t)
            optimistic_ranks.append(opt_rank)
            pessimistic_ranks.append(pess_rank)
            if opt_rank == pess_rank:
                return self._join_ranks(optimistic_ranks, pessimistic_ranks)
            
            tied_entities.append(self.trackers[tracker_idx].get_tied_entities(src, dst, current_t, candidate_set=candidates))
            
            tracker_idx += 1

        optimistic_rank, pessimistic_rank = self._join_ranks(optimistic_ranks, pessimistic_ranks)
        assert optimistic_rank <= pessimistic_rank
        assert pessimistic_rank <= self.num_nodes
        assert pessimistic_rank <= pessimistic_ranks[0]
        
        return optimistic_rank, pessimistic_rank
    
    def get_full_mrr(self, src, dst, t, to_filter = None):
        if not self.contains_dst(src, dst):
            return 0
        optimistic, pessimistic = self.get_rank(src, dst, t, to_filter=to_filter)
        rank = (0.5*(optimistic + pessimistic)) + 1
        assert optimistic <= pessimistic
        return 1.0/rank, optimistic + 1, pessimistic + 1

    def get_node_rank(self, src, dst, t):
        """ We return the ranks as a tuple of ranks from each tracker"""
        all_data = [tracker.get_node_rank(src, dst, t) for tracker in self.trackers]
        tranks = [data[0] for data in all_data]
        tdeltas = [data[1] for data in all_data]
        totnums = [data[2] for data in all_data]
        return tuple(tranks), tuple(tdeltas), tuple(totnums)
    
    def tracker_score_type(self):
        return tuple


class LocalGlobalRecencyComboTracker(CombinationTracker):
    def __init__(self, all_timestamps, num_nodes, num_edges):
        trackers = [
            LocalRecencyTracker(all_timestamps, num_nodes, num_edges), 
            GlobalRecencyTracker(all_timestamps, num_nodes, num_edges)
        ]
        super().__init__(
            all_timestamps, 
            num_nodes, 
            num_edges, 
            trackers_list=trackers
        )


class LocalRecencyLocalPopularityGlobalRecencyComboTracker(CombinationTracker):
    def __init__(self, all_timestamps, num_nodes, num_edges):
        trackers = [
            LocalRecencyTracker(all_timestamps, num_nodes, num_edges), 
            LocalPopularityTracker(all_timestamps, num_nodes, num_edges), 
            GlobalRecencyTracker(all_timestamps, num_nodes, num_edges)
        ]
        super().__init__(
            all_timestamps, 
            num_nodes, 
            num_edges, 
            trackers_list=trackers
        )



class RecencyPopularityTracker(CombinationTracker):
    def __init__(self, all_timestamps, num_nodes, num_edges):
        trackers = [
            LocalRecencyTracker(all_timestamps, num_nodes, num_edges), 
            GlobalRecencyTracker(all_timestamps, num_nodes, num_edges),
            LocalPopularityTracker(all_timestamps, num_nodes, num_edges),

        ]
        super().__init__(
            all_timestamps, 
            num_nodes, 
            num_edges, 
            trackers_list=trackers
        )

class LocalRecGlobalPop(CombinationTracker):
    def __init__(self, all_timestamps, num_nodes, num_edges):
        trackers = [
            LocalRecencyTracker(all_timestamps, num_nodes, num_edges), 
            GlobalPopularityTracker(all_timestamps, num_nodes, num_edges),
            GlobalRecencyTracker(all_timestamps, num_nodes, num_edges),
        ]
        super().__init__(
            all_timestamps, 
            num_nodes, 
            num_edges, 
            trackers_list=trackers
        )

class LocalRecGlobalRecGlobalPop(CombinationTracker):
    def __init__(self, all_timestamps, num_nodes, num_edges):
        trackers = [
            LocalRecencyTracker(all_timestamps, num_nodes, num_edges), 
            GlobalRecencyTracker(all_timestamps, num_nodes, num_edges),
            GlobalPopularityTracker(all_timestamps, num_nodes, num_edges),
        ]
        super().__init__(
            all_timestamps, 
            num_nodes, 
            num_edges, 
            trackers_list=trackers
        )

class AllHeuristicsInOne(CombinationTracker):
    def __init__(self, all_timestamps, num_nodes, num_edges):
        trackers = [
            LocalRecencyTracker(all_timestamps, num_nodes, num_edges), 
            GlobalRecencyTracker(all_timestamps, num_nodes, num_edges),
            LocalPopularityTracker(all_timestamps, num_nodes, num_edges), 
            GlobalPopularityTracker(all_timestamps, num_nodes, num_edges),
        ]
        super().__init__(
            all_timestamps, 
            num_nodes, 
            num_edges, 
            trackers_list=trackers
        )


class GlobalRecGlobalPop(CombinationTracker):
    def __init__(self, all_timestamps, num_nodes, num_edges):
        trackers = [
            GlobalRecencyTracker(all_timestamps, num_nodes, num_edges),
            GlobalPopularityTracker(all_timestamps, num_nodes, num_edges),
        ]
        super().__init__(
            all_timestamps, 
            num_nodes, 
            num_edges, 
            trackers_list=trackers
        )