from dataclasses import dataclass
from array import array
import time
from bitarray import bitarray
import random
import datetime
import os

now = datetime.datetime.now()
now = f"{now}".replace(':', '-').replace(' ', '_')
file = open(f"Results_{now}.results", 'a')

@dataclass
class Configuration:
    aggregate_demand: tuple

    def add_action_to_configuration(self, game, i: int, action: list):
        new_configuration = array('L', self.aggregate_demand)
        for resource in action:
            for j in range(game.k):
                new_configuration[resource * game.k + j] += game.demand[i][j]
        return Configuration(tuple(new_configuration))

    def remove_action_from_configuration(self, game, i: int, action: list):
        new_configuration = array('L', self.aggregate_demand)
        for resource in action:
            for j in range(game.k):
                if new_configuration[resource * game.k + j] < game.demand[i][j]:
                    return (None, False)
                new_configuration[resource * game.k + j] -= game.demand[i][j]
        return (Configuration(tuple(new_configuration)), True)
    
    # Utility when agent plays action under the configuration self.
    def utility(self, game, agent: int, action: list) -> float:
        cost = 0.0
        for resource in action:
            match game.resource_cost[resource][0]:
                case "kD non-monotonic":
                    this_resource_cost = 0.0
                    for j in range(game.k):
                        this_resource_cost += game.z[resource][j][self.aggregate_demand[resource * game.k + j]] # The aggregate demand in the jth dimension of the resource.

                    this_resource_cost *= game.alpha[resource]
                    this_resource_cost += game.beta[resource]
                    cost += this_resource_cost
                case "debug_fixed":
                    return 1
                case _:
                    raise Exception(f"Resource {resource} has a cost function with the unknown type: {game.resource_cost[resource]}")
        return -cost
    
    # Utility when agent stops playing action and starts playing alt_action under the configuration self.
    def alt_utility(self, game, agent: int, action: list, alt_action: list) -> (float, bool):
        (new_configuration, is_legal) = self.remove_action_from_configuration(game, agent, action)
        if not is_legal:
            return (None, False)
        new_configuration = new_configuration.add_action_to_configuration(game, agent, alt_action)
        return (new_configuration.utility(game, agent, alt_action), True)
    
    def empty_configuration(game):
        # Configurations are a 1D array that we pretend are a 2D array because of how Python works.
        con = array('L', [0 for i in range(game.k * game.num_resources)])
        # We convert the array to a tuple, because tuples are immutable: hence they can be hashed.
        con = tuple(con)
        return Configuration(con)

    def __key(self):    
        return (self.aggregate_demand)

    def __hash__(self):
        return hash(self.__key())

    def __eq__(self, other):
        if isinstance(other, Configuration):
            return self.__key() == other.__key()
        return NotImplemented

