from abcvoting import abcrules
from abcvoting.preferences import Profile
import networkx as nx
import matplotlib.pyplot as plt

COMMITTEE = "committee"
PATH = "path"
LENGTH = "length"
SDIFF = "score difference"


def show_graph(G):
    """
    Shows the graph G.

    Parameters
    ----------
    G : nx.Graph
        The graph to be shown.
    """
    nx.draw(G, with_labels=True, labels={i: G.nodes[i][COMMITTEE] for i in G.nodes})
    plt.show()


def read_committees_from_file(filename):
    with open(filename, "r") as f:
        line = f.readline()
        committeestrs = line.split(") (")
        committees = []
        for st in committeestrs:
            if st.strip() != "":
                items = st.replace(")", "").replace("(", "").split(" ")
                committees.append(set([int(item) for item in items if item.strip() != ""]))
        return committees


def get_all_winning_committees(m, voters, rule, k, optimality_requirement=1):
    """
    Finds the k-size committees whose score is at least optimality_requirement of a highest scoring committee.

    Parameters
    ----------
    m : int
        The number of alternatives
    voters : list[set[int]]
        The list of voters
    rule : str
        The voting rule
    k : int
        The committee size.
    optimality_requirement : float
        The minimum proportion of the score of an optimal committee every committee in the reconfiguration graph must
        obtain. Formally, we set $\delta_s$ = score(W)(1 - optimality_requirement), where W is an optimal committee.

    Returns
    -------
    list[CommitteeSet]
        List of the committees whose score is at least optimality_requirement of a highest scoring committee.
    """
    profile = Profile(num_cand=m)
    profile.add_voters(voters)
    return abcrules.compute(rule, profile, committeesize=k, optimality_requirement=optimality_requirement)


def _committee_to_id(committeeset):
    """
    Turns a CommitteeSet object into a sorted tuple so it can be used as a dictionary key.

    Parameters
    ----------
    committeeset : The committee.

    Returns
    -------
    tuple[int]
        The tuple corresponding to the committee.
    """
    return tuple(sorted(list(committeeset)))


