import itertools

import abcvoting.abcrules_gurobi
from abcvoting import abcrules
from abcvoting.preferences import Profile
import networkx as nx
import matplotlib.pyplot as plt

COMMITTEE = "committee"
PATH = "path"
LENGTH = "length"
SDIFF = "score difference"


def show_graph(G):
    """
    Shows the graph G.

    Parameters
    ----------
    G : nx.Graph
        The graph to be shown.
    """
    nx.draw(G, with_labels=True, labels={i: G.nodes[i][COMMITTEE] for i in G.nodes})
    plt.show()


def read_committees_from_file(filename):
    """
    Reads the committees from the provided path.

    Parameters
    ----------
    filename : str
        The location of the committee file.
    """
    with open(filename, "r") as f:
        line = f.readline()
        committeestrs = line.split(") (")
        committees = []
        for st in committeestrs:
            if st.strip() != "":
                items = st.replace(")", "").replace("(", "").split(" ")
                committees.append(set([int(item) for item in items if item.strip() != ""]))
        return committees


def get_all_winning_committees(m, voters, rule, k, optimality_requirement=1, min_score=None, fixed_alts=[]):
    """
    Finds the k-size committees whose score is at least optimality_requirement of a highest scoring committee.

    Parameters
    ----------
    m : int
        The number of alternatives
    voters : list[set[int]]
        The list of voters
    rule : str
        The voting rule
    k : int
        The committee size.
    optimality_requirement : float
        The minimum proportion of the score of an optimal committee every committee in the reconfiguration graph must
        obtain. Formally, we set $\delta_s$ = score(W)(1 - optimality_requirement), where W is an optimal committee.
    min_score : float
        The minimum score of a considered committee. If None, this is computed based on the optimality requirement.
    fixed_alts : Alternatives that are required to be in the committees

    Returns
    -------
    list[CommitteeSet]
        List of the committees whose score is at least optimality_requirement of a highest scoring committee.
    """
    if optimality_requirement < 1 and min_score is not None:
        print("Warning, both min_score and optimality requirement set")
    profile = Profile(num_cand=m)
    profile.add_voters(voters)
    return abcrules.compute(rule, profile, committeesize=k, optimality_requirement=optimality_requirement,
                            min_score=min_score, fixed_alts=fixed_alts)


def _committee_to_id(committeeset):
    """
    Turns a CommitteeSet object into a sorted tuple so it can be used as a dictionary key.

    Parameters
    ----------
    committeeset : The committee.

    Returns
    -------
    tuple[int]
        The tuple corresponding to the committee.
    """
    return tuple(sorted(list(committeeset)))