@dataclass
class Game:
    n: int
    k: int
    seed: int
    num_resources: int
    resource_cost: list
    demand: list
    actions: list
    alpha: list
    beta: list
    z: list
    rand_util_lookup: dict
    L: int

    # Generate a list of sets of configurations. One list for each table, T_0 through T_n.
    def generate_all_configurations(self):
        table_i = {Configuration.empty_configuration(self)}
        total_number_of_configurations_full_and_partial = 1
        total_edges = 0

        for i in range(self.n):
            table_i_plus_plus = set()
            for partial_configuration in table_i:
                total_edges += 2
                for action in self.actions[i]:
                    next_configuration = partial_configuration.add_action_to_configuration(self, i, action)
                    table_i_plus_plus.add(next_configuration)
            table_i = table_i_plus_plus
            total_number_of_configurations_full_and_partial += len(table_i)

        return (table_i, total_number_of_configurations_full_and_partial, total_edges)
    
    def dynamic_program_table_based_asymptotic_measurement_procedure_1(self, configuration: Configuration):

        # Run procedure 1
        an_agent_lacks_a_configuration_BR = False
        BR = [[] for _ in range(self.n)]
        for i in range(self.n):
            for action in self.actions[i]:
                utility_for_action = configuration.utility(self, i, action)
                action_is_a_BR = True
                for alt_action in self.actions[i]:
                    if action == alt_action:
                        continue
                    (utility_for_alt_action, alt_is_legal) = configuration.alt_utility(self, i, action, alt_action)
                    if alt_is_legal and utility_for_alt_action > utility_for_action:
                        action_is_a_BR = False
                        # This is for measuring asymptotic runtime cost, so we do not break
                        #   early if their is an alternative best response
                        pass
                        # break
                # Because we are computing asymptotic runtime we consider every action a best response.
                BR[i].append(action)
            if len(BR[i]) == 0:
                an_agent_lacks_a_configuration_BR = False
                # This is for measuring asymptotic runtime cost, so we do not break if an agent lacks any best response.
                pass
                # break
        if an_agent_lacks_a_configuration_BR:
            pass # See above comments.
            # continue

        return BR
    
    def dynamic_program_table_based_asymptotic_measurement_procedure_2(self, configuration: Configuration, w_max, BR, bitarray_length):

        # Run procedure 2
        table_i = bitarray(bitarray_length)
        table_i.setall(1) # Because we are computing asymptotic runtime we set every bit to 1.
        table_i[0] = 1
        for i in range(self.n): # For each table i
            table_i_plus_plus = bitarray(bitarray_length)
            table_i_plus_plus.setall(1)

            # Final all bits in table_i bitarray that are set to 1.
            for bitarray_index in table_i.itersearch(bitarray('1')): # Because we are computing asymptotic runtime we set every bit to 1.

                # Convert bitarray index to partial configuation.
                partial_configuration = self.bitarray_index_to_configuration(bitarray_index, w_max)
                for action in BR[i]:
                    next_configuration = partial_configuration.add_action_to_configuration(self, i, action)
                    next_index = self.configuration_to_bitarray_index(next_configuration, w_max)
                    if next_index < bitarray_length: # Because every bit to 1 in the asymptotic runtime version we must add this guard to prevent index out of bounds.
                        table_i_plus_plus[next_index] = 1
            table_i = table_i_plus_plus

        # Wrap up
        configuration_index = self.configuration_to_bitarray_index(configuration, w_max)
        if table_i[configuration_index] == 1:
            return True
            
        return False

    def dynamic_program_table_based(self, quit_after_first_ne: bool, return_all_pure_ne: bool):
        num_configurations_with_pure_ne = 0
        w_max = self.w_max()
        (all_configurations, _, _) = self.generate_all_configurations()
        ne = []
        for configuration in all_configurations:
            # Run procedure 1
            an_agent_lacks_a_configuration_BR = False
            BR = [[] for _ in range(self.n)]
            for i in range(self.n):
                for action in self.actions[i]:
                    utility_for_action = configuration.utility(self, i, action)
                    action_is_a_BR = True
                    for alt_action in self.actions[i]:
                        if action == alt_action:
                            continue
                        (utility_for_alt_action, alt_is_legal) = configuration.alt_utility(self, i, action, alt_action)
                        if alt_is_legal and utility_for_alt_action > utility_for_action:
                            action_is_a_BR = False
                            break
                    if action_is_a_BR:
                        BR[i].append(action)
                if len(BR[i]) == 0:
                    an_agent_lacks_a_configuration_BR = False
                    break
            if an_agent_lacks_a_configuration_BR:
                continue

            # Run procedure 2
            table_i = bitarray(pow(w_max+1,self.k*self.num_resources))
            table_i.setall(0)
            table_i[0] = 1
            for i in range(self.n): # For each table i
                table_i_plus_plus = bitarray(pow(w_max+1,self.k*self.num_resources))
                table_i_plus_plus.setall(0)

                # Final all bits in table_i bitarray that are set to 1.
                for bitarray_index in table_i.itersearch(bitarray('1')):

                    # Convert bitarray index to partial configuation.
                    partial_configuration = self.bitarray_index_to_configuration(bitarray_index, w_max)
                    for action in BR[i]:
                        next_configuration = partial_configuration.add_action_to_configuration(self, i, action)
                        next_index = self.configuration_to_bitarray_index(next_configuration, w_max)
                        table_i_plus_plus[next_index] = 1
                table_i = table_i_plus_plus

            # Wrap up
            configuration_index = self.configuration_to_bitarray_index(configuration, w_max)
            if table_i[configuration_index] == 1:
                if return_all_pure_ne:
                    this_configurations_ne = self.extract_ne_from_configuration_sets(configuration, BR)
                    ne = ne + this_configurations_ne
                
                if quit_after_first_ne:
                    ne.sort()
                    ne = [array('L', ne[h]) for h in range(len(ne))]
                    return (1, ne)
                num_configurations_with_pure_ne += 1

        if return_all_pure_ne:
            ne.sort()
            ne = [array('L', ne[h]) for h in range(len(ne))]
        return (num_configurations_with_pure_ne, ne)

    def dynamic_program_set_based(self, quit_after_first_ne: bool, return_all_pure_ne: bool):
        num_configurations_with_pure_ne = 0
        (all_configurations, _, _) = self.generate_all_configurations()
        ne = []
        for configuration in all_configurations:
            # Run procedure 1
            an_agent_lacks_a_configuration_BR = False
            BR = [[] for _ in range(self.n)]
            for i in range(self.n):
                for action in self.actions[i]:
                    utility_for_action = configuration.utility(self, i, action)
                    action_is_a_BR = True
                    for alt_action in self.actions[i]:
                        if action == alt_action:
                            continue
                        (utility_for_alt_action, alt_is_legal) = configuration.alt_utility(self, i, action, alt_action)
                        if alt_is_legal and utility_for_alt_action > utility_for_action:
                            action_is_a_BR = False
                            break
                    if action_is_a_BR:
                        BR[i].append(action)
                if len(BR[i]) == 0:
                    an_agent_lacks_a_configuration_BR = False
                    break
            if an_agent_lacks_a_configuration_BR:
                continue

            # Run procedure 2
            table_i = {Configuration.empty_configuration(self)}
                
            for i in range(self.n): # For each table i
                table_i_plus_plus = set()
                for partial_configuration in table_i:
                    for action in BR[i]:
                        next_configuration = partial_configuration.add_action_to_configuration(self, i, action)
                        table_i_plus_plus.add(next_configuration)
                table_i = table_i_plus_plus

            # Wrap up
            if configuration in table_i:
                if return_all_pure_ne:
                    this_configurations_ne = self.extract_ne_from_configuration_sets(configuration, BR)
                    ne = ne + this_configurations_ne
                
                if quit_after_first_ne:
                    ne.sort()
                    ne = [array('L', ne[h]) for h in range(len(ne))]
                    return (1, ne)
                num_configurations_with_pure_ne += 1

        if return_all_pure_ne:
            ne.sort()
            ne = [array('L', ne[h]) for h in range(len(ne))]
        return (num_configurations_with_pure_ne, ne)

    def extract_ne_from_configuration_sets(self, configuration: Configuration, BR: list[list]):
        ne_extract = {configuration: [[]]}
        for i in range(self.n, 0, -1):
            next_ne_extract = dict()
            for partial_configuration, partial_nes in ne_extract.items():
                for action in BR[i-1]:
                    action_index = self.actions[i-1].index(action)
                    (next_configuration, legal_remove) = partial_configuration.remove_action_from_configuration(self, i-1, action)
                    if legal_remove:
                        next_partial_nes = []
                        if next_configuration in next_ne_extract:
                            next_partial_nes = next_ne_extract.pop(next_configuration)
                        for partial_ne in partial_nes:
                            partial_ne = partial_ne[:] # Create a copy
                            partial_ne.append(action_index)
                            next_partial_nes.append(partial_ne)
                        next_ne_extract.update({next_configuration:next_partial_nes})
            ne_extract = next_ne_extract
        assert Configuration.empty_configuration(self) in ne_extract
        # NE are currently in reverse order from agent n to 1.
        reversed_ne = ne_extract.pop(Configuration.empty_configuration(self))
        this_configurations_ne = list(reversed(reversed_ne))
        return this_configurations_ne

    def generate_kdim_parameters(kdim_cost_parameters, k, num_resources):
        alpha = []
        beta = []
        z = []
        L = 0
        match kdim_cost_parameters:
            case ("seed_non-monotonic_int", L, w_ceiling):
                for _ in range(num_resources):
                    alpha.append(random.randint(0, L))
                for _ in range(num_resources):
                    beta.append(random.randint(0, L))
                z = [[[random.randint(0, L) for _ in range(w_ceiling)] for _ in range(k)] for _ in range(num_resources)]

            case None:
                pass

            case x:
                raise ValueError(f"Invalid parameter in generate_kdim_parameters: {x}")
            
        return (alpha, beta, z, L)
    
    def generate_demand_vectors(demand_parameters, n, k):
        demand_vectors = []
        match demand_parameters:
            case None:
                demand_vectors = [[1 for _ in range(k)] for _ in range(n)]
                
            case ("random_int", low, high):
                demand_vectors = []
                i = 0
                while i < n:
                    all_zeros = True
                    demand_vector = []
                    for _ in range(k):
                        demand = random.randint(low, high)
                        demand_vector.append(demand)
                        if demand != 0:
                            all_zeros = False
                    if not all_zeros: # Regenerate agent i's demand vector if all zeros.
                        i += 1
                        demand_vectors.append(demand_vector)

            case x:
                raise ValueError(f"Invalid parameter in generate_demand_vectors: {x}")
        return demand_vectors



    def generate_parallel_link_model(n_value, num_links, resource_cost_function, k=1, kdim_cost_parameters=None, demand_parameters=None, seed=None):
        if seed != None:
            random.seed(seed)
        else:
            seed = random.randint(-2147483649, 2147483648)
        n = n_value
        num_resources = num_links
        resource_cost = [(resource_cost_function, 1) for _ in range(num_resources)]
        actions = [[[j] for j in range(num_resources)] for _ in range(n)]

        (alpha, beta, z, L) = Game.generate_kdim_parameters(kdim_cost_parameters, k, num_resources)

        demand = Game.generate_demand_vectors(demand_parameters, n, k)
        
        return Game(n = n, k = k, seed = seed, num_resources = num_resources, resource_cost = resource_cost, demand = demand, actions = actions, alpha = alpha, beta = beta, z = z, rand_util_lookup=dict(), L=L)

    def w_max(self):
        w_max = 0
        for j in range(self.k):
            w_j = 0
            for i in range(self.n):
                w_j += self.demand[i][j]
            w_max = max(w_max, w_j)
        return w_max
    
    def bitarray_index_to_configuration(self, index, w_max):
        w_max_adj = w_max + 1 # The w_max range is inclusive where both 0 and w_max are valid. So we increment by 1.
        new_configuration = array('L', [0 for i in range(self.k * self.num_resources)])
        for m in range(self.num_resources-1, -1, -1):
            for j in range(self.k-1, -1, -1):
                offset_multiplier = pow(w_max_adj, m * self.k + j)
                value = index // offset_multiplier
                new_configuration[m * self.k + j] = value
                index -= value * offset_multiplier
        assert index == 0
        return Configuration(tuple(new_configuration))

    def configuration_to_bitarray_index(self, configuration, w_max):
        index = 0
        w_max_adj = w_max + 1 # The w_max range is inclusive where both 0 and w_max are valid. So we increment by 1.
        for m in range(self.num_resources):
            for j in range(self.k):
                demand_on_resource_m_in_jth_dimension = configuration.aggregate_demand[m * self.k + j]
                multiplier = pow(w_max_adj, m * self.k + j)
                index += demand_on_resource_m_in_jth_dimension * multiplier
        return index

    def strategy_profile_to_configuration(self, strategy_profile) -> Configuration:
        assert len(strategy_profile) == self.n
        configuration = Configuration.empty_configuration(self)
        for i in range(self.n):
            configuration = configuration.add_action_to_configuration(self, i, self.actions[i][strategy_profile[i]])
        return configuration

    def strategy_profile_utility(self, i, strategy_profile) -> float:
        configuration = self.strategy_profile_to_configuration(strategy_profile)
        return configuration.utility(self, i, self.actions[i][strategy_profile[i]])

    def is_strategy_profile_NE(self, strategy_profile) -> bool:
        for i in range(self.n):
            utility = self.strategy_profile_utility(i, strategy_profile)
            for action_index in range(len(self.actions[i])):
                if action_index == strategy_profile[i]:
                    continue
                alt_strategy_profile = strategy_profile[:] # Create a copy
                alt_strategy_profile[i] = action_index
                if utility < self.strategy_profile_utility(i, alt_strategy_profile):
                    return False
        return True

    def brute_force(self, quit_after_first_ne: bool, return_all_pure_ne: bool):
        strategy_profile = array('L', [0 for _ in range(self.n)])
        num_NE = 0
        ne = []
        while True:
            if self.is_strategy_profile_NE(strategy_profile):
                num_NE += 1
                if return_all_pure_ne:
                    copy = strategy_profile[:] # Create a copy
                    ne.append(copy)
                if quit_after_first_ne:
                    return (num_NE, ne)

            for i in range(self.n + 1):
                if i == self.n:
                    return (num_NE, ne)
                
                if strategy_profile[i] >= len(self.actions[i]) - 1:
                    strategy_profile[i] = 0
                else:
                    strategy_profile[i] += 1
                    break

