import math
import multiprocessing
import random
from os import listdir, makedirs
from os.path import isfile, join, basename
import reconfiguration
import vote_reader
import csv
import argparse
import pandas

import time
import timeout_decorator
from datetime import datetime

OK = "OK"
TO = "TO"
NOT_READY = "NOT READY"
NOT_ENOUGH_COMMITTEES = "NOT ENOUGH COMMITTEES"

_BRUTE_FORCE = "brute-force"
_INITIAL_ALTS = "initial-alts"

STRATEGIES = [_INITIAL_ALTS]


def get_all_files(folder_dir, extension=""):
    """
    Return all the files in the folder with the given extension, with the path that includes the folder.

    Parameters
    ----------
    folder_dir : os.PathLike
        The path of the folder.
    extension : str
        Only include files with this extension.

    Returns
    -------
    list[str]
        The files in the folder.
    """
    return [join(folder_dir, f) for f in listdir(folder_dir) if
            isfile(join(folder_dir, f)) and (extension == "" or f.endswith(extension))]


@timeout_decorator.timeout(seconds=60)
def _find_path(m, voters, rule, init_C, goal_C, strategy, row):
    """
    Computes path between two committees with the dedicated strategy.

    Parameters
    ----------
    m : int
        The number of alternatives
    voters : list[set[int]]
        The approval preferences
    rule : str
        The voting rule..
    init_C : set[int]
        The initial committee.
    goal_C : set[int]
        The target committee.
    strategy : str
        The strategy to use to find the path.
    row : dict
        Dictionary that contains all the computed values for the instance.

    Returns
    -------
    list[set[int]]
        The path between the two committees in question

    Raises
    ------
    timeout_decorator.TimeoutError
        When the execution times out.
    """
    start_time = time.time()
    if strategy == _BRUTE_FORCE:
        path = reconfiguration.brute_force_solve(voters, rule, m, init_C, goal_C)
    elif strategy == _INITIAL_ALTS:
        path = reconfiguration.graph_union_heuristic(voters, rule, m, init_C, goal_C)
    row["time"] = "%.5f" % (time.time() - start_time)
    row["status"] = OK
    return path


@timeout_decorator.timeout(seconds=60 * 60)
def find_path_rerun(filename, rule, init_C, goal_C, row, strategy=_INITIAL_ALTS):
    """
    Computes path between two committees with the dedicated strategy. Designed for rerunning timed out instances.

    Parameters
    ----------
    filename : str
        Location of the instance.
    rule : str
        The voting rule.
    init_C : set[int]
        The initial committee.
    goal_C : set[int]
        The target committee.
    strategy : str
        The strategy to use to find the path.
    row : dict
        Dictionary that contains all the computed values for the instance.

    Returns
    -------
    list[set[int]]
        The path between the two committees in question

    Raises
    ------
    timeout_decorator.TimeoutError
        When the execution times out.
    """
    voters, m, _ = vote_reader.read_file(filename)
    row["m"] = m
    row["n"] = len(voters)
    row["strategy"] = strategy
    start_time = time.time()
    if strategy == _BRUTE_FORCE:
        path = reconfiguration.brute_force_solve(voters, rule, m, init_C, goal_C)
    elif strategy == _INITIAL_ALTS:
        path = reconfiguration.graph_union_heuristic(voters, rule, m, init_C, goal_C)
    print(time.time() - start_time)
    row["time"] = "%.5f" % (time.time() - start_time)
    row["status"] = OK
    return path


def find_path(m, voters, rule, row, init_C, goal_C, strategy):
    """
    Computes the reconfiguration path and  all-pair shortest paths for the pair, and updates the path to the row.


    Parameters
    ----------
    m : int
        The number of alternatives.
    voters : list[set[int]]
        The voters.
    rule : str
        The voting rule.
    row : dict
        Dictionary to store statistics about the run.
    init_C : set[int]
        The alternatives in the first committee
    goal_C : set[int]
        The alternatives in the goal committee
    strategy : str
        The heuristic to use
    """
    try:
        print(rule, strategy)
        path = _find_path(m, voters, rule, init_C, goal_C, strategy, row)
        if path is not None:
            row['pathlen'] = len(path)
            row['path'] = str(path)[1:-1]
        else:
            row['pathlen'] = float("nan")
    except timeout_decorator.TimeoutError:
        pass


