import json
from pathlib import Path

import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
from tqdm import tqdm

from scipy import stats
from statsmodels.stats.contingency_tables import mcnemar


def get_word_probability(
    model, tokenizer, context, word_of_interest, bow_token_ids=None
):
    """
    Calculate the probability of a word given context using BOW tokenization correction.

    Args:
        model: The language model
        tokenizer: The tokenizer
        context: The text context before the word
        word_of_interest: The word whose probability we want to calculate

    Returns:
        float: The corrected probability of the word
    """

    # We strip the leading/trailing whitespaces of context and word_of_interest to be sure about the later tokenization
    context = context.rstrip().lstrip()
    word_of_interest = word_of_interest.rstrip().lstrip()

    # Tokenize context
    context_tokens = tokenizer(context, return_tensors="pt", add_special_tokens=False)[
        "input_ids"
    ].to(model.device)

    # Tokenize context + word together to get correct word tokenization
    full_text = context + " " + word_of_interest
    full_tokens = tokenizer(full_text, return_tensors="pt", add_special_tokens=False)[
        "input_ids"
    ].to(model.device)

    # Extract word tokens by taking tokens after context
    word_token_ids = full_tokens[0, context_tokens.shape[1] :].tolist()

    # Get logits for the full sequence
    outputs = model(full_tokens)
    logits = outputs.logits  # Shape: [batch, seq_len, vocab_size]

    # Calculate p(sw | sw<t) - probability of subword sequence
    log_prob_subwords = 0.0

    for i, token_id in enumerate(word_token_ids):
        # Position in the full sequence where we predict this token
        pos = context_tokens.shape[1] + i - 1

        # Get probabilities at this position
        probs = torch.softmax(logits[0, pos], dim=0)

        # Add log probability of this token
        log_prob_subwords += torch.log(probs[token_id])

    prob_subwords = torch.exp(log_prob_subwords)

    # Now apply BOW correction if needed (only when the bow_token_ids list contains items.)
    if len(bow_token_ids) > 0:
        # Calculate correction factor: numerator / denominator

        # Get position right after the word
        pos_after_word = context_tokens.shape[1] + len(word_token_ids) - 1
        probs_after_word = torch.softmax(
            logits[0, pos_after_word], dim=0
        )  # This gives us P(next_token | "context before word")

        # Get position at start of word (after context)
        pos_start = context_tokens.shape[1] - 1
        probs_at_start = torch.softmax(logits[0, pos_start], dim=0)

        # Calculate sums using matrix operations for efficiency
        bow_token_ids_tensor = torch.tensor(bow_token_ids)
        numerator = probs_after_word[bow_token_ids_tensor].sum()
        denominator = probs_at_start[bow_token_ids_tensor].sum()

        if denominator > 0:
            correction = numerator / denominator
            prob_word = prob_subwords * correction
        else:
            prob_word = prob_subwords

    else:
        # EOW marking - no correction needed
        prob_word = prob_subwords

    # cast to float for writeback
    if isinstance(prob_word, torch.Tensor) and prob_word.numel() == 1:
        prob_word = prob_word.item()

    return prob_word


def get_sentence_probs(model, tokenizer, sentence: str) -> dict:
    """Calculates the sentence LOG probability and the perplexity
    for sent: string"""

    sentence_logprob = 0.0
    token_logprobs = []

    # encode sentence
    tokens = (
        torch.tensor(tokenizer.encode(sentence), dtype=torch.long)
        .unsqueeze(dim=0)
        .to(model.device)
    )

    # get logits
    logits = model(tokens)[0]  # only use logits, not the loss

    # apply softmax to get probabilities
    log_probs = F.log_softmax(logits, dim=-1)  # (B,T,C)

    # Calculate log probability for each token (starting from the second token)
    for i in range(1, len(tokens[0])):  # ignore first token since it has no context
        token_id = tokens[0][i]
        token_logprob = log_probs[
            0, i - 1, token_id
        ].item()  # Prediction from position i-1
        token_logprobs.append(token_logprob)
        sentence_logprob += token_logprob

    perplexity = torch.exp(-torch.tensor(sentence_logprob) / len(token_logprobs))

    return {"Perplexity": perplexity.item(), "Sentence_logprob": sentence_logprob}


