from typing import Dict, Optional, Any, List
import torch
import torch.nn as nn
from tqdm import tqdm

from constants import DEVICE_CPU, TOKENIZER_RETURN_TENSORS
from project_utils.timers import SegmentTimer


def evaluate_language_model_wikitext2(
        model: nn.Module,
        tokenizer: Any,
        batch_size: int = 1,
        max_length: int = 2048,
        stride: int = 512,
        max_samples: Optional[int] = None,
        device: torch.device = torch.device(DEVICE_CPU),
        print_interval: int = 2,  # Print every N samples
) -> Dict[str, Any]:
    """
    Evaluate a causal LM on WikiText-2 using non-overlapping chunks,
    compatible with the reference implementation.
    """
    try:
        from datasets import load_dataset  # type: ignore[import]
    except ImportError:
        raise ImportError("Please install datasets: pip install datasets")

    eval_timer = SegmentTimer()

    # Load test split
    dataset = load_dataset('Salesforce/wikitext', 'wikitext-2-raw-v1', split='test')
    data_loading_time = eval_timer.segment("Dataset Loading", print_time=False)

    # Concatenate text and tokenize
    text = '\n\n'.join(dataset['text'])
    encodings = tokenizer(text, **{TOKENIZER_RETURN_TENSORS: 'pt'})
    tokenization_time = eval_timer.segment("Tokenization", print_time=False)

    input_ids = encodings.input_ids.to(device)
    seq_len = input_ids.size(1)

    # Use model's context length
    seqlen = min(
        max_length,
        getattr(model.config, 'max_position_embeddings', max_length),
    )

    model.eval()

    nsamples = seq_len // seqlen
    if max_samples is not None:
        nsamples = min(nsamples, max_samples)

    nlls = []
    inference_times: List[float] = []

    with torch.no_grad():
        pbar = tqdm(range(nsamples), desc="Evaluating perplexity")

        for i in range(nsamples):
            batch_timer = SegmentTimer()

            # Non-overlapping chunks
            batch = input_ids[:, (i * seqlen):((i + 1) * seqlen)]
            batch_timer.segment("data_transfer", print_time=False)

            # Forward pass - get logits without computing loss
            outputs = model(batch)
            lm_logits = outputs.logits

            inference_time = batch_timer.segment("inference", print_time=False) * 1000.0
            inference_times.append(inference_time)

            # Shift logits and labels (predict next token)
            shift_logits = lm_logits[:, :-1, :].contiguous()
            shift_labels = batch[:, 1:].contiguous()

            # Compute loss manually
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

            # Accumulate unnormalized NLL
            # loss is mean over (seqlen - 1) tokens, so multiply by (seqlen - 1)
            # But reference uses seqlen, likely accounting for the full sequence
            neg_log_likelihood = loss.float() * seqlen
            nlls.append(neg_log_likelihood)

            # Print intermediate results
            if (i + 1) % print_interval == 0 or (i + 1) == nsamples:
                current_nll = torch.stack(nlls).sum()
                current_tokens = (i + 1) * seqlen
                current_ppl = torch.exp(current_nll / current_tokens)
                avg_inf_time = sum(inference_times) / len(inference_times)

                print(f"\n[Sample {i + 1}/{nsamples}] "
                      f"PPL: {current_ppl.item():.4f} | "
                      f"Avg inference: {avg_inf_time:.2f}ms | "
                      f"Current batch: {inference_time:.2f}ms")

            pbar.update(1)

        pbar.close()

    # Compute perplexity: exp(sum(NLL) / total_tokens)
    total_nll = torch.stack(nlls).sum()
    total_tokens = nsamples * seqlen
    ppl = torch.exp(total_nll / total_tokens)

    total_eval_time = eval_timer.segment("Evaluation Complete", print_time=False)

    return {
        'perplexity': ppl.item(),
        'avg_inference_time_ms': sum(inference_times) / len(inference_times) if inference_times else 0.0,
        'total_tokens': total_tokens,
        'num_sequences': nsamples,
        'total_evaluation_time_seconds': total_eval_time,
        'data_loading_time_seconds': data_loading_time,
        'tokenization_time_seconds': tokenization_time,
    }