def get_random_committee_pairs(m, k, voters, rule, nrcommittees, nrpairs):
    """
    Computes nrcommittees many committees, then pairs up npairs*2 best of them.

    Parameters
    ----------
    rule : str
        The voting rule.
    voters : list[set[int]]
        The voters.
    m : int
        The number of alternatives.
    k : int
        The committee size.
    nrcommittees : int
        The number of committees generated
    nrpairs : int
        The number of pairs returned

    Returns
    -------
    List[(set[int], set[int])]
        The produced committee list
    """

    def _get_random_committee(m, k):
        alts = list(range(m))
        random.shuffle(alts)
        return tuple(sorted(alts[:k]))

    if nrcommittees >= math.comb(m, k):
        print("WARNING: Not enough committees")
        return None
    committees = set()
    while len(committees) < nrcommittees:
        committees.add(_get_random_committee(m, k))
    scored_committees = [(reconfiguration.committee_score(voters, c, rule), c) for c in committees]
    scored_committees.sort(reverse=True)
    return [[scored_committees[2 * i + 1][1], scored_committees[2 * i][1]] for i in range(nrpairs)]


def get_random_committees(m, k, voters, rule, opt_req):
    """
    Computes the random committees with score at least opt_req. The one with higher score is always c1.

    Parameters
    ----------
    rule : str
        The voting rule.
    voters : list[set[int]]
        The voters.
    m : int
        The number of alternatives.
    opt_req : float
        The optimality requirement of the random committees.
    k : int
        The committee size.

    Returns
    -------
    set[int], set[int]
        The two committees"""

    def _get_random_committee(m, k):
        alts = list(range(m))
        random.shuffle(alts)
        return tuple(alts[:k])

    if opt_req == 0:
        if m == k:
            print("WARNING: Only one possible committee")
            return None
        c1 = _get_random_committee(m, k)
        c2 = c1
        while c1 == c2:
            c2 = _get_random_committee(m, k)
        s1 = reconfiguration.committee_score(voters, c1, rule)
        s2 = reconfiguration.committee_score(voters, c2, rule)
        if s1 <= s2:
            return c1, c2
        else:
            return c2, c1
    else:
        committees = reconfiguration.get_all_winning_committees(m, voters, rule, k, opt_req)
        if len(committees) < 2:
            return None
        random.shuffle(committees)
        return committees[0], committees[1]


def find_path_and_update_rows(rule, file, voters, m, opt, k, writer):
    """
    Computes the reconfiguration graph and  all-pair shortest paths for the instance. Writes to a file statistics
    containing the number of committees, connected components, swaps needed, avg. swaps needed, computation times and
    timeout status.

    Parameters
    ----------
    rule : str
        The voting rule.
    file : os.PathLike
        The instance location.
    voters : list[set[int]]
        The voters.
    m : int
        The number of alternatives.
    opt : float
        The optimality requirement.
    k : int
        The committee size.
    writer : csv.DictWriter
        A csv.DictWriter that will write the statistics to a scv file.
    """
    nrcommittees = 100
    nrpairs = 10
    committees = get_random_committee_pairs(m, k, voters, rule, nrcommittees, nrpairs)
    if committees is None:
        row = {"rule": rule, "file": file, "n": len(voters), "m": m, "opt_req": opt, "k": k,
               "status": NOT_ENOUGH_COMMITTEES}
        writer.writerow(row)
        print("Not enough committees")
    else:
        for pair in committees:
            init_C, goal_C = pair
            print("Committees", init_C, goal_C)
            for strategy in STRATEGIES:
                row = {"rule": rule, "file": file, "n": len(voters), "m": m, "k": k, "status": NOT_READY,
                       "strategy": strategy, "init_C": init_C, "goal_C": goal_C}
                find_path(m=m, voters=voters, rule=rule, strategy=strategy, row=row, init_C=init_C,
                          goal_C=goal_C)
                if row["status"] == NOT_READY:
                    print("timed out")
                    row["status"] = TO
                writer.writerow(row)


