#! /usr/bin/env python3

import random
import json
import argparse
from random import Random
from typing import Any

import sympy
from sympy import Symbol, symbols
from operator import add, sub, mul


def generate_instance_for_reasoning_gym(rng: Random, num_terms: int, min_value, max_value, max_target, num_iterations) -> tuple[sympy.Expr, list[int], list[Symbol]]:
    """
    Code taken from: https://github.com/open-thought/reasoning-gym/blob/main/reasoning_gym/games/countdown.py
    Generate an instance for the reasoning gym

    Args:
        rng: Random number generator
        num_terms: Number of terms to include
        min_value: Minimum value for the numbers
        max_value: Maximum value for the numbers
        max_target: Maximum target value
        num_iterations: Number of iterations to try

    Returns:
        A tuple of (expression, numbers, symbols)
    """
    for _ in range(num_iterations):
        try:
            expr, numbers, syms = _generate_candidate_expression_reasoning_gym(rng, num_terms,  min_value, max_value)

            # Substitute actual numbers to get target
            subs = {sym: num for sym, num in zip(syms, numbers)}
            target = int(expr.subs(subs))

            # Convert to string expression
            expr_str = str(expr)
            for i, sym in enumerate(syms):
                expr_str = expr_str.replace(str(sym), str(numbers[i]))

            # Ensure target is within bounds
            if 1 <= target <= max_target:
                e = {"state": numbers, "goal" : target}
                return e
        except (ValueError, ZeroDivisionError):
            continue
    return None

def _generate_candidate_expression_reasoning_gym(rng: Random, num_terms: int, min_value, max_value) -> tuple[sympy.Expr, list[int], list[Symbol]]:
    """
    Code taken from: https://github.com/open-thought/reasoning-gym/blob/main/reasoning_gym/games/countdown.py
    Generate a candidate expression with random numbers and operators

    Args:
        rng: Random number generator
        num_terms: Number of terms to include

    Returns:
        Tuple of (sympy expression, list of numbers, list of symbols)
    """
    operators = ("+", "-", "*", "/")
    # Generate random numbers
    numbers = [rng.randint(min_value, max_value) for _ in range(num_terms)]

    # Create symbols for building expression
    syms = symbols(f"x:{num_terms}")

    # Build random expression
    expr = syms[0]

    for i in range(1, num_terms):
        op = rng.choice(operators)
        if op == "+":
            expr = expr + syms[i]
        elif op == "-":
            expr = expr - syms[i]
        elif op == "*":
            expr = expr * syms[i]
        else:  # division
            # Handle division carefully to ensure integer results
            if numbers[i] != 0:  # Avoid division by zero
                # Get current value after substituting previous numbers
                current = int(expr.subs({sym: num for sym, num in zip(syms[:i], numbers[:i])}))
                # Try each remaining number to find one that divides evenly
                remaining = [n for n in numbers[i:] if n != 0]
                rng.shuffle(remaining)  # Randomize order for variety
                found_divisor = False
                for div in remaining:
                    if current % div == 0:  # Check if divides evenly
                        numbers[i] = div
                        expr = expr / syms[i]
                        found_divisor = True
                        break
                if not found_divisor:
                    # If no number divides evenly, fallback to subtraction
                    expr = expr - syms[i]
            else:
                # Fallback to addition for zero
                expr = expr + syms[i]
    return expr, numbers, syms

