import numpy as np
from machiavelli.game.machiavelli_env import MachiavelliEnv 
import tqdm
from copy import deepcopy
import gym
from machiavelli.agent.load_agent import load_agent
import numpy as np
import matplotlib.pyplot as plt
import torch
from abstract_cf.sampling_utils import gumbel_max_rejection_sampling
from abstract_cf.sampling_utils import gumbel_max


def estimate_binary_abstraction_distribution(
    game: str,
    state: dict, 
    action_probs: np.array,
    p_abstraction_given_state_action: callable,
    verbose: bool = False,
) -> float:
    print(action_probs)
    actions = np.arange(len(action_probs))
    # P(Y | A=a, x)
    p_abstraction_given_state_actions = np.zeros(len(actions))
    for i, action in enumerate(actions): 
        state_copy = deepcopy(state)    # making sure we don't change the original state
        p = p_abstraction_given_state_action(game, state_copy, action, verbose)[1]
        p_abstraction_given_state_actions[i] = p
        # p_abstraction_given_state += action_probs[action] * p
    p_abstraction_given_state = np.dot(p_abstraction_given_state_actions, action_probs)
    return p_abstraction_given_state.item()


def compute_abstraction_given_state_actions_dist(
    game: str,
    state: dict, 
    actions: list[int],
    p_abstraction_given_state_action: callable,
    verbose: bool = False,
) -> float:
    '''
    simply utility to handle state copying. Evaluated the provided callable on all the actions provided with the given state 
    NOTE: assumes binary abstractions 
    '''
    # P(Y | A=a, x)
    p_abstraction_given_state_actions = np.zeros(len(actions))
    for i, action in enumerate(actions): 
        state_copy = deepcopy(state)    # making sure we don't change the original state
        p_abstraction_given_state_actions[i] = p_abstraction_given_state_action(game, state_copy, action, verbose)[1]
    return p_abstraction_given_state_actions 


def machiavelli_abstraction_given_state_action(
    game: str,
    state: dict, 
    action: int,
    category: str = 'morality',
    field: str = 'intending_harm',
    verbose: bool = False,
) -> np.array:
    ''' 
    returns the conditional distribuion P(Y | A, x) 
    That is, for the current state, the conditional distribution of abstractions given an action. 
    In the machiavelli case the enviornment is deterministic 
    # NOTE: we are just returning the annotation, thouhg some annotations are not in {0, 1} becuase they are the average of multiple samples. 
    This could be interpreted as a probability - for now we are just ignoring it. 
    '''
    env = MachiavelliEnv(game, agent_name='sampler')   # dummy env on which we will call `._set_state`
    env._set_state(state)
    _ = env.step(action)

    # TODO - the range of these annotations is pretty wild - reach out and make sure that they are intended as binary. 
    raw_annotation = env._get_info()['annotations'][category][field]
    abstraction_binary_value = 0 if raw_annotation < 0.5 else 1 
    probs = np.zeros(2) 
    probs[abstraction_binary_value] = 1
    if verbose: 
        print(f'p({field}=1 | A={action}, x) = {probs[1]}')
    return probs 


def run_sequence_of_actions(game, actions: list[int]) -> tuple[str, int, bool, dict, dict]:
    '''
    Arguments:
        - game: the game to be played 
        - actions: the sequence of actions to be taken
    Returns:
        - observation: the final observation after the sequence of actions
        - reward: the final reward after the sequence of actions
        - done: whether the game is done after the sequence of actions
        - info: the info dictionary of the final observation
        - state: a copy of the final environment state  
    '''
    env = MachiavelliEnv(game, agent_name='sampler')
    _, _ = env.reset()
    for action in actions: 
        obs, reward, done, info = env.step(action)
    state = env._get_state()
    return obs, reward, done, info, state 


def plot_categorical(probs, title: str) -> plt.figure:
    fig, ax = plt.subplots()
    ax.bar(np.arange(len(probs)), probs)
    ax.set_title(title)
    ax.set_xticks(np.arange(len(probs)))
    ax.set_xticklabels([f'{i}' for i in range(len(probs))])
    return fig


