from abstract_cf.text_generation.learned_abstraction import LearnedAbstractionPipeline
import torch
import torch.nn.functional as F
import tqdm
from typing import Union, List
from transformers import AutoModelForCausalLM, AutoTokenizer
from enum import Enum


device = 'cuda' if torch.cuda.is_available() else 'mps'  if torch.backends.mps.is_available() else 'cpu'


# is this all?  
male_pronouns = {'he', 'him', 'his', 'himself'}
female_pronouns = {'she', 'her', 'hers', 'herself'}


def check_matching_ids(abstract_samples: dict, token_samples: dict):
    """
    Check that the sample IDs match between the abstract and token-level samples.
    """
    abstract_ids = set(abstract_samples.keys())
    token_ids = set(token_samples.keys())
    if abstract_ids != token_ids:
        missing_in_abstract = token_ids - abstract_ids
        missing_in_token = abstract_ids - token_ids
        raise ValueError(f"Sample ID mismatch detected! Missing in abstract: {missing_in_abstract}, missing in token: {missing_in_token}")



class PronounCategory(Enum):
    MIXED = 'mixed'
    ONLY_MALE = 'only_male'
    ONLY_FEMALE = 'only_female'
    NO_PRONOUNS = 'no_pronouns'


def check_pronoun_category(generation: str) -> str:
    words = generation.lower().split()
    n_male_pronouns = sum(1 for word in words if word in male_pronouns)
    n_female_pronouns = sum(1 for word in words if word in female_pronouns)

    if n_male_pronouns > 0 and n_female_pronouns > 0:
        return PronounCategory.MIXED
    elif n_male_pronouns > 0:
        return PronounCategory.ONLY_MALE
    elif n_female_pronouns > 0:
        return PronounCategory.ONLY_FEMALE
    else:
        return PronounCategory.NO_PRONOUNS


def count_pronoun_categories(generations: list[str]) -> dict[str, int]:
    """
    Count pronoun categories from a list of generated texts.

    Parameters:
    generations (list of str): A list of strings containing the generated text.
    Returns:
    dict: A dictionary with counts for:
          'n_only_male', 'n_only_female', 'n_mixed', and 'no_pronouns'
    """
    counts = {
        PronounCategory.ONLY_MALE.value: 0,
        PronounCategory.ONLY_FEMALE.value: 0,
        PronounCategory.MIXED.value: 0,
        PronounCategory.NO_PRONOUNS.value: 0
    }

    for gen in generations:
        pronoun_category = check_pronoun_category(gen).value
        counts[pronoun_category] += 1
    return counts