def get_hidden_states(model, tokenizer, text, layer=-1, pooling=None):
    """
    Get hidden states from a specific layer (before the LM head).

    Args:
        model: The language model
        tokenizer: The tokenizer
        text: Input text
        layer: Which layer to extract (-1 for final layer before LM head,
               0 for embedding layer, 1-N for transformer layers)
        pooling: How to aggregate across tokens:
               - None: Return all token hidden states (batch, seq_len, hidden_dim)
               - "last": Return last token's hidden state (batch, hidden_dim)
               - "mean": Return mean across tokens (batch, hidden_dim)
               - "first": Return first token's hidden state (batch, hidden_dim)

    Returns:
        torch.Tensor: Hidden states, shape depends on pooling strategy
    """
    tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    tokens = {k: v.to(model.device) for k, v in tokens.items()}

    with torch.no_grad():
        outputs = model(**tokens, output_hidden_states=True)

    # hidden_states is a tuple: (embedding_output, layer1, layer2, ..., final_layer)
    # outputs.hidden_states[-1] is the final transformer layer output BEFORE the lm_head
    hidden_states = outputs.hidden_states[layer]

    if pooling == "last":
        return hidden_states[:, -1, :]  # (batch, hidden_dim)
    elif pooling == "mean":
        return hidden_states.mean(dim=1)  # (batch, hidden_dim)
    elif pooling == "first":
        return hidden_states[:, 0, :]  # (batch, hidden_dim)
    else:
        return hidden_states  # (batch, seq_len, hidden_dim)


def get_steering_vector(model, tokenizer, positive_texts, negative_texts=None,
                        layer=-1, pooling="last"):
    """
    Create a steering vector from multiple examples.

    Averages hidden states across examples to get a robust concept vector.
    If negative_texts provided, returns (avg_positive - avg_negative).

    Args:
        model: The language model
        tokenizer: The tokenizer
        positive_texts: List of texts capturing the target concept
        negative_texts: Optional list of contrasting texts (recommended)
        layer: Which layer to extract from (-1 for final layer)
        pooling: How to pool each text ("last", "mean", "first")

    Returns:
        torch.Tensor: Steering vector of shape (hidden_dim,)

    Example:
        # Contrastive (recommended)
        pos = ["I am confident", "I know the answer", "I am certain"]
        neg = ["I am unsure", "I don't know", "I am uncertain"]
        vec = get_steering_vector(model, tokenizer, pos, neg, layer=12)

        # Single-sided (less robust)
        concepts = ["Amsterdam is the capital", "The capital is Amsterdam", ...]
        vec = get_steering_vector(model, tokenizer, concepts, layer=12)
    """
    # Collect positive representations
    pos_vecs = []
    for text in positive_texts:
        hidden = get_hidden_states(model, tokenizer, text, layer=layer, pooling=pooling)
        pos_vecs.append(hidden.squeeze(0))  # Remove batch dim -> (hidden_dim,)

    avg_positive = torch.stack(pos_vecs).mean(dim=0)  # (hidden_dim,)

    if negative_texts is None:
        return avg_positive

    # Collect negative representations
    neg_vecs = []
    for text in negative_texts:
        hidden = get_hidden_states(model, tokenizer, text, layer=layer, pooling=pooling)
        neg_vecs.append(hidden.squeeze(0))

    avg_negative = torch.stack(neg_vecs).mean(dim=0)

    # Return the direction from negative to positive
    return avg_positive - avg_negative


def get_model_layers(model):
    """
    Get the list of transformer layers from a model.
    Supports common architectures (OLMo, Llama, GPT-2, etc.)
    """
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return model.model.layers  # OLMo, Llama, Mistral
    elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
        return model.transformer.h  # GPT-2, GPT-Neo
    elif hasattr(model, "model") and hasattr(model.model, "decoder"):
        return model.model.decoder.layers  # OPT
    else:
        raise ValueError(f"Unknown model architecture: {type(model)}")