def estimate_posterior_cf_action_distribution(
    game: str,
    state_int: dict,
    action_probs_int: np.array,
    abstraction_probs_int: np.array,
    abstraction_value_cf_samples: list[int],
    category: str,
    field: str,
):
    p_actions_cf = [0.0] * len(action_probs_int)  # Assume action_probs_int contains P(A' | x')

    # pre-compute machiavelli_abstraction_given_state_action values
    precomputed_abstraction_values = [
        machiavelli_abstraction_given_state_action(game, state_int, action, category, field)
        for action in range(len(action_probs_int)) # NOTE: we can do this because action and action_index are the same in this context
    ]

    # iterate over counterfactual abstraction samples, abstraction_value_cf_samples = {y'_j ~ Y'}
    for abstraction_value_cf in tqdm.tqdm(abstraction_value_cf_samples):  
        # index corresponding to y'_i in abstraction
        abstraction_cf_index = abstraction_value_cf  

        # compute the posterior probability for each action A'
        for action_index in range(len(action_probs_int)):  
            # P_cf(A' | Y', x')
            posterior = (
                action_probs_int[action_index]  # π(A' | x')
                * precomputed_abstraction_values[action_index][abstraction_cf_index]  # γ(Y' | A', x')
                / abstraction_probs_int[abstraction_cf_index]  # γ(Y' | x')
            )
            p_actions_cf[action_index] += posterior  # Accumulate contributions from all y'_j

    # normalize 
    p_actions_cf = [p / len(abstraction_value_cf_samples) for p in p_actions_cf]
    return p_actions_cf


def compute_abstract_counterfactual(
    agent_name: str,
    game: str,
    factual_action_sequence: list[int],
    interventional_action_sequence: list[int],
    category: str, 
    field: str,
    observation_value: bool,
):
    agent_name = 'LMAgent:allenai/OLMo-1B-hf'
    agent = load_agent(agent_name)

    obs, reward, done, info, state = run_sequence_of_actions(game, factual_action_sequence)

    # STEP 1: estimate 'combined' mechanism, P(Y | A, x)
    # get P(A | X)
    with torch.no_grad():
        action, action_probs = agent.get_action(obs, reward, done, info, return_probs=True)
    action_probs = action_probs.cpu().numpy()

    # fig_A_given_x = plot_categorical(action_probs, 'P(A | x)')

    # constructing a lambda function for P(Y | a, x)
    abstraction_given_state_action = lambda game, state, action, verbose: machiavelli_abstraction_given_state_action(
        game, state, action, category, field, verbose
    )

    p_abstraction_given_state_actions = compute_abstraction_given_state_actions_dist(
        game, state, range(len(action_probs)), abstraction_given_state_action, verbose=True
    )
    p_abstraction = np.dot(p_abstraction_given_state_actions, action_probs)

    abstraction_probs = np.array([1-p_abstraction, p_abstraction]) 
    # fig_Y_given_A_x = plot_categorical(abstraction_probs, f'P(Y | A, x)')

    # STEP 2 - perform abduction to derive noise over combined distribution (on the factual case)
    # since Y is deterministic, this is the noise of the action distribution

    # TODO: this should be a sample from the abstraction distribution 
    # abstraction_value = 0
    abstraction_value = int(observation_value)
    _, G = gumbel_max_rejection_sampling(
        abstraction_probs, 
        abstraction_value, 
        n_samples=10000
    )

    # STEP 3 - same as (1) but instead we estimate for the counterfactual case P(Y' | A', x')
    obs_int, reward_int, done_int, info_int, state_int = run_sequence_of_actions(
        game, 
        interventional_action_sequence
    )

    with torch.no_grad():
        action_int, action_probs_int = agent.get_action(obs_int, reward_int, done_int, info_int, return_probs=True)
        action_probs_int = action_probs_int.cpu().numpy()
    # fig_A_given_x_int = plot_categorical(action_probs_int, 'P(A\' | X\')')

    p_abstraction_given_state_actions_int = compute_abstraction_given_state_actions_dist(
        game, state_int, range(len(action_probs_int)), abstraction_given_state_action, verbose=True
    )
    p_abstraction_int = np.dot(p_abstraction_given_state_actions_int, action_probs_int)
    abstraction_probs_int = np.array([1-p_abstraction_int, p_abstraction_int])

    # plotting the abstraction probabilities
    # fig_Y_given_A_x_int = plot_categorical(abstraction_probs_int, f'P(Y | A\', x\')')

    abstraction_value_cf_samples, _ = zip(
        *[gumbel_max(np.log(abstraction_probs_int), g) for g in G]
    )
    # TODO should plot this 'theoretical' counterfactual Y distribution as well? 

    # STEP 5 - Compute P_cf(A' | x') using sampled values y'_i ~ Y'
    p_actions_cf = estimate_posterior_cf_action_distribution(
        game,
        state_int,
        action_probs_int,
        abstraction_probs_int,
        abstraction_value_cf_samples,
        category,
        field,
    )

    experiment_data = {
        'factual': {
            'game_state': state,
            'text': obs,
            'action': action,
        },
        # in this case factual and counterfactual are the same thing
        'counterfactual':{
            'game_state': state_int,
            'text': obs_int,
            'action': action_int,
        }
    }
    distributions = {
        'factual_action_dist': action_probs,
        'factual_abstraction_dist': abstraction_probs,
        'factual_abstraction_given_state_action': p_abstraction_given_state_actions,
        'interventional_action_dist': action_probs_int,
        'interventional_abstraction_dist': abstraction_probs_int,
        'interventional_abstraction_given_state_action': p_abstraction_given_state_actions_int,
        # cf abstraction (empirical) distribution
        'counterfactual_action_dist': p_actions_cf   
    }
    return experiment_data, distributions 


