import numpy as np

# ------------------------------------------------------------------
# parameters
# ------------------------------------------------------------------
density = 1000
# endpoint = False keeps 1.0 out of the grid, hence 1/(1-step) never blows up
steps = np.linspace(0, 1, density, endpoint=False)

# ------------------------------------------------------------------
# helper: probability matrices for the transient states
# ------------------------------------------------------------------
def get_probabilities(state):
    """Return a 2 × 2 matrix with transition probabilities
       for the current (transient) state.
    """
    lookup = {
        (0, 0): np.array([[0.9, 0.1],
                          [0.1, 0.9]]),

        (1, 0): np.array([[1.0, 0.0],
                          [0.2, 0.8]]),

        (2, 0): np.array([[0.0, 1.0],
                          [0.3, 0.7]]),

        (0, 1): np.array([[0.8, 0.2],
                          [1.0, 0.0]]),

        (0, 2): np.array([[0.5, 0.5],
                          [0.0, 1.0]]),
    }

    try:
        return lookup[(state[0], state[1])]
    except KeyError:
        raise ValueError(f"State {tuple(state)} not covered by lookup table")

# ------------------------------------------------------------------
# one run of the two–coordinate chain
# ------------------------------------------------------------------
def chain_run():
    """
    Start in (0,0).  At each time step update coordinate 0 and/or 1 if
    that coordinate is still zero.  When both coordinates have become
    either 1 or 2 the function returns the final state.
    """
    state = np.array([0, 0], dtype=int)

    for step in steps:
        probs = get_probabilities(state)

        # update each coordinate that is still equal to 0
        for idx in (0, 1):
            if state[idx] == 0:
                scale = (1.0 / density) / (1.0 - step)

                # base probabilities for choosing 1 or 2
                p12 = scale * probs[idx]          # length-2 vector
                p0  = 1.0 - p12.sum()             # probability of staying at 0

                # numerical safety
                if p0 < 0:
                    p0 = 0.0
                prob_vec = np.array([p0, p12[0], p12[1]])
                prob_vec /= prob_vec.sum()         # make sure sums to 1

                state[idx] = np.random.choice([0, 1, 2], p=prob_vec)

        # if both coordinates are non–zero the chain has absorbed
        if state[0] != 0 and state[1] != 0:
            break

    return state

# ------------------------------------------------------------------
# experiment
# ------------------------------------------------------------------
def counter(n_runs=10000):
    """
    Run the chain n_runs times and count the four absorbing outcomes
    AA = (1,1), AB = (1,2), BA = (2,1), BB = (2,2).
    """
    tally = {'AA': 0, 'AB': 0, 'BA': 0, 'BB': 0}

    for i in range(n_runs):
        if 1000 and i % 1000 == 0:
            print(tally)


        s = chain_run()
        if   (s == (1, 1)).all():   tally['AA'] += 1
        elif (s == (1, 2)).all():   tally['AB'] += 1
        elif (s == (2, 1)).all():   tally['BA'] += 1
        elif (s == (2, 2)).all():   tally['BB'] += 1
        else:                       # should never happen
            print("Unexpected absorbing state:", s)

    return tally


if __name__ == "__main__":
    np.random.seed(0)          # for reproducibility; remove in production
    total_iters = 100000
    tally = counter(total_iters)
    print('Tally:', tally)
    probs = np.array(list(tally.values()))/total_iters
    cross_entropy=-np.array([0.15, 0.5, 0.05, 0.3])@np.log(probs)
    print('Cross Entropy:', cross_entropy)
