import itertools
import argparse
import pickle
import os
from multiprocessing import Pool
from functools import partial
from tqdm import tqdm

from machiavelli.game.machiavelli_env import MachiavelliEnv 
from abstract_cf.machiavelli import compute_abstraction_given_state_actions_dist, machiavelli_abstraction_given_state_action

# -----------------------------------------------------------------------------
# Wrapper to reorder arguments so that we can fix category and field via partial.
# The original function likely has the signature:
#     machiavelli_abstraction_given_state_action(game, state, action, category, field, verbose)
#
# We define a wrapper with a default verbose argument. When called with
# (game, state, action, verbose), the category and field are provided via partial.
# -----------------------------------------------------------------------------
def abstraction_wrapper(game, state, action, verbose=False, category=None, field=None):
    return machiavelli_abstraction_given_state_action(game, state, action, category, field, verbose)

# -----------------------------------------------------------------------------
def simulate_random_agent(game, p_abstraction_given_state_action):
    """
    Runs a single random simulation until a non-degenerate abstraction state is found.
    Returns the action sequence (as a tuple) that led to that state.
    
    If an exception occurs during the simulation, the error is logged and the function
    returns None.
    """
    try:
        env = MachiavelliEnv(game, agent_name='sampler')
        action_sequence = []

        # Reset the environment and get its initial state.
        env.reset()
        state = env._get_state()

        p_abstraction = 0
        # Continue taking random actions until the abstraction distribution is non-degenerate.
        while p_abstraction == 0 or p_abstraction == 1:
            action = env.action_space.sample()
            action_sequence.append(action)
            env.step(action)
            state = env._get_state()
            # Compute the abstraction distribution.
            Y_given_X_A = compute_abstraction_given_state_actions_dist(
                game,
                state,
                list(range(env.action_space.n)),
                p_abstraction_given_state_action
            )
            p_abstraction = Y_given_X_A.sum() / len(Y_given_X_A)
        
        return tuple(action_sequence)
    except Exception as e:
        print(f"Error in simulate_random_agent for game {game}: {e}")
        return None

# -----------------------------------------------------------------------------
def worker(dummy, game, p_abstraction_given_state_action):
    """
    Worker function to run a simulation.
    The dummy argument is ignored; it allows us to use an infinite iterator as input.
    """
    return simulate_random_agent(game, p_abstraction_given_state_action)

# -----------------------------------------------------------------------------
def find_non_degenerate_binary_abstraction_states(game, p_abstraction_given_state_action, num_states=10):
    """
    Runs multiple random agents in parallel (using multiprocessing) to find states where
    the abstraction values are non-degenerate.
    
    Parameters:
      game: The game identifier (a string).
      p_abstraction_given_state_action: A callable that computes the abstraction probability
          given a state and an action.
      num_states: The (target) number of interesting action sequences to return.
      
    Returns:
      A list of interesting action sequences (each represented as a tuple of actions).
    """
    interesting_action_sequences = set()  # Use a set to enforce uniqueness.

    # Create a worker function with game and the abstraction callable pre-bound.
    worker_func = partial(worker, game=game, p_abstraction_given_state_action=p_abstraction_given_state_action)

    # Create a multiprocessing pool.
    with Pool() as pool:
        with tqdm(total=num_states, desc="Interesting sequences found") as pbar:
            # Use an infinite iterator so that tasks keep being generated.
            for seq in pool.imap_unordered(worker_func, itertools.count()):
                if seq is None:
                    # A failure occurred in this simulation; continue to the next one.
                    continue
                if seq not in interesting_action_sequences:
                    interesting_action_sequences.add(seq)
                    pbar.update(1)
                    if len(interesting_action_sequences) >= num_states:
                        break

    return list(interesting_action_sequences)

# -----------------------------------------------------------------------------
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Find non-degenerate binary abstraction states.")
    parser.add_argument('--game', type=str, default='aegis-project-180-files',
                        help='Game identifier')
    parser.add_argument('--category', type=str, default='morality',
                        help='Category')
    parser.add_argument('--field', type=str, default='physical_harm',
                        help='Field')
    parser.add_argument('--num_states', type=int, default=10,
                        help='Number of interesting states to find')
    parser.add_argument('--output', type=str, default='machiavelli_case_studies',
                        help='Directory in which to save the output pickle file')
    args = parser.parse_args()

    # Construct the default filename.
    filename = f"{args.game}_{args.category}_{args.field}.pkl"
    output_dir = args.output

    # Ensure the output directory exists.
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)

    output_path = os.path.join(output_dir, filename)

    # If the file already exists, ask the user for a new filename.
    if os.path.exists(output_path):
        new_filename = input(f"File {filename} already exists in {output_dir}. Please enter a new filename (or press Enter to overwrite): ").strip()
        if new_filename:
            # If the new filename does not end with .pkl, append it.
            if not new_filename.endswith('.pkl'):
                new_filename += '.pkl'
            filename = new_filename
            output_path = os.path.join(output_dir, filename)
        # Otherwise, leave output_path as is to overwrite the existing file.

    # Create a partial abstraction function that fixes category and field.
    # The resulting function will have the signature (game, state, action, verbose)
    p_abstraction_func = partial(abstraction_wrapper, category=args.category, field=args.field)

    sequences = find_non_degenerate_binary_abstraction_states(args.game, p_abstraction_func, num_states=args.num_states)

    # Save the sequences as a pickle file.
    with open(output_path, 'wb') as f:
        pickle.dump(sequences, f)

    print("Interesting action sequences found:")
    for seq in sequences:
        print(seq)
    print(f"Results saved to {output_path}")