import torch
from utils.data_utils import *
import numpy as np
from tqdm import tqdm


def get_model_memory(model):
    """
    Compute the total memory usage of a model (parameters + buffers).
    Returns size in bytes.
    """
    param_size = 0
    for param in model.parameters():
        # Number of elements × element size (bytes)
        param_size += param.nelement() * param.element_size() 
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    total_size = param_size + buffer_size
    return total_size  

@torch.inference_mode()
def ppl_eval(model, tokenizer, datasets=['wikitext2', 'ptb', 'c4'],
             model_seq_len=2048, batch_size=32, device="cuda"):
    """
    Evaluate model perplexity (PPL) on LM benchmarks (WikiText2, PTB, C4).
    """
    model.to(device).eval()
    memory_bytes = get_model_memory(model)
    print(f"memory usage: {memory_bytes / (1024**3):.2f} GB")

    ppls = {}
    for name in datasets:
        # Load evaluation dataset
        test_loader = get_test_data(name, tokenizer, seq_len=model_seq_len, batch_size=batch_size)

        total_nll = 0.0  # negative log-likelihood sum
        total_tok = 0    # number of valid tokens

        for bi, batch in enumerate(test_loader):
            batch = batch.to(device)  # [B, T]
            # Forward pass
            out = model(input_ids=batch, use_cache=False)
            logits = out.logits.float()        # [B, T, V]
            # Shift labels by one token
            labels = batch[:, 1:].contiguous() # [B, T-1]
            logits = logits[:, :-1, :].contiguous()
            # Mask out padding tokens
            pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 0
            mask = (labels != pad_id)          # [B, T-1]

            if mask.any():
                V = logits.size(-1)
                valid_logits = logits[mask].view(-1, V)   # [N_valid, V]
                valid_labels = labels[mask].view(-1)      # [N_valid]
                # Cross-entropy loss (sum over tokens)
                loss_fct = torch.nn.CrossEntropyLoss(reduction="sum")
                loss = loss_fct(valid_logits, valid_labels)

                total_nll += loss.item()
                total_tok += valid_labels.numel()
                # Debug info for first few batches
                if bi < 2:
                    max_logit = float(logits.abs().max().item())
                    print(f"[ppl-debug:{name}] batch={bi}, max|logit|={max_logit:.2f}, "
                          f"valid_tok={valid_labels.numel()}, batch_nll={loss.item():.2f}")
        # Compute perplexity: exp(total NLL / number of tokens)
        ppl = float(np.exp(total_nll / max(total_tok, 1)))
        ppls[name] = ppl

    print("PPL after pruning:", ppls)
    return ppls