class SteeringHook:
    """
    Context manager for applying steering vectors to model hidden states.

    Usage:
        steering_vec = pos_vec - neg_vec  # Shape: (1, hidden_dim) or (hidden_dim,)

        with SteeringHook(model, layer=16, steering_vec=steering_vec, scale=1.0):
            outputs = model.generate(...)
    """

    def __init__(self, model, layer, steering_vec, scale=1.0, token_positions=None):
        """
        Args:
            model: The language model
            layer: Which layer to apply steering (0-indexed transformer layer)
            steering_vec: Vector to add, shape (hidden_dim,) or (1, hidden_dim)
            scale: Multiplier for the steering vector strength
            token_positions: Which token positions to steer:
                - None: All positions
                - "last": Only last token
                - list of ints: Specific positions
        """
        self.model = model
        self.layer = layer
        self.steering_vec = steering_vec.squeeze()  # Ensure shape is (hidden_dim,)
        self.scale = scale
        self.token_positions = token_positions
        self.hook_handle = None

    def _steering_hook(self, module, input, output):
        # output is typically (hidden_states, ...) or just hidden_states
        if isinstance(output, tuple):
            hidden_states = output[0]
            rest = output[1:]
        else:
            hidden_states = output
            rest = None

        # Apply steering vector
        vec = self.steering_vec.to(hidden_states.device) * self.scale

        if self.token_positions is None:
            # Add to all positions
            hidden_states = hidden_states + vec
        elif self.token_positions == "last":
            # Add only to last position
            hidden_states[:, -1, :] = hidden_states[:, -1, :] + vec
        elif isinstance(self.token_positions, list):
            # Add to specific positions
            for pos in self.token_positions:
                hidden_states[:, pos, :] = hidden_states[:, pos, :] + vec

        if rest is not None:
            return (hidden_states,) + rest
        return hidden_states

    def __enter__(self):
        layers = get_model_layers(self.model)
        self.hook_handle = layers[self.layer].register_forward_hook(self._steering_hook)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.hook_handle is not None:
            self.hook_handle.remove()
        return False


def get_steered_word_probability(model, tokenizer, context, word, steering_vec, layer,
                                  scale=1.0, token_positions=None, bow_token_ids=None):
    """
    Get the probability of a word given context, with steering applied.

    Args:
        model: The language model
        tokenizer: The tokenizer
        context: Input context text (before the word)
        word: The word whose probability to compute
        steering_vec: Vector to add, shape (hidden_dim,) or (1, hidden_dim)
        layer: Which layer to apply steering (0-indexed)
        scale: Multiplier for steering strength (default 1.0)
        token_positions: Where to apply steering (None=all, "last", or list of ints)
        bow_token_ids: List of beginning-of-word token IDs for BOW correction.
                       If provided, applies the same correction as get_word_probability.

    Returns:
        float: Probability of the word given the steered context

    Example:
        vec = get_steering_vector(model, tokenizer, pos_texts, neg_texts, layer=12)

        prob = get_steered_word_probability(
            model, tokenizer,
            context="The capital of France is",
            word="Amsterdam",
            steering_vec=vec,
            layer=12,
            scale=1.5,
            bow_token_ids=bow_token_ids
        )
    """
    context = context.rstrip().lstrip()
    word = word.rstrip().lstrip()

    context_tokens = tokenizer(context, return_tensors="pt", add_special_tokens=False)["input_ids"]
    context_tokens = context_tokens.to(model.device)

    # Tokenize context + word together to get correct word tokenization
    full_text = context + " " + word
    full_tokens = tokenizer(full_text, return_tensors="pt", add_special_tokens=False)["input_ids"]
    full_tokens = full_tokens.to(model.device)

    # Extract word token ids
    word_token_ids = full_tokens[0, context_tokens.shape[1]:].tolist()

    # Run forward pass with steering
    with SteeringHook(model, layer, steering_vec, scale, token_positions):
        with torch.no_grad():
            outputs = model(full_tokens)
            logits = outputs.logits

    # Calculate probability of word (product of subword probabilities)
    log_prob = 0.0
    for i, token_id in enumerate(word_token_ids):
        pos = context_tokens.shape[1] + i - 1
        probs = torch.softmax(logits[0, pos], dim=0)
        log_prob += torch.log(probs[token_id])

    prob_subwords = torch.exp(log_prob)

    # Apply BOW correction if bow_token_ids provided
    if bow_token_ids is not None and len(bow_token_ids) > 0:
        # Get position right after the word
        pos_after_word = context_tokens.shape[1] + len(word_token_ids) - 1
        probs_after_word = torch.softmax(logits[0, pos_after_word], dim=0)

        # Get position at start of word (after context)
        pos_start = context_tokens.shape[1] - 1
        probs_at_start = torch.softmax(logits[0, pos_start], dim=0)

        # Calculate correction factor
        bow_token_ids_tensor = torch.tensor(bow_token_ids, device=logits.device)
        numerator = probs_after_word[bow_token_ids_tensor].sum()
        denominator = probs_at_start[bow_token_ids_tensor].sum()

        if denominator > 0:
            correction = numerator / denominator
            prob_word = prob_subwords * correction
        else:
            prob_word = prob_subwords
    else:
        prob_word = prob_subwords

    # Cast to float
    if isinstance(prob_word, torch.Tensor) and prob_word.numel() == 1:
        prob_word = prob_word.item()

    return prob_word