def default_k_creator(m):
    """
    Give 2 % for anything below 400 and 1 % for everything above 500, lower cap 3 and upper cap 8.

    Parameters
    ----------
    m : int
        The number of alternatives.

    Returns
    -------
    int
        The committee size.
    """
    if m < 400:
        return max((2 * m + 49) // 100, 3)
    else:
        return min((m + 49) // 100, 8)


def create_stats_fixed_committees(base_folder, source_folders, committee_sizes, rules, optimality_requirements, start,
                                  min_voters=0, min_alternatives=0, max_alternatives=-1, percentage=0, files=None,
                                  default_ks=False, max_k=float("inf"), min_k=2, end=None):
    """
    Computes the reconfiguration graph and  all-pair shortest paths for all the files in the source_folders, assuming
    the folders are within base_folder. Writes to a file statistics containing the number of committees, connected
    components, swaps needed, avg. swaps needed, computation times and timeout status.

    Parameters
    ----------
    base_folder : str
        The folder that contains the experiment folders and the results will be produced to.
    source_folders : list[str]
        The folders within base folders which contain the input instances.
    committee_sizes : list[int]
        Committee sizes for which to compute the reconfiguration graph.
    rules : list[str]
        The voting rules, can be cc and pav.
    optimality_requirements : list[float]
        Optimality requirements for which to compute the reconfiguration graph.
    start : int
        Starts from the instance number start as opposed from the beginning, useful when restarting.
    min_voters : int
        Ignores the instances with fewer than min_voters voters.
    min_alternatives : int
        Ignores the instances with fewer than min_alternatives alternatives..
    max_alternatives : int
        Ignores the instances with more than max_alternatives alternatives.
    percentage : int
        Set k as a given percentage of the number of alternatives. If 0, uses other methods of setting k instead.
    files : list[str]
        Only run the experiment on these specific files.
    default_ks : bool
        Default false, if true, sets k as described in the first Netflix experiment of the paper.
    max_k : int
        Maximum value of k.
    min_k : int
        Minimum value of k.
    end : int
        Stop running experiments at this instance.
    """

    def k_to_range(k):
        k = int(k)
        if min_k > max_k:
            raise ValueError("Min k %r larger than max k %r" % (min_k, max_k))
        return min(max_k, max(k, min_k))

    arg_files = files
    for source_folder in source_folders:
        failures = 0
        if not arg_files:
            files = get_all_files(join(base_folder, source_folder))
        files.sort()
        # a for append, w for write
        mode = "w"
        with open(join(base_folder, "path_stats%s_%s.csv") %
                  (source_folder, str(datetime.now()).replace(" ", "--").replace(":", "-").split(".")[0]), mode) as f:
            fieldnames = ["file", "init_C", "goal_C", "n", "m", "rule", "opt_req", "k", "status", "time", "strategy",
                          "pathlen", "path"]
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            if mode == "w":
                writer.writeheader()
            if end is None:
                end = len(files)
            for i, file in enumerate(files[start:end]):
                print(file, i + start, " out of ", len(files))
                voters, m, _ = vote_reader.read_file(file)
                print("voters", len(voters), "alts", m)
                if len(voters) < min_voters or m < min_alternatives or (0 <= max_alternatives < m):
                    print("skipped")
                    continue
                for rule in rules:
                    for opt in optimality_requirements:
                        if percentage:
                            k = k_to_range((m * percentage + 49) // 100)
                            print("k", k)
                            find_path_and_update_rows(rule, file, voters, m, opt, k, writer)
                        elif default_ks:
                            k = k_to_range(default_k_creator(m))
                            print("k", k)
                            find_path_and_update_rows(rule, file, voters, m, opt, k, writer)
                        else:
                            for k in committee_sizes:
                                k = k_to_range(k)
                                if k <= m:
                                    print("k", k)
                                    find_path_and_update_rows(rule, file, voters, m, opt, k, writer)


def read_committee(st):
    """
   Turns a string represantion of a committee into a committee

    Parameters
    ----------
    st : str
        The string representation

    Returns
    -------
    set[int]
        The committee.
    """
    items = st.replace(")", "").replace("(", "").split(", ")
    return set([int(item) for item in items if item.strip() != ""])


def rerun_timeouts(base_folder, source_folders):
    """
   Reruns timed out instances

    Parameters
    ----------
    base_folder : str
        The folder that contains all source_folders.
    source_folders : list[str]
        The locations of the different .csv files which contain timeout information.
    """
    mode = "w"
    with open(join(base_folder, "path_stats%s_%s.csv") %
              ("rerun", str(datetime.now()).replace(" ", "--").replace(":", "-").split(".")[0]), mode) as f:
        fieldnames = ["file", "init_C", "goal_C", "n", "m", "rule", "opt_req", "k", "status", "time", "strategy",
                      "pathlen", "path"]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        files = []
        for folder in source_folders:
            files.extend(get_all_files(join(base_folder, folder), ".csv"))
        for filename in files:
            print(filename)
            df = pandas.read_csv(filename, index_col=None, header=0)
            timeouts = df[df["status"] == "TO"]
            for _, line in timeouts.iterrows():
                print(line)
                rule = line["rule"]
                init_C, goal_C = read_committee(line["init_C"]), read_committee(line["goal_C"])
                print(filename, init_C, goal_C)
                filename = line["file"]
                row = {"rule": rule, "file": filename, "k": len(init_C), "status": NOT_READY, "init_C": init_C,
                       "goal_C": goal_C}
                try:
                    path = find_path_rerun(filename, rule, init_C, goal_C, row)
                    row["path"] = path
                    if path is not None:
                        row["pathlen"] = len(path)
                    print("success")
                except timeout_decorator.TimeoutError:
                    row["status"] = TO
                    print("timeout")
                writer.writerow(row)


if __name__ == "__main__":
    # python3 path_experiment_runner.py --base_folder path_data/resampling/ --source_folder  pfresampling_p0_05_f0_25  pfresampling_p0_05_f0_75  pfresampling_p0_25_f0_25  pfresampling_p0_25_f0_75	pfresampling_p0_50_f0_25  pfresampling_p0_50_f0_75  pfresampling_p0_75_f0_25  pfresampling_p0_75_f0_75 pfresampling_p0_05_f0_50  pfresampling_p0_05_f1_00  pfresampling_p0_25_f0_50  pfresampling_p0_25_f1_00	pfresampling_p0_50_f0_50  pfresampling_p0_50_f1_00  pfresampling_p0_75_f0_50  pfresampling_p0_75_f1_00 --rules pav cc --committee_sizes 5
    # python3 path_experiment_runner.py --base_folder path_data/manhattan/ --source_folder  d2_r0_10  d2_r0_20  d2_r0_30  d2_r0_40	d2_r0_50  d2_r0_60  d2_r0_70  --rules pav cc --committee_sizes 5
    # python3 path_experiment_runner.py --base_folder path_data/netflix/ --source_folder  data  --rules pav cc --committee_sizes 5
    parser = argparse.ArgumentParser(description='Run committees')
    parser.add_argument("--base_folder", type=str, help="The folder that contains the experiment folders and the "
                                                        "results will be produced to.")
    parser.add_argument("--source_folders", type=str, nargs="*",
                        help="The folders within base folders which contain the input instances.")
    parser.add_argument("--rules", type=str, nargs="*",
                        help="The voting rules, can be cc and pav.")
    parser.add_argument("--min_voters", type=int, default=10,
                        help="Ignores the instances with fewer than min_voters voters.")
    parser.add_argument("--max_alternatives", type=int, default=-1,
                        help="Ignores the instances with more than max_alternatives voters.")
    parser.add_argument("--min_alternatives", type=int, default=0,
                        help="Ignores the instances with fewer than min_alternatives voters")
    parser.add_argument("--start", type=int, default=0,
                        help="Starts from the instance number start as opposed from the beginning, useful when "
                             "restarting.")
    parser.add_argument("--committee_sizes", type=int, nargs="*",
                        help="Committee sizes for which to compute the reconfiguration graph.")
    parser.add_argument("--optimality_requirements", type=float, nargs="*", default=[0],
                        help="Optimality requirements of the random committees. Should be either very high or 0,"
                             " because otherwise there are too many possible committees.")
    parser.add_argument("--percentage", type=float, default=0,
                        help="Set k as a given percentage of the number of alternatives.")
    # parser.add_argument("--timeout_c", type=int, default=300)
    # parser.add_argument("--timeout_p", type=int, default=300)
    parser.add_argument("--files", type=str, nargs="*", default=None,
                        help="Only run the experiment on these specific files.")
    parser.add_argument("--default_k", type=bool, default=False,
                        help="Use the k as described in the first Netflix experiment in the paper.")
    parser.add_argument("--min_k", type=int, default=0,
                        help="If percentage sets the k smaller than this, make this the k")
    parser.add_argument("--max_k", type=int, default=float("inf"),
                        help="If percentage sets the k higher than this, make this the k")
    parser.add_argument("--end", type=int, default=None,
                        help="Stop experiment on this instance.")
    parser.add_argument("--rerun", type=bool, default=False, help="rerun cases that timed out")
    args = parser.parse_args()
    # TIMEOUT_P = args.timeout_p
    # TIMEOUT_C = args.timeout_c
    print(args)
    if args.rerun:
        rerun_timeouts(args.base_folder, args.source_folders)
    else:
        create_stats_fixed_committees(
            source_folders=args.source_folders,
            committee_sizes=args.committee_sizes,
            rules=args.rules,
            optimality_requirements=args.optimality_requirements,
            start=args.start,
            min_voters=args.min_voters,
            max_alternatives=args.max_alternatives,
            percentage=args.percentage,
            files=args.files,
            base_folder=args.base_folder,
            min_alternatives=args.min_alternatives,
            default_ks=args.default_k,
            end=args.end,
            min_k=args.min_k,
            max_k=args.max_k
        )