@torch.inference_mode()
def accuracy_eval(
    model,
    tokenizer,
    datasets=['arc_easy', 'winogrande',
              'hellaswag',  'piqa'],
    model_seq_len=2048,
    batch_size=32,
    device="cuda",
    max_samples=1000,
    local_cache_root="/root/.cache/huggingface",   
    offline=True                                  
):
    """
    Evaluate multiple-choice accuracy on commonsense reasoning benchmarks.
    Includes ARC-Easy, WinoGrande, HellaSwag, and PIQA.
    """
    if offline:
        # Run in offline mode (use local cache, no network download)
        os.environ.setdefault("HF_DATASETS_OFFLINE", "1")
        os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")

    model.to(device).eval()
    accuracies = {}
    # Dataset configs: (dataset name, subset, split)
    cfg = {
        'arc_easy':       ('ai2_arc', 'ARC-Easy', 'test'),
        'winogrande':     ('winogrande', 'winogrande_xl', 'validation'),
        'hellaswag':      ('Rowan/hellaswag', None, 'validation'),
        'piqa':           ('baber/piqa', None, 'validation'),
    }
    # Fallback configs in case local cache fails
    fallbacks = {
        'arc_easy': [
            ('ai2_arc', 'ARC-Easy', 'validation')
        ],
        'winogrande': [
            ('winogrande', 'winogrande_xl', 'validation'),
        ],
        'hellaswag': [
            ('hellaswag', None, 'validation'),     
        ],
        'piqa': [
            ('piqa', None, 'validation'),          
        ]
    }

    def _try_load_local(ds_name, subset, split):
        """Try to load dataset from local cache."""
        kwargs = {
            "split": split,
            "download_mode": "reuse_dataset_if_exists",
            "cache_dir": os.path.join(local_cache_root, "datasets"),
        }
        if subset is None:
            return load_dataset(ds_name, **kwargs)
        else:
            return load_dataset(ds_name, subset, **kwargs)

    def _safe_load(name):
        """Try loading dataset with fallbacks if the primary config fails."""
        ds_name, subset, split = cfg[name]
        tries = [(ds_name, subset, split)] + fallbacks.get(name, [])
        last_err = None
        for n, sub, sp in tries:
            try:
                return _try_load_local(n, sub, sp)
            except Exception as e:
                last_err = e
                continue
        print(f"[accuracy_eval] Skipping {name}: local cache could not be loaded (both primary and fallback failed). Last error: {last_err}")
        return None

    def preprocess(example, name):
        """Convert raw dataset example into a unified multiple-choice format."""
        if name == 'winogrande':
            ans = int(str(example['answer']).strip()) - 1
            return {
                'question': example['sentence'],
                'choices': [example['option1'], example['option2']],
                'answer': max(0, min(1, ans)),
                'template': 'winogrande',
            }
        if name in ('arc_easy'):
            return {
                'question': example['question'],
                'choices': list(example['choices']['text']),
                'answer': ord(example['answerKey']) - ord('A'),
                'template': 'science',
            }
        if name == 'hellaswag':
            return {
                'question': example['ctx'],
                'choices': list(example['endings']),
                'answer': int(example['label']),
                'template': 'hellaswag',
            }
        if name == 'piqa':
            return {
                'question': example['goal'],
                'choices': [example['sol1'], example['sol2']],
                'answer': int(example['label']),
                'template': 'piqa',
            }
        return None

    def render_prompt(q, c, tmpl):
        """Render input text prompt depending on dataset template."""
        if tmpl == 'winogrande':
            return q.replace("_", c)
        if tmpl == 'science':
            return f"Science Question: {q}\nPossible Answer: {c}"
        if tmpl == 'hellaswag':
            return f"Context: {q}\nContinuation: {c}"
        if tmpl == 'piqa':
            return f"Goal: {q}\nSolution: {c}"
        return f"{q}\nThe correct answer is: {c}"

    loss_fct = torch.nn.CrossEntropyLoss(reduction='sum')

    for name in datasets:
        test_set = _safe_load(name)
        if test_set is None:
            continue

        ds_name, subset, split = cfg[name]
        # Truncate validation set for faster evaluation
        if split != 'test':
            test_set = test_set.select(range(min(max_samples, len(test_set))))

        correct, total = 0, 0
        # Iterate over dataset
        for ex in tqdm(test_set, desc=f"Evaluating {name}"):
            proc = preprocess(ex, name)
            if proc is None:
                continue
            q, choices, ans, tmpl = proc['question'], proc['choices'], proc['answer'], proc['template']

            scores = []
            # Process each choice independently
            for i in range(0, len(choices), batch_size):
                batch_choices = choices[i:i + batch_size]
                batch_inputs = []
                for c in batch_choices:
                    text = render_prompt(q, c, tmpl)
                    inputs = tokenizer(
                        text,
                        max_length=model_seq_len,
                        truncation=True,
                        return_tensors="pt"
                    )
                    batch_inputs.append(inputs)
                # Forward pass for each choice
                for inp in batch_inputs:
                    inp = {k: v.to(device) for k, v in inp.items()}
                    out = model(**inp)
                    logits = out.logits.float()        # [1, T, V]
                    labels = inp['input_ids'][:, 1:]   # [1, T-1]
                    logits = logits[:, :-1, :]         
                    # Align sequence lengths
                    T = min(logits.size(1), labels.size(1))
                    logits = logits[:, :T, :]
                    labels = labels[:, :T]
                    # Mask out padding tokens
                    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 0
                    mask = (labels != pad_id)          # [1, T]
                    if mask.any():
                        V = logits.size(-1)
                        valid_logits = logits[mask].view(-1, V)
                        valid_labels = labels[mask].view(-1)
                        loss = loss_fct(valid_logits, valid_labels)
                        scores.append(-loss.item()) # higher score = better
                    else:
                        scores.append(float('-inf'))
            # Choose answer with highest score (lowest loss)
            pred = int(np.argmax(scores)) if name != 'mathqa' else int(np.argmin(scores))
            correct += int(pred == ans)
            total += 1
        # Compute accuracy
        acc = 100.0 * correct / max(total, 1)
        accuracies[name] = acc
        print(f"{name} accuracy: {acc:.2f}%")

    print("Accuracy after pruning:", {k: f"{v:.2f}%" for k, v in accuracies.items()})
    return accuracies