def obtain_outputs(model, tokenizer, data, bow_token_ids, model_details, task="fb"):
    """
    Obtain model outputs for the given data.

    Args:
        model: The language model
        tokenizer: The tokenizer
        data: DataFrame containing the data
        bow_token_ids: List of token IDs that mark the beginning of words
        revision: Model information
        task: Task type (default: "fb")
        Returns:
        DataFrame: Results with probabilities for start and end words
    """
    output = []

    with torch.no_grad():
        # for _, row in data.iterrows():
        for _, row in tqdm(data.iterrows(), total=data.shape[0]):

            if task == "fb":
                context = row["passage"].replace("[MASK].", "")

                start_prob = get_word_probability(
                    model, tokenizer, context, row["start"], bow_token_ids
                )
                end_prob = get_word_probability(
                    model, tokenizer, context, row["end"], bow_token_ids
                )

                highest_prob_word = (
                    row["start"] if start_prob > end_prob else row["end"]
                )

                output.append(
                    {
                        "item_id": row["item_id"],
                        "item": row["item"],
                        **model_details,
                        "context": context,
                        "token_c1": row["start"],
                        "start_prob": start_prob,
                        "end_prob": end_prob,
                        "token_c2": row["end"],
                        "prediction": highest_prob_word,
                    }
                )

            else:  # task is blimp
                good_out = get_sentence_probs(model, tokenizer, row["sentence_good"])
                bad_out = get_sentence_probs(model, tokenizer, row["sentence_bad"])

                output.append(
                    {
                        "index": row["index"],
                        **model_details,
                        "p_good": good_out["Perplexity"],
                        "p_bad": bad_out["Perplexity"],
                        "logp_good": good_out["Sentence_logprob"],
                        "logp_bad": bad_out["Sentence_logprob"],
                        "correct": (
                            1
                            if good_out["Sentence_logprob"]
                            > bad_out["Sentence_logprob"]
                            else 0
                        ),
                    }
                )

    return output


def obtain_pmis(model, tokenizer, rm_data, model_details):
    output = []
    model.eval()
    for _, row in tqdm(rm_data.iterrows(), total=rm_data.shape[0]):

        item_id = row["item_id"]
        correct = row["correct_option"]
        incorrect = "a" if correct == "b" else "b"

        option_good = row[f"option_{correct}"]
        option_bad = row[f"option_{incorrect}"]

        story_id = row["story_id"]
        story_format = row["story_format"]
        story = row["story_text"]

        # Calculate naive average log likelihood of both options
        # Calculate average log-likelihood for option_good
        log_prob_good = get_sentence_probs(
            model, tokenizer, story + "\n" + option_good
        )["Sentence_logprob"]
        avg_log_prob_good = log_prob_good / len(
            tokenizer.encode(story + "\n" + option_good)
        )

        # Calculate average log-likelihood for option_bad
        log_prob_bad = get_sentence_probs(model, tokenizer, story + "\n" + option_bad)[
            "Sentence_logprob"
        ]
        avg_log_prob_bad = log_prob_bad / len(
            tokenizer.encode(story + "\n" + option_bad)
        )

        # Calculate PMI for option_good
        p_option_good = get_sentence_probs(model, tokenizer, option_good)[
            "Sentence_logprob"
        ]
        pmi_good = log_prob_good - p_option_good

        # Calculate PMI for option_bad
        p_option_bad = get_sentence_probs(model, tokenizer, option_bad)[
            "Sentence_logprob"
        ]
        pmi_bad = log_prob_bad - p_option_bad

        # Determine if the model chose the correct option
        model_pred = option_good if pmi_good > pmi_bad else option_bad
        model_correct = 1 if pmi_good > pmi_bad else 0

        output.append(
            {
                "item_id": item_id,
                "story_id": story_id,
                "story_format": story_format,
                "option_good": option_good,
                "option_bad": option_bad,
                **model_details,
                "average_log_likelihood_good": avg_log_prob_good,
                "average_log_likelihood_bad": avg_log_prob_bad,
                "pmi_good": pmi_good,
                "pmi_bad": pmi_bad,
                "model_pred": model_pred,
                "model_correct": model_correct,
            }
        )
    return output


