import argparse
import itertools
import random
import os
import math
import argparse
import re
import sys
import numpy as np

import matplotlib.pyplot as plt
from collections import Counter

from reward import Stat, Act
from config import SUPPORT_PROTOCOLS

if os.getenv("LOG") == "1":
    ENABLE = True
    LEVEL = 1
elif os.getenv("LOG") == "2":
    ENABLE = True
    LEVEL = 2
else:
    ENABLE = False


def parse_arg():
    parser = argparse.ArgumentParser()
    parser.add_argument("--players", type=int, required=True, default=3, help="number of players")
    parser.add_argument("--rounds", type=int, required=True, default=2, help="number of rounds")
    parser.add_argument("--save_dir", type=str, help="path to save the model")
    parser.add_argument("--load_dir", type=str, help="path to load the model")
    parser.add_argument("--protocol", type=str, default="primary_backup", help="protocol type", choices=SUPPORT_PROTOCOLS)
    parser.add_argument(
        "--math_fn",
        type=str,
        help="math function to learn (only valid when --protocol == math",
        choices=["sum", "min_idx", "max_idx"],
    )
    parser.add_argument(
        "--model_type", type=str, required=True, help="choose model (MLP or Transformer)", choices=["mlp", "transformer"]
    )
    parser.add_argument("--gpu", type=bool, help="use gpu for training", default=False)
    parser.add_argument("--debug", type=bool, help="debug mode", default=False)
    parser.add_argument("--verbose", type=bool, help="verbose mode", default=False)
    parser.add_argument("--freeze", type=bool, help="freeze some action", default=False)
    parser.add_argument("--gt", type=bool, help="use ground truth for evaluation", default=False)
    parser.add_argument("--fix", type=bool, help="whether enable fixing logical in mcts", default=False)
    parser.add_argument("--encode_id", type=bool, help="whether encode node id into input state", default=False)
    parser.add_argument("--test_id", type=int, default=None, help="specify the test id that you want to test")
    parser.add_argument("--history", type=bool, default=False, help="encode previous history into input state")
    parser.add_argument("--pre_train", type=int, default=0, help="iterations for pre train without fixing logic")
    parser.add_argument("--interval", type=int, default=20, help="interval for fixing logic")

    opt = parser.parse_args()
    return opt


def construct_bits_string(states, n):
    # states: value for string
    # : format
    # 0: use 0 to padding
    # n: length to padding
    # b: binary representation
    s = f"{states:0{n}b}"
    return s


"""
    Fill unneeded digits with 0
    state: input e.g., [1,-1,1]
"""


def pad_input_state(state, players, index):
    fill_value = 0.0
    output = []
    for i in range(players):
        if i == index:
            output += state
            output += [float(Stat.Dummy.value)]
        else:
            output += [fill_value] * (players + 1)
    return output


def map_idx_to_action(act_idx):
    if act_idx == 0:
        return float(Act.Abort.value)
    elif act_idx == 1:
        return float(Act.Commit.value)
    elif act_idx == 2:
        return float(Act.DoNothing.value)
    else:
        raise "action index out of range"


def print_digit_information():
    msg = """ 
    Digits for Action:
    4: Abort
    5: Commit
    6: Do nothing
    3: Lost

    Digits in state/message:
    0: no information (fill unneeded positions)
    1: abort
    2：commit
    3: Lost
    """
    print(msg)


def log(func, level, message):
    if ENABLE and level <= LEVEL:
        print(f"[{func}] {message}")
    else:
        pass


def log_level(level, func, message):
    if level <= LEVEL:
        print(f"[{func}] {message}")
    else:
        pass


def compute_coordinator_cases(players):
    all_cases = 0
    for i in range(players + 1):
        all_cases += nCr(players, i)
    return all_cases * 2**players


def nCr(n, r):
    f = math.factorial
    return f(n) // f(r) // f(n - r)


"""
Generate all possible inputs for coordinator
"""


def generate_arrays(length):
    if length == 0:
        return [[]]
    arrays = []
    for array in generate_arrays(length - 1):
        arrays.append(array + [1.0])
        arrays.append(array + [2.0])
        arrays.append(array + [3.0])
    return arrays


"""
Hash observations in order to make sure that different agents 
make same action when see same observations during training
"""


def hash_observation_for_embedding(state):
    return str(sorted(Counter(state).items()))


"""
Hash by input
"""


def hash_observation(state):
    return str(state)


# map action index to action value
def action_map(act_index):
    return act_index + 1


def extract_failed_cases(log_file):
    # Read the file and filter lines with the specified format
    with open(log_file, "r") as file:
        lines = [line.strip() for line in file if re.match(r"^\d+:\d+$", line)]

    # Extract the first number from each line
    numbers = [int(line.split(":")[0]) for line in lines]

    # Create the plot
    plt.plot(range(len(numbers)), numbers, marker="o", linewidth=0.5, markersize=1)
    plt.xlabel("Training step")
    plt.ylabel("Failed Cases")
    # plt.show()
    plt.savefig("min_idx_15p_transformer.png", dpi=300)


def get_number_of_parameters(network):
    size = np.sum([np.prod(v.get_shape().as_list()) for v in network.trainable_variables])
    return size


def divide_list(lst, n):
    # Calculate the length of each part
    length = len(lst)
    part_length = length // n
    remainder = length % n

    # Initialize variables
    result = []
    start = 0

    # Divide the list into parts
    for i in range(n):
        if i < remainder:
            end = start + part_length + 1
        else:
            end = start + part_length
        result.append(lst[start:end])
        start = end

    return result


def evenly_distributed_crash_with_survival(alive_nodes):
    """
    Randomly selects a subset of nodes to crash, ensuring uniform probability
    across all possible crash scenarios while keeping at least one node alive.

    :param alive_nodes: List of indices representing currently alive nodes.
    :return: A set of crashed nodes.
    """
    n = len(alive_nodes)
    if n == 1:
        return set()  # If there's only one node, it must stay alive.

    # Get all non-empty crash scenarios except total failure
    crash_scenarios = [
        set(subset) for subset in itertools.chain.from_iterable(itertools.combinations(alive_nodes, r) for r in range(n))
    ]

    # Select one scenario uniformly at random
    return random.choice(crash_scenarios)


if __name__ == "__main__":
    extract_failed_cases(sys.argv[1])
