''' A state intervention is a transformation of the "initial state" as defined in the `abstract_counterfactual` module '''
from abstract_cf.text_generation.state import State
from copy import deepcopy
from transformers import AutoTokenizer, AutoModelForCausalLM
import ravfogel_lm_counterfactuals.utils as ravfogel_utils 
from abstract_cf.text_generation.utils import sample_from_model

import torch
import yaml 

device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

def gender_steering(
    state: State, 
    model_name: str | None = None,    # this is the interventional model name
    model: AutoModelForCausalLM | None = None,
    tokenizer_name: str | None = None,
    tokenizer: AutoTokenizer | None = None,
) -> State:
    int_state = deepcopy(state)
    assert model is not None or model_name is not None, "Either model or model_name must be provided"
    if model is None:
        model = ravfogel_utils.get_counterfactual_model(model_name)
    if tokenizer is None:
        # TODO does this tokenizer exist, or should we load the one from the factual model? 
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        tokenizer.pad_token = tokenizer.eos_token

    policy_function=lambda inputs, n_samples, batch_size, max_length: sample_from_model(
        model, 
        tokenizer, 
        inputs, 
        n_samples,
        batch_size=batch_size,
        max_length=max_length,
    )[0]
    int_state.policy_function = policy_function
    return int_state


def token_replacement_uniform(
        state: State,
        tokenizer_name: str,
    ) -> State:
    int_state = deepcopy(state)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    inputs = int_state.inputs
    # pick a random token to replace 
    b_size, seq_len = inputs['input_ids'].shape
    # to_replace = random.randint(0, seq_len)
    to_replace = torch.randint(0, seq_len, (b_size,))

    # for now pick a random token_id
    random_tokens = torch.randint(0, tokenizer.vocab_size, (b_size,))
    inputs['input_ids'][torch.arange(b_size), to_replace] = random_tokens

    int_state.inputs = inputs
    inputs_text = [
        tokenizer.decode(tokens)
        for tokens in int_state.inputs['input_ids']
    ]
    if len(inputs_text) == 1:
        inputs_text = inputs_text[0]
    int_state.inputs_text = inputs_text

    # we do not change the sampling function of the state
    return int_state


def _final_token_replacement_most_likely(
    inputs: dict,
    model_name: str | None = None,
    model: AutoModelForCausalLM | None = None,
) -> dict:
    assert model is not None or model_name is not None, "Either model or model_name must be provided"
    if model is None:
        model = AutoModelForCausalLM.from_pretrained(model_name)

    # Ensure that the model is on the same device as the inputs.
    device = inputs['input_ids'].device
    model = model.to(device)

    b_size, seq_len = inputs['input_ids'].shape

    # Determine the final token index for each example.
    # If an attention mask is provided, we use it to compute the number of non-padding tokens.
    # Otherwise, we assume that the entire sequence is valid.
    if 'attention_mask' in inputs:
        # attention_mask: shape (batch_size, seq_len), with 1 for valid tokens and 0 for padding.
        final_indices = (inputs['attention_mask'].sum(dim=1) - 1).tolist()
    else:
        final_indices = [seq_len - 1] * b_size

    for i in range(b_size):
        final_index = final_indices[i]
        original_token = inputs['input_ids'][i, final_index].item()

        # If there is no context (i.e. final_index is 0) then skip this example.
        if final_index == 0:
            continue

        # For an autoregressive model, to predict the token at the final position we provide the context
        # as the tokens preceding the final token.
        context = inputs['input_ids'][i, :final_index].unsqueeze(0)
        with torch.no_grad():
            outputs = model(context)
        # Obtain the logits for the prediction of the next token.
        logits = outputs.logits[0, -1, :].clone()

        # Exclude the original token from being selected.
        logits[original_token] = -float('inf')

        # Select the token with the highest logit.
        replacement_token = torch.argmax(logits)
        inputs['input_ids'][i, final_index] = replacement_token

    return inputs


def final_token_replacement_most_likely(
    state: State,
    model_name: str | None = None,
    model: AutoModelForCausalLM | None = None,
    tokenizer_name: str | None = None,
    tokenizer: AutoTokenizer | None = None,
) -> State:
    """
    For each example in the state's batch, replaces the final non-padding token with the most likely token
    predicted by the model (using a forward pass) given the preceding context, excluding the token
    that is already present.
    
    When padding is present, the final token is determined by the attention mask.
    
    Parameters:
        state: the original state to intervene upon.
        model_name: the name of the model to load (AutoModelForCausalLM).
        tokenizer_name: the name of the tokenizer corresponding to the model.
    
    Returns:
        A new State instance with the intervention applied.
    """
    assert tokenizer is not None or tokenizer_name is not None, "Either tokenizer or tokenizer_name must be provided"
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    assert model is not None or model_name is not None, "Either model or model_name must be provided"
    if model is None:
        model = AutoModelForCausalLM.from_pretrained(model_name)

    int_state = deepcopy(state)
    # we are re-encoding with the interventional tokenizer here
    int_inputs = tokenizer(
        int_state.inputs_text, 
        return_tensors='pt',
        truncation=True,
    ).to(device)
    inputs = _final_token_replacement_most_likely(
        int_inputs,
        model=model,
    )

    int_state.inputs = inputs
    # Update the state's text representation.
    inputs_text = [
        tokenizer.decode(tokens, skip_special_tokens=True)
        for tokens in int_state.inputs['input_ids']
    ]
    if len(inputs_text) == 1:
        inputs_text = inputs_text[0]
    int_state.inputs_text = inputs_text

    return int_state


_INTERVENTION_REGISTRY = {
    'gender_steering': gender_steering,
    'token_replacement_uniform': token_replacement_uniform,
    'final_token_replacement_most_likely': final_token_replacement_most_likely,
}


def get_intervention_fn(intervention_config: str) -> callable:
    with open(intervention_config, 'r') as f:
        config = yaml.safe_load(f)

    intervention_name = config['intervention_name']
    int_model_name = config['intervention_kwargs']['model_name']
    int_tokenizer_name = config['intervention_kwargs']['tokenizer_name']

    # we load them here so that we only have to do it once 
    if config['ravfogel_loading']:
        int_model = ravfogel_utils.get_counterfactual_model(int_model_name)
    else:
        int_model = AutoModelForCausalLM.from_pretrained(int_model_name)
    int_tokenizer = AutoTokenizer.from_pretrained(int_tokenizer_name)
    int_tokenizer.pad_token = int_tokenizer.eos_token

    return lambda state: _INTERVENTION_REGISTRY[intervention_name](
        state, 
        model=int_model,
        tokenizer=int_tokenizer,
    )