# This class creates a graph for all the optimal committees and complete graph on them
class ReconfigurationSolverExhaustive():

    def __init__(self, m=None, voters=None, rule=None, k=None, symdiff=2, construct_path_dictionary=True,
                 optimality_requirement=1, min_score=None, committee_file=None):
        """
        Constructs the reconfiguration graph for the given election, as described in Preliminaries, paragraph
        Reconfiguration Graphs.

        Parameters
        ----------
        m : int
            The number of alternatives.
        voters : list[set[int]]
            The voters in the instance.
        rule : str
            The voting rule.
        k : int
            the committee size
        symdiff : int
            $\delta_c$, the symmetric difference between adjacent committees.
        construct_path_dictionary : bool
            Construct a dictionary containing all-pair shortest paths.
        optimality_requirement : float
            The minimum proportion of the score of an optimal committee every committee in the reconfiguration graph
            must obtain. Formally, we set $\delta_s$ = score(W)(1 - optimality_requirement), where W is an optimal
            committee.
        min_score : float
            The minimum score of a committee on a reconfiguration graph

        """
        self.m = m
        self.voters = voters
        self.rule = rule
        self.k = k
        if not committee_file:
            if None in [m, voters, rule, k]:
                raise ValueError("Filename not provided, m, voters, rule, k, required")
            self.committees = get_all_winning_committees(m, voters, rule, k,
                                                         optimality_requirement=optimality_requirement,
                                                         min_score=min_score)
        else:
            self.committees = read_committees_from_file(committee_file)
        self.path_dictionary = None
        G = nx.Graph()
        self.committee_dictionary = {}
        for i, committee in enumerate(self.committees):
            G.add_node(i)
            G.nodes[i][COMMITTEE] = committee
            self.committee_dictionary[_committee_to_id(committee)] = i
        for i in range(len(self.committees)):
            for j in range(i + 1, len(self.committees)):
                if len(self.committees[i].symmetric_difference(self.committees[j])) <= symdiff:
                    G.add_edge(i, j)
        self.G = G
        if construct_path_dictionary:
            self.construct_all_shortest_paths()
        self.symdiff = symdiff

    def construct_all_shortest_paths(self):
        """
        Constructs a dictionary containing all-pair shortest paths.
        """
        all_pair_shortest_paths = nx.all_pairs_shortest_path(self.G)
        self.path_dictionary = {}
        for i, d, in all_pair_shortest_paths:
            self.path_dictionary[_committee_to_id(self.committees[i])] = {}
            for j in range(len(self.committees)):
                self.path_dictionary[_committee_to_id(self.committees[i])][_committee_to_id(self.committees[j])] = {}
                if j in d:
                    self.path_dictionary[_committee_to_id(self.committees[i])][_committee_to_id(self.committees[j])][
                        PATH] = [self.committees[x] for x in d[j]]
                    self.path_dictionary[_committee_to_id(self.committees[i])][_committee_to_id(self.committees[j])][
                        LENGTH] = len(d[j])
                else:
                    self.path_dictionary[_committee_to_id(self.committees[i])][_committee_to_id(self.committees[j])][
                        PATH] = None
                    self.path_dictionary[_committee_to_id(self.committees[i])][_committee_to_id(self.committees[j])][
                        LENGTH] = -1

    def get_winning_committees(self):
        """
        Gives the committees in the reconfiguration graph.

        Returns
        -------
        list[set[int]]
            The committees in the reconfiguration graph.
        """
        return self.committees

    def get_path_length(self, c1, c2):
        """
        Returns the length of the shortest path between two committees c1, c2.

        Parameters
        ----------
        c1 : set[int]
            Initial committee W_0
        c2 : set[int]
            Goal committee W_t

        Returns
        -------
        int
            The length of the path, -1 if no path exists.

        """
        return self.path_dictionary[_committee_to_id(c1)][_committee_to_id(c2)][LENGTH]

    def get_path(self, c1, c2):
        """
        Returns a shortest path between two committees c1, c2.

        Parameters
        ----------
        c1 : set[int]
            Initial committee W_0
        c2 : set[int]
            Goal committee W_t

        Returns
        -------
        list[set[int]]
            List of the committees on the path, None if there is no path.


        """
        if hasattr(self, "path_dictionary") and self.path_dictionary is not None:
            return self.path_dictionary[_committee_to_id(c1)][_committee_to_id(c2)][PATH]
        else:
            try:
                return [self.committees[x] for x in
                        nx.shortest_path(self.G, self.committee_dictionary[_committee_to_id(c1)],
                                         self.committee_dictionary[_committee_to_id(c2)])]
            except nx.exception.NetworkXNoPath:
                return None

    def get_connected_components_nro(self):
        """
        Gives the number of connected components in the reconfiguration graph.

        Returns
        -------
        int
            The number of connected components in the reconfiguration graph.
        """
        return len(list(nx.connected_components(self.G)))

    def compute_path_existence_probability(self):
        """
        Computes the proportion of pairs of committees that have a reconfiguration path between them.

        Returns
        -------
        float
            The proportion.
        """
        path_exists = 0
        nc = len(self.committees)
        if nc == 1:
            return 1
        for i in range(nc):
            for j in range(i + 1, nc):
                if self.get_path(self.committees[i], self.committees[j]) is not None:
                    path_exists += 1
        return path_exists / (nc * (nc - 1) // 2)

    def compute_average_path_length(self):
        """
        Computes the average path length between two committees in the reconfiguration graph, over all pairs of
        committees that are in the same connected component.

        Returns
        -------
        float
            The average.
        """
        nc = len(self.committees)
        nro_pairs = 0
        total_length = 0
        if nc == 1:
            return 1
        for i in range(nc):
            for j in range(i + 1, nc):
                if self.get_path(self.committees[i], self.committees[j]) is not None:
                    nro_pairs += 1
                    total_length += self.get_path_length(self.committees[i], self.committees[j])
        return total_length / nro_pairs

    def show_graph(self):
        """
        Shows the reconfiguration graph using Matplotlib.
        """
        show_graph(self.G)

    def compute_extra_swaps_needed(self, c1, c2):
        """
        Given two committees, computes how much longer their shortest reconfiguration graph is than the shortest
        possible path based on their symmetric difference. See Example 1 in the paper for explanation.

        Parameters
        ----------
        c1 : set[int]
            The first committee.
        c2 : set[int]
            The second committee.

        Returns
        -------
        int
            The extra path length needed. -1 if no path exists.
        """
        if self.get_path_length(c1, c2) == -1:
            return -1
        symdiff = len(c1.symmetric_difference(c2))
        return self.get_path_length(c1, c2) - 1 - symdiff // 2

    def compute_minimal_path_probability(self):
        """
        Computes how likely it is that two connected committees admit a reconfiguration path that is as short as
        possible based on their symmetric difference.

        Returns
        -------
        float:
            The probability.
        """
        path_exists = 0
        path_minimal = 0
        nc = len(self.committees)
        if nc == 1:
            return 1
        for i in range(nc):
            for j in range(i + 1, nc):
                if self.get_path(self.committees[i], self.committees[j]) is not None:
                    path_exists += 1
                    if self.compute_extra_swaps_needed(self.committees[i], self.committees[j]) == 0:
                        path_minimal += 1
        return path_minimal / path_exists

    def compute_average_extra_swaps_needed(self):
        """
        Computes how much longer the shortest reconfiguration graph is compared to shortest needed based on the
        symmetric difference between the two committees (see Example 1 in the paper), averaged over all pairs of
        committees that are in the same connected component.

        Returns
        -------
        float:
            The average.
        """
        nc = len(self.committees)
        nro_pairs = 0
        total_extra_pairs = 0
        if nc == 1:
            return 0
        for i in range(nc):
            for j in range(i + 1, nc):
                if self.get_path(self.committees[i], self.committees[j]) is not None:
                    nro_pairs += 1
                    total_extra_pairs += self.compute_extra_swaps_needed(self.committees[i], self.committees[j])
        if nro_pairs == 0:
            return 0
        return total_extra_pairs / nro_pairs

    def compute_all_extra_swaps_needed(self):
        """
        Computes how much longer the shortest reconfiguration graph is compared to shortest needed based on the
        symmetric difference between the two committees (see Example 1 in the paper), averaged over all pairs of
        committees that are in the same connected component.

        Returns
        -------
        float:
            The average.
        """
        nc = len(self.committees)
        total_extra_pairs = 0
        if nc == 1:
            return 0
        for i in range(nc):
            for j in range(i + 1, nc):
                if self.get_path(self.committees[i], self.committees[j]) is not None:
                    total_extra_pairs += self.compute_extra_swaps_needed(self.committees[i], self.committees[j])
        return total_extra_pairs

    def add_committee_to_graph(self, committee):
        """
        Adds a committee to reconfiguration graph

        Parameters
        ------
        committee : set
            The committee to be added
        """
        i = len(self.G.nodes)
        if hasattr(self, "path_dictionary") and self.path_dictionary is not None:
            print("WARNING: path dictionary not updated")
        self.G.add_node(i)
        self.G.nodes[i][COMMITTEE] = committee
        self.committee_dictionary[_committee_to_id(committee)] = i
        self.committees.append(committee)
        for j in range(len(self.committees) - 1):
            if len(committee.symmetric_difference(self.committees[j])) <= self.symdiff:
                self.G.add_edge(j, i)


def committee_score(voters, committee, rule):
    """
    Computes the score of a committee according to some voting rule.

    Parameters
    ----------
    voters : list[set[int]]
        The voters.
    committee : set[int]
        The committee whose score we compute.
    rule : str
        The rule according to which we compute the score.

    Returns
    -------
    float
        The committee score.
    """
    if rule == "cc":
        vector = [1] + [0 for _ in range(len(committee) - 1)]
    if rule == "pav":
        vector = [(1 / (x + 1)) for x in range(len(committee))]
    score = 0
    for vote in voters:
        index = 0
        for c in committee:
            if c in vote:
                score += vector[index]
                index += 1
    return score


def brute_force_solve(voters, rule, m, init_C, goal_C, symdiff=2, optimality_requirement=None, delta_s=0):
    """

    Parameters
    ----------
    voters : list[set[int]]
        The voters in the instance.
    rule : str
        The voting rule.
    m : int
        The number of alternatives
    init_C : set[int]
        The first committee on the path.
    goal_C : set[int]
        The last committee on the path.
    symdiff : int
        $\delta_c$, the symmetric difference between adjacent committees. Values larger than 2 not supported yet.
    optimality_requirement: float
        The minimum proportion of the score of an optimal committee every committee in the reconfiguration graph
        must obtain. Formally, we set $\delta_s$ = score(W)(1 - optimality_requirement), where W is an optimal
        committee. None if we want delta_s instead.
    delta_s : float
        The score disimprovement from the starting committee that is allowed

    Returns
    -------
    list[str] : the original names of the alternatives."""
    if symdiff > 2:
        raise ValueError("Symdiff >2 not implemented yet.")
    k = len(init_C)
    if optimality_requirement is not None:
        min_score = None
    else:
        min_score = committee_score(voters, init_C, rule) - delta_s
        optimality_requirement = 1
    solver = ReconfigurationSolverExhaustive(m=m, voters=voters, rule=rule, min_score=min_score,
                                             optimality_requirement=optimality_requirement, k=k,
                                             construct_path_dictionary=False)
    return solver.get_path(init_C, goal_C)


def graph_union_heuristic(voters, rule, m, init_C, goal_C, symdiff=2, optimality_requirement=None, delta_s=0):
    """

    Parameters
    ----------
    voters : list[set[int]]
        The voters in the instance.
    rule : str
        The voting rule.
    m : int
        The number of alternatives
    init_C : set[int]
        The first committee on the path.
    goal_C : set[int]
        The last committee on the path.
    symdiff : int
        $\delta_c$, the symmetric difference between adjacent committees.
    optimality_requirement: float
        The minimum proportion of the score of an optimal committee every committee in the reconfiguration graph
        must obtain. Formally, we set $\delta_s$ = score(W)(1 - optimality_requirement), where W is an optimal
        committee. None if we want delta_s instead.
    delta_s : float
        The score disimprovement from the starting committee that is allowed

    Returns
    -------
    list[set[int]] : the committees on the path
    """
    k = len(init_C)
    selected_to_all_alts = sorted(list(set(init_C).union(goal_C)))
    first_round = True
    all_alts_to_selected = {}
    # Rename the alternatives so that they are still within range [1, ... , m]
    for i, alt in enumerate(selected_to_all_alts):
        all_alts_to_selected[alt] = i
    new_init_C = set(all_alts_to_selected[x] for x in init_C)
    new_goal_C = set(all_alts_to_selected[x] for x in goal_C)
    _m = 0
    if optimality_requirement is not None:
        min_score = None
    else:
        min_score = committee_score(voters, init_C, rule) - delta_s
        optimality_requirement = 1
    while _m < m:
        _m = len(selected_to_all_alts)
        new_voters = []
        for voter in voters:
            new_voter = [all_alts_to_selected[x] for x in voter if x in selected_to_all_alts]
            if new_voter:
                new_voters.append(new_voter)
        if first_round:
            solver = ReconfigurationSolverExhaustive(_m, new_voters, rule, k, symdiff, False, optimality_requirement,
                                                     min_score)
            first_round = False
        else:
            solver.voters = new_voters
            try:
                new_committees = get_all_winning_committees(m, new_voters, rule, k, min_score=min_score,
                                                            fixed_alts=[selected_alt])
            except abcvoting.abcrules_gurobi.NoSolutionException:
                new_committees = []
            # Add the committees here
            for committee in new_committees:
                solver.add_committee_to_graph(committee)
        path = solver.get_path(new_init_C, new_goal_C)
        if path:
            return [set(selected_to_all_alts[y] for y in x) for x in path]
        elif _m < m:
            # Find the alternative with highest amount of approvals and add all of its committees to the graph
            alts_with_approvals = [(committee_score(voters, {c}, "cc"), c) for c in range(m)
                                   if c not in selected_to_all_alts]
            alts_with_approvals.sort()
            selected_alt = alts_with_approvals[-1][1]
            selected_to_all_alts.append(selected_alt)
            all_alts_to_selected[selected_alt] = len(selected_to_all_alts) - 1
        else:
            break
    return None


def find_path_ilp(voters, rule, m, init_C, goal_C, delta_s=0, max_path_len=None):
    """
    NOT IMPLEMENTED

    Parameters
    ----------
    voters : list[set[int]]
        The voters in the instance.
    rule : str
        The voting rule.
    m : int
        The number of alternatives
    init_C : set[int]
        The first committee on the path.
    goal_C : set[int]
        The last committee on the path.
    delta_s : float
        The score disimprovement from the starting committee that is allowed
    max_path_len : int
        The maximum length of a desired path

    Returns
    -------
    list[set[int]] : the committees on the path
    """
    pass


if __name__ == "__main__":
    voters = [{0, 1}, {1, 2}, {2, 3}, {0, 3}, {4}, {4}, {4}, {4}, {4}, {4}]
    min_score = 3
    rule = "pav"
    m = 5
    k = 2
    winning_committees = get_all_winning_committees(m, voters, rule, min_score=min_score, fixed_alts=[0])
    print(winning_committees)
    for committee in itertools.combinations(range(m), k):
        print(committee, committee_score(voters, committee, rule))