def print_and_write(string):
    global file
    file.write(f"{string}\n")
    print(string)

def simulations():
    # Ensure that set iteration is deterministic for experiments.
    # See: https://stackoverflow.com/questions/3848091/set-iteration-order-varies-from-run-to-run
    hashseed = os.getenv('PYTHONHASHSEED')
    if not hashseed:
        print("""
***********************************************
***********************************************
Set environmental variable \"PYTHONHASHSEED\" = 0
***********************************************
***********************************************\n""")

    # Run parameters
    num_trials = 15
    master_seed = 2024
    starting_n = 2
    max_n = 30
    step_n = 1
    w_ceiling = 500 # w_ceiling should always be greater than w_max. This ensures that each time a trial is run with n+1 agents it is identical to the trial with n agents except for the additional agent.
    quit_if_average_run_exceeds = 300 # Seconds
    global file

    # Create a comma-delimited header for the output file.
    s = [f"trial_{i}_runtime" for i in range(num_trials)]
    s = ','.join(s)
    s = "algorithm,n,num_links,k,rand_range," + s
    print_and_write(s)

    random.seed(master_seed)
    seed_list = [(seed_index, random.random()) for seed_index in range(num_trials)]

    experiment_parameters = []
    for algorithm in ["BF", "SDP", "TDP"]:
        for k in [3]:
            for num_links in [2]:
                for q in [1]:
                    experiment_parameters.append((0.0, algorithm, k, num_links, q, starting_n))

    while experiment_parameters:
        experiment_parameters.sort(key=lambda tup: tup[0])
        (_, algorithm, k, num_links, q, n) = experiment_parameters.pop(0)

        print(f"{algorithm},{n},{num_links},{k},{q}")

        # Run experiment
        sum_of_runtime = 0.0
        runtimes = []
        continue_tests = True
        for (seed_index, seed) in seed_list:
            game = Game.generate_parallel_link_model(n, num_links, "kD non-monotonic", k=k, kdim_cost_parameters=("seed_non-monotonic_int", 100, w_ceiling), demand_parameters=("random_int", 0, q), seed=seed)
            w_max = game.w_max()

            start = time.perf_counter()
            match algorithm:
                case "BF":
                    (num_NE, ne) = game.brute_force(False, False)
                case "SDP":
                    (num_configurations_with_pure_ne, ne) = game.dynamic_program_set_based(False, False)
                case "TDP":
                    # Skip TDP if it will take more than a Gigabyte in memory.
                    if pow(w_max+1,game.k*game.num_resources) >= 8 * 1000000000:
                        continue_tests = False
                        break
                    (num_configurations_with_pure_ne, ne) = game.dynamic_program_table_based(False, False)
                case x:
                    raise ValueError(f"Invalid parameter in main: {x}")

            stop = time.perf_counter()
            sum_of_runtime += stop - start
            runtimes.append(f"{stop - start:0.8f}")

        print_and_write(f"{algorithm},{n},{num_links},{k},{q},{','.join(runtimes)}")

        #Quit if average runtime for a given n exceeds quit_if_average_run_exceeds
        average_runtime = sum_of_runtime / num_trials
        if average_runtime <= quit_if_average_run_exceeds and \
           n + step_n <= max_n and \
           continue_tests:
           experiment_parameters.append((average_runtime, algorithm, k, num_links, q, n + step_n))
    
    file.close()

