import itertools
import time
import csv
import numpy as np
import h5py
import os
import pickle

# from src.generate import generate_utility_matrices

def load_utility_matrices(filename):
    with h5py.File(filename, 'r') as f:
        utilities = np.array(f["utilities"])
    return utilities

def generate_utility_matrices(n_players, m_actions, util_range=None, seed=None):
    if seed is not None:
        np.random.seed(seed)
    
    shape = (n_players,) + (m_actions,) * n_players
    
    if util_range is None:
        utilities = np.random.uniform(-1, 1, size=shape)
    else:
        utilities = np.random.randint(util_range[0], util_range[1] + 1, size=shape)
    
    return utilities

def find_best_pure_NE(payoffs, nb_players, nb_actions):
    max_values = []
    #optimal_actions = []
    exploitabilities = []
    for i in range(nb_players):
        max_value = np.max(payoffs[i,:],axis = i)
        # print("*****************************")
        # print(max_value)
        max_values.append(max_value)
        
        expanded_max_val = np.expand_dims(max_value,axis = i)
        # print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
        # print(expanded_max_val)
        #is_optimal_action = (payoffs[i,:] == expanded_max_val)
        exploitability = expanded_max_val - payoffs[i,:]
        # print(exploitability)
        #optimal_actions.append(is_optimal_action)
        exploitabilities.append(exploitability)
    # print("-----------------------")
    # print(exploitabilities)
    final_exploitability = np.sum(exploitabilities,axis=0)
    # print("============================")
    # print(final_exploitability)
    #is_pure_NE = np.logical_and.reduce(optimal_actions)
    #NE_list = np.where(is_pure_NE)

    flat_index = np.argmin(final_exploitability)
    coordinates = np.unravel_index(flat_index, final_exploitability.shape)
    coordinates = [int(x) for x in coordinates]
    
    return (float(np.min(final_exploitability)), coordinates)

import argparse

parser = argparse.ArgumentParser(description="solve a multiplayer game")
parser.add_argument('-n', '--n_players', type=int, default=2, help="Number of players.")
parser.add_argument('-m', '--m_actions', type=int, default=2, help="Number of actions per player.")
parser.add_argument('-s', '--seed', type=int, default=None, help="Random seed for reproducibility.")
parser.add_argument('-k', '--max_seed', type=int, default=None, help="Run many seeds.")
parser.add_argument('-if', '--input_file', type=str, default=None, help="Read input file.")

args = parser.parse_args()


N = args.n_players
M = args.m_actions

if args.input_file is not None:
    ext = os.path.splitext(args.input_file)[1].lower()

    if ext == ".h5":
        u = load_utility_matrices(args.input_file)
    elif ext == ".pkl":
        with open(args.input_file, "rb") as f:
            data = pickle.load(f)
        u = data["multiplayer_tensor"]
    else:
        raise ValueError("Unsupported file type: must be .h5 or .pkl")

    N = len(u)
    exp, sol = find_best_pure_NE(u, N, None)
    print(exp)
    print(sol)

if args.seed is not None:
    u = np.array(generate_utility_matrices(N, M, seed=args.seed))
    exp, sol = find_best_pure_NE(u,N,M)
    values = ""
    for i in range(N):
        values += str(u[(i,) + tuple(sol)]) + "; "
    print(exp, sol, values)

    if exp < 1e-8:
        is_nash = True
        for i in range(N):
            utils = []
            for a in range(M):
                profile = [_ for _ in sol]
                profile[i] = a
                utils.append(u[(i,) + tuple(profile)])
            if u[(i,) + tuple(sol)] < max(utils):
                # print(u[(i,) + tuple(sol)], max(utils))
                is_nash = False
        print('is nash:', is_nash)


if args.max_seed is not None:

    results = []

    overall_time = time.time()

    for S in range(args.max_seed):

        u = np.array(generate_utility_matrices(N, M, seed=S))

        start_time = time.time()
        exp, sol = find_best_pure_NE(u,N,M)
        elapsed_time = time.time() - start_time

        values = ""
        for i in range(N):
            values += str(u[(i,) + tuple(sol)]) + "; "
        # print(exp, values)
        results.append({"N" : N, "M" : M, "S" : S, "exploitability" : exp, "time": elapsed_time, "values" : values[:-2]})

    with open("nash_results.csv", mode="w", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=["N", "M", "S", "exploitability", "time", "values"])
        writer.writeheader()  # Write the header row
        for result in results:
            writer.writerow(result)

    sorted_results = sorted(results, key=lambda x: x["exploitability"], reverse=True)

    # Print the top 10 results
    for i, res in enumerate(sorted_results[:10], start=1):
        print(f"S: {res['S']}, "
              f"Exploitability: {res['exploitability']:.6f}, "
              f"Time: {res['time']:.2f}, Values: {res['values']}")

    print("took:", time.time() - overall_time)