def compute_perplexity(
    texts: Union[str, List[str]], 
    model: AutoModelForCausalLM, 
    tokenizer: AutoTokenizer,
    stride: int = 128,
    show_pbar: bool = False,
) -> torch.Tensor:
    # adapted from XXXX 
    """
    Computes per-example perplexity for a batch of texts using a sliding window,
    matching the internal label-shift used by the model.

    Args:
        texts (str or List[str]): A single text or a list of texts.
        model (AutoModelForCausalLM): The causal language model.
        tokenizer (AutoTokenizer): The tokenizer corresponding to the model.
        stride (int): The stride used in the sliding window.
        device (torch.device): The device to run the model on.
        show_pbar (bool): If True, display a progress bar; otherwise, disable it.
        
    Returns:
        A 1D tensor of perplexities (one per input text).
    """
    # Ensure texts is a list.
    if isinstance(texts, str):
        texts = [texts]
    
    # Batch tokenization with padding.
    encodings = tokenizer(texts, padding=True, return_tensors='pt')
    input_ids = encodings.input_ids.to(device)         # shape: (B, L)
    attention_mask = encodings.attention_mask.to(device)   # shape: (B, L)
    B, L = input_ids.size()
    max_model_length = model.config.n_positions  # maximum window size
    
    # Effective lengths (non-padding tokens) per example.
    lengths = attention_mask.sum(dim=1)  # shape: (B,)
    
    # Initialize accumulators per example.
    nll_sum = torch.zeros(B, device=device)
    token_count = torch.zeros(B, device=device)
    
    # Track how many tokens have been scored for each example.
    prev_end = torch.zeros(B, dtype=torch.long, device=device)
    
    # Use tqdm with the disable flag set to the inverse of show_pbar.
    for begin_loc in tqdm.tqdm(range(0, L, stride), disable=not show_pbar):
        end_loc = min(begin_loc + max_model_length, L)
        window_size = end_loc - begin_loc
        
        # Get current window.
        window_input_ids = input_ids[:, begin_loc:end_loc]
        window_attention_mask = attention_mask[:, begin_loc:end_loc]
        
        # For each example, determine the effective end (if text ended before end_loc).
        current_end = torch.minimum(torch.full((B,), end_loc, device=device), lengths)
        # Number of new tokens for each example in this window.
        valid_len = current_end - prev_end  # shape: (B,)
        
        # Prepare target_ids: start as a clone of the window.
        target_ids = window_input_ids.clone()
        # For each example, mask out tokens that have been scored before.
        for i in range(B):
            if valid_len[i] <= 0:
                target_ids[i, :] = -100
            else:
                num_to_mask = window_size - valid_len[i]
                if num_to_mask > 0:
                    target_ids[i, :num_to_mask] = -100
        
        # If no example has any new token in this window, skip.
        if (valid_len <= 0).all():
            continue

        # Forward pass to get logits.
        with torch.no_grad():
            outputs = model(window_input_ids, attention_mask=window_attention_mask)
            logits = outputs.logits  # shape: (B, window_size, vocab_size)
        
        # ---- Apply label shifting to match internal loss computation ----
        # The model's loss is computed on positions 1...window_size (using targets from 1...window_size)
        # compared to logits for positions 0...window_size-1.
        if window_size < 2:
            # Not enough tokens to shift; skip this window.
            continue
        shift_logits = logits[:, :-1, :].contiguous()    # shape: (B, window_size-1, vocab_size)
        shift_targets = target_ids[:, 1:].contiguous()     # shape: (B, window_size-1)
        
        # Compute per-token loss (no reduction)
        loss_vals = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_targets.view(-1),
            reduction='none'
        ).view(B, shift_logits.size(1))  # shape: (B, window_size-1)
        
        # Create mask for valid positions (where target != -100)
        mask = shift_targets != -100  # shape: (B, window_size-1)
        loss_sum = (loss_vals * mask.float()).sum(dim=1)  # per-example sum over window
        num_tokens = mask.sum(dim=1).float()  # per-example count of tokens in this window
        
        # Accumulate losses and token counts.
        nll_sum += loss_sum
        token_count += num_tokens
        
        # Update pointer: all tokens in window have now been scored.
        prev_end = current_end.clone()
        if torch.all(prev_end >= lengths):
            break

    # Compute average negative log-likelihood per token per example.
    avg_nll = nll_sum / token_count
    # ppl = torch.exp(avg_nll)
    # return ppl
    return avg_nll


# Y sample space constructed from both 
task_ids = {
    'gender_steering': {
        'unsupervised': {
            'gpt2-xl': {
                # online Y 
                'acf_task_id': '85162ea6fde341c1a68f923e9d98d0e9',
                'tlcf_task_id': 'ecc60770b70a45e2976245911576e843'   
            },
        },
        'supervised': {
            'gpt2-xl': {
                'acf_task_id': 'b0daf401104c46269a94c416ed71ad21',
                'tlcf_task_id': '97603626245c496ca0771ff6d236e20b'
            },
        },
        # 'tlcf': {
        #     'gpt2-xl': {
        #         'acf_task_id': '',
        #         'tlcf_task_id': ''
        #     },
        # }
    },

    'token_replacement': {
        'unsupervised': {
            'gpt2-xl': {
                'acf_task_id': 'e1dcada3add54d4bb67e8947f758c27d',
                'tlcf_task_id': 'eb6b8bdd02b14879a0800cc9837ab274'
            },
            'llama-3.2-1B': {
                'acf_task_id': '40547e8686cb4f338bfb5322d77041e4',
                'tlcf_task_id': '06944d08915a48a199b5c197588a3dfa'
            }
        },
        'supervised': {
            'gpt2-xl': {
                'acf_task_id': '0050d418db4f4412b2067b3edc72a662',
                'tlcf_task_id': 'e6e055b00f9a4c599fa7ea471cfc296e'
            },
            'llama-3.2-1B': {
                'acf_task_id': '78f65b271a924798be82e3118dcfdf66',
                'tlcf_task_id': 'a45fa03f4d234f6da2ecb88905fe40d0'
            }
        },
    }
}

supervised_abstraction_paths = {
    'gender_steering': '../model_data/learned_abstractions/profession/checkpoint-5994',
    'token_replacement': '../model_data/learned_abstractions/emotion/checkpoint-568',
}