def compute_simple_counterfactual(
    agent_name: str,
    game: str,
    factual_action_sequence: list[int],
    interventional_action_sequence: list[int],
    action_observation: int | None = None,
):
    agent_name = 'LMAgent:allenai/OLMo-1B-hf'
    agent = load_agent(agent_name)

    obs, reward, done, info, state = run_sequence_of_actions(game, factual_action_sequence)

    with torch.no_grad():
        action, action_probs = agent.get_action(obs, reward, done, info, return_probs=True)
    action_probs = action_probs.cpu().numpy()

    action_value = action_observation if action_observation is not None else action

    _, G = gumbel_max_rejection_sampling(
        action_probs, 
        action_value, 
        n_samples=10000
    )

    obs_int, reward_int, done_int, info_int, state_int = run_sequence_of_actions(
        game, 
        interventional_action_sequence
    )
    with torch.no_grad():
        action_int, action_probs_int = agent.get_action(obs_int, reward_int, done_int, info_int, return_probs=True)
        action_probs_int = action_probs_int.cpu().numpy()

    # Determine the target length based on action_probs_int.
    target_length = len(action_probs_int)
    current_length = len(G[0])  # All g in G have the same shape.

    # Adjust each noise array in G by padding with zeros or truncating.
    if current_length < target_length:
        G = [np.concatenate([g, np.zeros(target_length - current_length)]) for g in G]
    elif current_length > target_length:
        G = [g[:target_length] for g in G]
    # Else, they are already the correct length.

    action_value_cf_samples, _ = zip(*[gumbel_max(np.log(action_probs_int), g) for g in G])
    # NOTE: should plot this 'theoretical' counterfactual Y distribution as well? 
    # sum([a == abstraction_value for a in abstraction_value_cf_samples])

    # simple histogram of the action_value_cf_samples
    p_actions_cf = [action_value_cf_samples.count(a) / len(action_value_cf_samples) for a in range(len(action_probs_int))]


    experiment_data = {
        'factual': {
            'game_state': state,
            'text': obs,
            'action': action,
        },
        # in this case factual and counterfactual are the same thing
        'counterfactual':{
            'game_state': state_int,
            'text': obs_int,
            'action': action_int,
        }
    }
    distributions = {
        'factual_action_dist': action_probs,
        'interventional_action_dist': action_probs_int,
        # cf abstraction (empirical) distribution
        'counterfactual_action_dist': p_actions_cf   
    }
    return experiment_data, distributions 