def read_rm():
    """load recursive mind reading data"""
    data_dir = Path("data")
    df_questions = pd.read_csv(data_dir / "rm_questions.csv")
    df_stories = pd.read_csv(data_dir / "rm_stories.csv")

    df_data = pd.merge(df_stories, df_questions, on="story_id")
    df_data["item_id"] = (
        df_data["story_id"].astype(str)
        + "_"
        + df_data["story_format"]
        + "_"
        + df_data["question_format"]
        + "_"
        + df_data["question_type"]
        + "_"
        + df_data["question_level"].astype(str)
    )

    return df_data


def list_all_models_to_test():
    runs_dir = Path("runs")
    models_to_test = []

    files_with_model_names = [
        "additional_chkpt_models.jsonl",
        "base_instruct_models.jsonl",
        "olmo2_7B_all_branches.jsonl",
        "olmo2_sampled_branches.jsonl",
    ]

    for file_name in files_with_model_names:
        with open(runs_dir / file_name, "r") as f:
            for line in f:
                if not line or line.strip().startswith("#") or line.strip().startswith("//"):
                    continue
                models_to_test.append(json.loads(line))
    all_model_ids = [x["model_id"] for x in models_to_test]
    return all_model_ids


def extract_k2v2_stage(k2v2_stage_name, k2v2_model_id):
    """
    instruct: stage_3_0017500
    base: base_0090000
    """
    if "instruct" in k2v2_model_id.lower():
        parts = k2v2_stage_name.split("_")
        stage = parts[1]
        step = parts[2]
        return stage, step
    else:
        parts = k2v2_stage_name.split("_")
        step = parts[1]
        return None, step


def sample_checkpoints(df, stage1_samples=50, stage2_samples=5):
    """
    Sample rows with equal spacing across the token range.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame with columns: step, tokens, ingredient, stage
    stage1_samples : int
        Number of samples to take for stage 1 (default: 50)
    stage2_samples : int
        Number of samples to take per ingredient in stage 2 (default: 5)
    
    Returns:
    --------
    pd.DataFrame
        Sampled rows with equally spaced token values
    """
    
    sampled_indices = []
    
    # Process Stage 1
    stage1_df = df[df['train_stage'] == 1].copy()
    if len(stage1_df) > 0:
        min_tokens = stage1_df['tokens'].min()
        max_tokens = stage1_df['tokens'].max()
        target_tokens = np.linspace(min_tokens, max_tokens, stage1_samples)
        
        for target in target_tokens:
            idx = (stage1_df['tokens'] - target).abs().idxmin()
            if idx not in sampled_indices:
                sampled_indices.append(idx)

    # Process Stage 2 - shared targets based on minimum range, then extend for longer ranges
    stage2_df = df[df['train_stage'] == 2].copy()
    if len(stage2_df) > 0:
        # Find min and max token values for each ingredient
        ingredient_ranges = {}
        for ingredient in sorted(stage2_df['ingredient'].dropna().unique()):
            ingredient_df = stage2_df[stage2_df['ingredient'] == ingredient]
            ingredient_ranges[ingredient] = {
                'min': ingredient_df['tokens'].min(),
                'max': ingredient_df['tokens'].max()
            }
        
        # Find the shortest max value (common max across all ingredients)
        overall_min = min(r['min'] for r in ingredient_ranges.values())
        common_max = min(r['max'] for r in ingredient_ranges.values())

        # Create shared targets from overall_min to common_max
        shared_targets = list(np.linspace(overall_min, common_max, stage2_samples))
        
        # For each ingredient, use shared targets + additional targets if it extends beyond
        for ingredient in sorted(ingredient_ranges.keys()):
            ingredient_df = stage2_df[stage2_df['ingredient'] == ingredient]
            ing_max = ingredient_ranges[ingredient]['max']
            
            # Start with shared targets
            targets = shared_targets.copy()
            
            # If this ingredient extends beyond the common max, add more targets
            if ing_max > common_max:
                # Calculate how many additional samples needed
                # We want stage2_samples total, but already have them up to common_max
                # So add more from common_max to ing_max
                additional_targets = np.linspace(common_max, ing_max, stage2_samples)
                # Skip the first one (common_max) since it's already in shared_targets
                targets.extend(additional_targets[1:])
            
            # Sample at each target
            for target in targets:
                idx = (ingredient_df['tokens'] - target).abs().idxmin()
                if idx not in sampled_indices:
                    sampled_indices.append(idx)

    # Return subset of original dataframe
    result = df.loc[sampled_indices]
    
    return result

    



    