def generate_random_successor_state(state: list[int]) -> list[int]:
    """Generate a random successor state for a given state

    Args:
        state: The current state

    Returns:
        A list of successor states
    """
    def apply_operations(a: int, b: int) -> list[int]:
        """Apply all valid operations to the two numbers, with only valid divisions

        Args:
            a: The first number
            b: The second number

        Returns:
            A list of results
        """
        results = [add(a, b), mul(a, b)]
        if a >= b:
            results.append(sub(a,b))
        else:
            results.append(sub(b,a))

        if b != 0 and a % b == 0:
            results.append(a // b)
        if a != 0 and b % a == 0:
            results.append(b // a)
        return results

    def generate_random_new_state(state: list[int]) -> list[int]:
        """Generate a random new state for a given state by applying a random operation to two randomly selected numbers from the current state

        Args:
            state: The current state

        Returns:
            A new state
        """
        n = len(state)
        i = random.randrange(n-1)
        j = random.randrange(i + 1, n)
        op_result = random.choice(apply_operations(state[i], state[j]))
        new_state = [op_result] + [state[k] for k in range(n) if k != i and k != j]
        return new_state
    
    if len(state) == 1:
        return []
    else:
        return generate_random_new_state(state)


def get_state(state_size: int, num_upper_bound: int) -> list[int]:
    """Generate a random state for a given state size and upper bound on possible values of numbers

    Args:
        state_size: The size of the state
        num_upper_bound: The upper bound for the numbers

    Returns:
        A state
    """
    s = []
    for _ in range(state_size):
        s.append(random.randint(1,num_upper_bound))

    return s


def get_random_path_from_state(s: list[int]) -> int:
    """Generate a random path from a given state

    Args:
        s: The current state

    Returns:
        A random path
    """
    num_steps = len(s)
    state = s[:]
    # Randomly perform an action
    for _ in range(num_steps-1):
        state = generate_random_successor_state(state)

    assert len(state) == 1
    return state[0]


def get_smallest_key_with_min_value(res: dict[int, int]) -> int:
    """Get the smallest key with the minimum value

    Args:
        res: A dictionary of values

    Returns:
        The smallest key with the minimum value
    """
    min_value = min(res.values())
    min_keys = [k for k, v in res.items() if v == min_value]
    return min(min_keys)

def _get_least_and_most_frequent_val(res: dict[int, int]) -> tuple[int, int]:
    """Get the least frequent value and the most frequent value

    Args:
        res: A dictionary of values

    Returns:
        A tuple of (least frequent value, most frequent value)
    """
    return get_smallest_key_with_min_value(res), max(res, key=res.get)


def create_example(state_size: int, num_iterations: int, num_upper_bound: int, target_upper_bound: int) -> dict[str, Any]:
    """Create an example countdown problem

    Args:
        state_size: The size of the state
        num_iterations: The number of iterations to try
        num_upper_bound: The upper bound for the numbers
        target_upper_bound: The upper bound for the target

    Returns:
        An example
    """
    init = get_state(state_size, num_upper_bound)

    # Count the frequency of each value in the random paths
    res = {}
    for _ in range(num_iterations):
        r = get_random_path_from_state(init)
        if r < 0 or r > target_upper_bound:
            continue
        if r not in res:
            res[r] = 0
        res[r] += 1

    g,x = _get_least_and_most_frequent_val(res)

    e = {"state": init, "goal" : g, "extra" : { "min" : g, "min_freq" : res[g], "max" : x, "max_freq": res[x]}}
    return e


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-instances", type=int, default=100)
    parser.add_argument("--num-iterations", type=int, default=10000)
    parser.add_argument("--state-size", type=int, required=True)
    parser.add_argument("--upper-bound", type=int, default=99)
    parser.add_argument("--target-upper-bound", type=int, default=99)
    parser.add_argument("--seed", type=int, default=2025)
    parser.add_argument("--reasoning-gym", action='store_true')

    args = parser.parse_args()

    random.seed(args.seed)

    if args.reasoning_gym:
        # If generation method is reasoning gym, we use the code from https://github.com/open-thought/reasoning-gym/blob/main/reasoning_gym/games/countdown.py
        rng = Random()
        for i in range(args.num_instances):
            print(json.dumps(generate_instance_for_reasoning_gym(rng, args.state_size, 1, args.upper_bound, args.target_upper_bound, args.num_iterations)))
    else:
        # We create an instance that includes a random target value
        for i in range(args.num_instances):
            print(json.dumps(create_example(args.state_size, args.num_iterations, args.upper_bound, args.target_upper_bound)))