def asymptotic_measurements():
    num_trials = 15
    master_seed = 2024
    start_n =2
    max_n = 100
    bitarray_length = 1000
    w_ceiling = 500 # w_ceiling should always be greater than w_max. This ensures that each time a trial is run with n+1 agents it is identical to the trial with n agents except for the additional agent.

    # Create a comma-delimited header for the output file.
    s = [f"trial_{i}_runtime" for i in range(num_trials)]
    s = ','.join(s)
    s = "algorithm,n,num_links,k,rand_range," + s
    global file
    print_and_write(s)

    random.seed(master_seed)
    seed_list = [random.random() for _ in range(num_trials)]

    experiment_parameters = []
    for algorithm in ["TDP_Asymptotic_Procedure_1", "TDP_Asymptotic_Procedure_2"]:
        for k in [3]:
            for num_links in [2]:
                for q in [1]:
                    experiment_parameters.append((algorithm, k, num_links, q, start_n))

    while experiment_parameters:
        (algorithm, k, num_links, q, n) = experiment_parameters.pop(0)

        runtimes = []
        continue_tests = True
        for seed in seed_list:
            game = Game.generate_parallel_link_model(n, num_links, "kD non-monotonic", k=k, kdim_cost_parameters=("seed_non-monotonic_int", 100, w_ceiling), demand_parameters=("random_int", 0, q), seed=seed)

            (start, stop) = (None, None)
            match algorithm:
                case "TDP_Asymptotic_Procedure_1":
                    strategy_profile = array('L', [0 for _ in range(game.n)])
                    configuration = game.strategy_profile_to_configuration(strategy_profile)
                    
                    start = time.perf_counter()
                    game.dynamic_program_table_based_asymptotic_measurement_procedure_1(configuration)
                    stop = time.perf_counter()
                case "TDP_Asymptotic_Procedure_2":
                    w_max = game.w_max()
                    configuration = Configuration.empty_configuration(game) # The exact configuration does not impact procedure 2, as the configuration is mainly used in procedure 1. An empty configuration is used to avoid an index out of range for a single call on the bitarray.
                    
                    # Make every action a best response to measure asymptotic time.
                    BR = [[action for action in game.actions[i]] for i in range(game.n)]
                    start = time.perf_counter()
                    game.dynamic_program_table_based_asymptotic_measurement_procedure_2(configuration, w_max, BR, bitarray_length)
                    stop = time.perf_counter()

            runtimes.append(f"{stop - start:0.8f}")
        print_and_write(f"{algorithm},{n},{num_links},{k},{q},{','.join(runtimes)}")
        if n + 1 < max_n and continue_tests:
           experiment_parameters.append((algorithm, k, num_links, q, n + 1))
    
    file.close()

def main():
    print("""
1: Run simulations with current coded settings.
2: Get asymptotic measurements for this computer.
Enter choice: """)
    match input():
        case "1":
            simulations()
        case "2":
            asymptotic_measurements()
        case default:
            print("Invalid entry")
    


if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        file.close()