# This class creates a graph for all the optimal committees and complete graph on them
class ReconfigurationSolverExhaustive():

    def __init__(self, m=None, voters=None, rule=None, k=None, symdiff=2, construct_path_dictionary=True,
                 optimality_requirement=1, committee_file=None):
        """
        Constructs the reconfiguration graph for the given election, as described in Preliminaries, paragraph
        Reconfiguration Graphs.

        Parameters
        ----------
        m : int
            The number of alternatives.
        voters : list[set[int]]
            The voters in the instance.
        rule : str
            The voting rule.
        k : int
            the committee size
        symdiff : int
            $\delta_c$, the symmetric difference between adjacent committees.
        construct_path_dictionary : bool
            Construct a dictionary containing all-pair shortest paths.
        optimality_requirement : float
            The minimum proportion of the score of an optimal committee every committee in the reconfiguration graph
            must obtain. Formally, we set $\delta_s$ = score(W)(1 - optimality_requirement), where W is an optimal
            committee.

        """
        self.m = m
        self.voters = voters
        self.rule = rule
        self.k = k
        if not committee_file:
            if None in [m, voters, rule, k]:
                raise ValueError("Filename not provided, m, voters, rule, k, required")
            self.committees = get_all_winning_committees(m, voters, rule, k,
                                                         optimality_requirement=optimality_requirement)
        else:
            self.committees = read_committees_from_file(committee_file)
        self.path_dictionary = None
        G = nx.Graph()
        self.committee_dictionary = {}
        for i, committee in enumerate(self.committees):
            G.add_node(i)
            G.nodes[i][COMMITTEE] = committee
            self.committee_dictionary[_committee_to_id(committee)] = i
        for i in range(len(self.committees)):
            for j in range(i + 1, len(self.committees)):
                if len(self.committees[i].symmetric_difference(self.committees[j])) <= symdiff:
                    G.add_edge(i, j)
        self.G = G
        if construct_path_dictionary:
            self.construct_all_shortest_paths()

    def construct_all_shortest_paths(self):
        """
        Constructs a dictionary containing all-pair shortest paths.
        """
        all_pair_shortest_paths = nx.all_pairs_shortest_path(self.G)
        self.path_dictionary = {}
        for i, d, in all_pair_shortest_paths:
            self.path_dictionary[_committee_to_id(self.committees[i])] = {}
            for j in range(len(self.committees)):
                self.path_dictionary[_committee_to_id(self.committees[i])][_committee_to_id(self.committees[j])] = {}
                if j in d:
                    self.path_dictionary[_committee_to_id(self.committees[i])][_committee_to_id(self.committees[j])][
                        PATH] = [self.committees[x] for x in d[j]]
                    self.path_dictionary[_committee_to_id(self.committees[i])][_committee_to_id(self.committees[j])][
                        LENGTH] = len(d[j])
                else:
                    self.path_dictionary[_committee_to_id(self.committees[i])][_committee_to_id(self.committees[j])][
                        PATH] = None
                    self.path_dictionary[_committee_to_id(self.committees[i])][_committee_to_id(self.committees[j])][
                        LENGTH] = -1

    def get_winning_committees(self):
        """
        Gives the committees in the reconfiguration graph.

        Returns
        -------
        list[set[int]]
            The committees in the reconfiguration graph.
        """
        return self.committees

    def get_path_length(self, c1, c2):
        """
        Returns the length of the shortest path between two committees c1, c2.

        Parameters
        ----------
        c1 : set[int]
            Initial committee W_0
        c2 : set[int]
            Goal committee W_t

        Returns
        -------
        int
            The length of the path, -1 if no path exists.

        """
        return self.path_dictionary[_committee_to_id(c1)][_committee_to_id(c2)][LENGTH]

    def get_path(self, c1, c2):
        """
        Returns a shortest path between two committees c1, c2.

        Parameters
        ----------
        c1 : set[int]
            Initial committee W_0
        c2 : set[int]
            Goal committee W_t

        Returns
        -------
        list[set[int]]
            List of the committees on the path, None if there is no path.


        """
        if hasattr(self, "path_dictionary") and self.path_dictionary is not None:
            return self.path_dictionary[_committee_to_id(c1)][_committee_to_id(c2)][PATH]
        else:
            try:
                return [self.committees[x] for x in
                        nx.shortest_path(self.G, self.committee_dictionary[_committee_to_id(c1)],
                                         self.committee_dictionary[_committee_to_id(c2)])]
            except nx.exception.NetworkXNoPath:
                return None

    def get_connected_components_nro(self):
        """
        Gives the number of connected components in the reconfiguration graph.

        Returns
        -------
        int
            The number of connected components in the reconfiguration graph.
        """
        return len(list(nx.connected_components(self.G)))

    def compute_path_existence_probability(self):
        """
        Computes the proportion of pairs of committees that have a reconfiguration path between them.

        Returns
        -------
        float
            The proportion.
        """
        path_exists = 0
        nc = len(self.committees)
        if nc == 1:
            return 1
        for i in range(nc):
            for j in range(i + 1, nc):
                if self.get_path(self.committees[i], self.committees[j]) is not None:
                    path_exists += 1
        return path_exists / (nc * (nc - 1) // 2)

    def compute_average_path_length(self):
        """
        Computes the average path length between two committees in the reconfiguration graph, over all pairs of
        committees that are in the same connected component.

        Returns
        -------
        float
            The average.
        """
        nc = len(self.committees)
        nro_pairs = 0
        total_length = 0
        if nc == 1:
            return 1
        for i in range(nc):
            for j in range(i + 1, nc):
                if self.get_path(self.committees[i], self.committees[j]) is not None:
                    nro_pairs += 1
                    total_length += self.get_path_length(self.committees[i], self.committees[j])
        return total_length / nro_pairs

    def show_graph(self):
        """
        Shows the reconfiguration graph using Matplotlib.
        """
        show_graph(self.G)

    def compute_extra_swaps_needed(self, c1, c2):
        """
        Given two committees, computes how much longer their shortest reconfiguration graph is than the shortest
        possible path based on their symmetric difference. See Example 1 in the paper for explanation.

        Parameters
        ----------
        c1 : set[int]
            The first committee.
        c2 : set[int]
            The second committee.

        Returns
        -------
        int
            The extra path length needed. -1 if no path exists.
        """
        if self.get_path_length(c1, c2) == -1:
            return -1
        symdiff = len(c1.symmetric_difference(c2))
        return self.get_path_length(c1, c2) - 1 - symdiff // 2

    def compute_average_extra_swaps_needed(self):
        """
        Computes how much longer the shortest reconfiguration graph is compared to shortest needed based on the
        symmetric difference between the two committees (see Example 1 in the paper), averaged over all pairs of
        committees that are in the same connected component.

        Returns
        -------
        float:
            The average.
        """
        nc = len(self.committees)
        nro_pairs = 0
        total_extra_pairs = 0
        if nc == 1:
            return 0
        for i in range(nc):
            for j in range(i + 1, nc):
                if self.get_path(self.committees[i], self.committees[j]) is not None:
                    nro_pairs += 1
                    total_extra_pairs += self.compute_extra_swaps_needed(self.committees[i], self.committees[j])
        if nro_pairs == 0:
            return 0
        return total_extra_pairs / nro_pairs

    def compute_all_extra_swaps_needed(self):
        """
        Computes how much longer the shortest reconfiguration graph is compared to shortest needed based on the
        symmetric difference between the two committees (see Example 1 in the paper), averaged over all pairs of
        committees that are in the same connected component.

        Returns
        -------
        float:
            The average.
        """
        nc = len(self.committees)
        total_extra_pairs = 0
        if nc == 1:
            return 0
        for i in range(nc):
            for j in range(i + 1, nc):
                if self.get_path(self.committees[i], self.committees[j]) is not None:
                    total_extra_pairs += self.compute_extra_swaps_needed(self.committees[i], self.committees[j])
        return total_extra_pairs


def committee_score(voters, committee, rule):
    """
    Computes the score of a committee according to some voting rule.

    Parameters
    ----------
    voters : list[set[int]]
        The voters.
    committee : set[int]
        The committee whose score we compute.
    rule : str
        The rule according to which we compute the score.

    Returns
    -------
    float
        The committee score.
    """
    if rule == "cc":
        vector = [1] + [0 for _ in range(len(committee) - 1)]
    if rule == "pav":
        vector = [(1 / (x + 1)) for x in range(len(committee))]
    score = 0
    for vote in voters:
        index = 0
        for c in committee:
            if c in vote:
                score += vector[index]
                index += 1
    return score


def reconfiguration_heuristic(voters, rule, k, init_C, goal_C, symdiff=2, optimality_requirement=1):
    """

    Parameters
    ----------
    voters : list[set[int]]
        The voters in the instance.
    rule : str
        The voting rule.
    k : int
        The committee size
    init_C : set[int]
        The first committee on the path.
    goal_C : set[int]
        The last committee on the path.
    symdiff : int
        $\delta_c$, the symmetric difference between adjacent committees.
    optimality_requirement: float
        The minimum proportion of the score of an optimal committee every committee in the reconfiguration graph
        must obtain. Formally, we set $\delta_s$ = score(W)(1 - optimality_requirement), where W is an optimal
        committee.

    Returns
    -------

    """
    # Rename the alternatives so that they are still within range [1, ... , m]
    all_alts_to_selected = {}
    selected_to_all_alts = sorted(list(set(init_C).union(goal_C)))
    for i, alt in enumerate(selected_to_all_alts):
        all_alts_to_selected[alt] = i
    _m = len(selected_to_all_alts)
    new_voters = []
    for voter in voters:
        new_voter = [all_alts_to_selected[x] for x in voter if x in selected_to_all_alts]
        if new_voter:
            new_voters.append(new_voter)
    solver = ReconfigurationSolverExhaustive(_m, new_voters, rule, k, symdiff, False, optimality_requirement)
    new_init_C = set(all_alts_to_selected[x] for x in init_C)
    new_goal_C = set(all_alts_to_selected[x] for x in goal_C)
    path = solver.get_path(new_init_C, new_goal_C)
    if path is None:
        return None
    # Return the original names for the alternatives.
    return [set(selected_to_all_alts[y] for y in x) for x in path]


if __name__ == "__main__":
    voters = [{1, 2}, {2, 0}, {1}]
    optimality_requirement = 0.8
    rule = "cc"
    solver = ReconfigurationSolverExhaustive(3, voters, rule, 2, optimality_requirement=optimality_requirement)
    print(solver.committees)
