import torch
import numpy as np
from transformers import LogitsProcessorList, TopKLogitsWarper, PreTrainedTokenizer, PreTrainedTokenizerFast
from logits_processors.primal_threshold_processor_fix_k import PrimalThresholdProcessor
from pathlib import Path
from typing import List, Union, Optional, Dict, Tuple

def calculate_perplexity(model, tokenizer, texts: Union[str, List[str]], device: str, only_generated: bool = True, prefix_length: int = 0, max_length: Optional[int] = None) -> List[float]:
    if isinstance(texts, str):
        texts = [texts]
    
    original_padding_side = tokenizer.padding_side
    tokenizer.padding_side = "right"
    
    encodings = tokenizer(
        texts, 
        return_tensors="pt", 
        padding=True,
        truncation=True,
        max_length=max_length
    ).to(device)
    
    tokenizer.padding_side = original_padding_side
    
    attention_mask = encodings.attention_mask
    
    if only_generated and prefix_length > 0:
        loss_mask = torch.zeros_like(attention_mask)
        for i in range(attention_mask.shape[0]):
            seq_len = attention_mask[i].sum().item()
            if seq_len > prefix_length:
                loss_mask[i, prefix_length:seq_len] = 1
    else:
        loss_mask = attention_mask.clone()
    
    with torch.no_grad():
        outputs = model(**encodings)
    
    logits = outputs.logits
    targets = encodings.input_ids
    
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = targets[..., 1:].contiguous()
    shift_mask = loss_mask[..., 1:].contiguous()
    
    loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
    losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    
    losses = losses.view(shift_labels.shape)
    losses = losses * shift_mask
    
    perplexities = []
    for i in range(losses.shape[0]):
        sample_losses = losses[i]
        sample_mask = shift_mask[i]
        if sample_mask.sum() > 0:
            mean_loss = sample_losses.sum() / sample_mask.sum()
            perplexity = torch.exp(mean_loss).item()
        else:
            perplexity = float('nan')
        perplexities.append(perplexity)
    
    return perplexities


def find_repeated_substrings(tokens: List[int], min_len: int = 2) -> bool:
    n = len(tokens)
    if n < min_len * 2:
        return False

    for length in range(min_len, n // 2 + 1):
        for i in range(n - 2 * length + 1):
            if tokens[i:i+length] == tokens[i+length : i+2*length]:
                return True
    return False


def generate_text_with_processor(
    model, 
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 
    prefix_ids: torch.Tensor, 
    max_length: int, 
    processor: Optional[Union[TopKLogitsWarper, PrimalThresholdProcessor, List[Union[TopKLogitsWarper, PrimalThresholdProcessor]]]] = None, 
    temperature: float = 1.0,
    device: str = "cuda",
    batch_size: Optional[int] = None
) -> List[str]:
    if max_length <= 0:
        raise ValueError("max_length must be positive")
    
    prefix_ids = prefix_ids.to(device)
    
    attention_mask = torch.ones_like(prefix_ids)
    
    generation_kwargs = {
        "max_length": max_length,
        "do_sample": temperature > 0,
        "temperature": None,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "attention_mask": attention_mask,
    }
    
    logits_processor = LogitsProcessorList()
    
    if temperature > 0:
        generation_kwargs["temperature"] = temperature
    
    if isinstance(processor, TopKLogitsWarper):
        generation_kwargs["top_k"] = processor.top_k
        generation_kwargs["top_p"] = 1.0 
    elif isinstance(processor, PrimalThresholdProcessor):
        logits_processor.append(processor)
        generation_kwargs["top_k"] = 0 
        generation_kwargs["top_p"] = 1.0
    elif processor is not None:
        if isinstance(processor, list):
            logits_processor.extend(processor)
        else:
            logits_processor.append(processor)
        generation_kwargs["top_k"] = 0
        generation_kwargs["top_p"] = 1.0

    if len(logits_processor) > 0:
        generation_kwargs["logits_processor"] = logits_processor
    
    with torch.no_grad():
        outputs = model.generate(prefix_ids, **generation_kwargs)
    
    generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return generated_texts


def _transform_results_to_new_format(results_list: List[Dict], all_processor_names_list: List[str]) -> Tuple[Dict[str, List[str]], Dict[str, Dict[str, List[Union[float, None]]]]]:
    texts_output = {"original": []}
    perplexity_output = {"original": []}
    repetition_output = {"original": []}

    for processor_name in all_processor_names_list:
        texts_output[processor_name] = []
        perplexity_output[processor_name] = []
        repetition_output[processor_name] = []

    for res_item in results_list:
        texts_output["original"].append(res_item.get("original_text", ""))
        perplexity_output["original"].append(res_item.get("original_perplexity", float('nan')))
        repetition_output["original"].append(res_item.get("original_has_repetition", None))

        for processor_name in all_processor_names_list:
            gen_info = res_item.get("generations", {}).get(processor_name)
            if gen_info:
                texts_output[processor_name].append(gen_info.get("text", ""))
                perplexity_output[processor_name].append(gen_info.get("perplexity", float('nan')))
                repetition_output[processor_name].append(gen_info.get("has_repetition", None))
            else:
                texts_output[processor_name].append("")
                perplexity_output[processor_name].append(float('nan'))
                repetition_output[processor_name].append(None)
    
    metrics_output_final = {
        "perplexity_data": perplexity_output,
        "repetition_data": repetition_output
    }
    return texts_output, metrics_output_final
