import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoConfig
from tqdm import tqdm
import gc
import os

# from models.llama3.modeling_llama3_pruned import LlamaForCausalLM as PrunedLlamaForCausalLM
from models.llama2.modeling_llama2_pruned import LlamaForCausalLM as PrunedLlamaForCausalLM
from models.phi2.modeling_phi2_pruned import PhiForCausalLM as PrunedPhiForCausalLM

# import multi-choice evaluation functions from dataset_loader
from lib.dataset_loader import (
    load_mc_dataset, 
    format_mc_example, 
    evaluate_mc_example, 
    evaluate_mc_dataset,
    build_wikitext_ids,
    build_c4_ids,
    build_ptb_ids,
    sample_wikitext_sequences,
    calculate_perplexity
)

# Set seed for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    random.seed(seed)

# ===================== Dataset Loading Functions =====================

# Load and process wikitext2 dataset
def get_wikitext2(tokenizer, seqlen=2048, split='test'):
    # Load dataset
    if split == 'train':
        data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
        # Process training data
        text = " ".join(data['text'])
    else:
        data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
        # Process test data
        text = "\n\n".join(data['text'])
    
    # Tokenize the data
    encoded = tokenizer(text, return_tensors='pt')
    
    print(f"Loaded WikiText-2 ({split}): {encoded.input_ids.shape[1]} tokens")
    
    return encoded

def get_ptb(tokenizer, seqlen=2048, split='test'):
    ds = load_dataset("ptb_text_only", "penn_treebank", split=split)
    field = "sentence" if "sentence" in ds.column_names else "text"
    text = "\n\n".join(ds[field])
    encoded = tokenizer(text, return_tensors='pt', add_special_tokens=True)
    print(f"Loaded PTB ({split}): {encoded.input_ids.shape[1]} tokens")
    return encoded

def get_c4(tokenizer, seqlen=2048, split='test'):

    ds = load_dataset("json", data_files="xxx/cache/huggingface/datasets/c4-validation.00000-of-00008.json.gz", split="train")
    texts = ds["text"]  # Use all validation samples
    
    full_text = "\n\n".join(texts)
    encoded = tokenizer(full_text, return_tensors='pt', add_special_tokens=True)
    
    print(f"Loaded C4 ({split}): {encoded.input_ids.shape[1]} tokens")
    
    return encoded

# ===================== Evaluation Functions =====================

# Function to evaluate perplexity on wikitext
def eval_ppl_wikitext(model, testenc, bs=1, device=None, seqlen=2048):
    """
    Evaluate perplexity on WikiText dataset
    
    Args:
        model: The language model
        testenc: The tokenized test data
        bs: Batch size
        device: Device to use
        seqlen: Sequence length
        
    Returns:
        float: Perplexity value
    """
    
    # Get input IDs
    testenc = testenc.input_ids
    
    # Calculate number of samples
    nsamples = testenc.numel() // seqlen
    
    # List to store negative log likelihoods
    nlls = []
    # List to store per-sample perplexities
    per_sample_ppls = []
    print(f"nsamples {nsamples}")
    
    # Loop through each batch
    for i in range(0, nsamples, bs):
        if i % 50 == 0:
            print(f"sample {i}")
        
        # Calculate end index
        j = min(i+bs, nsamples)
        
        # Prepare inputs and move to device
        inputs = testenc[:, (i * seqlen):(j * seqlen)].to(device)
        inputs = inputs.reshape(j-i, seqlen)
        
        # Keep inputs as Long type for embedding layer
        inputs = inputs.to(torch.long)
        
        # Forward pass through the model with mixed precision
        with torch.cuda.amp.autocast(dtype=torch.float16):
            with torch.no_grad():
                outputs = model(inputs)
                if hasattr(outputs, 'logits'):
                    lm_logits = outputs.logits
                else:
                    lm_logits = outputs
        
        # Shift logits and labels for next token prediction
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]
        
        # Compute loss
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
        
        # Calculate negative log likelihood
        # neg_log_likelihood = loss.float() * seqlen * (j-i)
        neg_log_likelihood = loss.float() * shift_labels.numel()

        # Append to list of negative log likelihoods
        nlls.append(neg_log_likelihood)
        
        # Calculate and store per-sample perplexity
        per_sample_ppl = torch.exp(neg_log_likelihood / shift_labels.numel())
        per_sample_ppls.append(per_sample_ppl)
    
    # Compute perplexity
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * (seqlen-1)))
    
    # Save per-sample perplexities
    per_sample_ppls_cpu = [ppl.cpu().item() for ppl in per_sample_ppls]
    import pickle
    with open(f'xxx/pruned_model_per_sample_ppls.pkl', 'wb') as f:
        pickle.dump(per_sample_ppls_cpu, f)
    print(f"saved per_sample_ppls to pruned_model_per_sample_ppls.pkl")
    
    # Empty CUDA cache to save memory
    torch.cuda.empty_cache()
    
    return ppl.item()

# ===================== Model Analysis =====================

def analyze_model_sparsity(model):
    """
    Analyze sparsity of model parameters
    
    Args:
        model: The model to analyze
        
    Returns:
        dict: Sparsity statistics
    """
    print("\n===== Model Sparsity Analysis =====")
    
    total_params = 0
    zero_params = 0
    
    # Analyze layer by layer
    layer_stats = []
    
    for name, param in model.named_parameters():
        param_count = param.numel()
        zeros_count = (param == 0).sum().item()
        total_params += param_count
        zero_params += zeros_count
        sparsity = zeros_count / param_count
        
        layer_stats.append({
            "name": name,
            "param_count": param_count,
            "zeros_count": zeros_count,
            "sparsity": sparsity
        })
        
        print(f"{name} - sparsity {sparsity:.4f} ({zeros_count}/{param_count})")
    
    # Calculate overall sparsity
    overall_sparsity = zero_params / total_params if total_params > 0 else 0
    
    print(f"\nOverall model sparsity: {overall_sparsity:.4f} ({zero_params}/{total_params})")
    print(f"Total parameters: {total_params/1e6:.2f}M, Non-zero parameters: {(total_params-zero_params)/1e6:.2f}M")
    
    return {
        "layer_stats": layer_stats,
        "total_params": total_params,
        "zero_params": zero_params,
        "overall_sparsity": overall_sparsity
    }

# ===================== Main Function =====================

def main():
    parser = argparse.ArgumentParser(description="Evaluate pruned model on multiple datasets")
    parser.add_argument('--pruned_model_path', type=str, default="xxx/llms/pruning/DISP/phi/phi-2-070", 
                        help='Path to the pruned model directory')
    parser.add_argument('--pruned_ratio', type=float, default=0.7, 
                        help='Pruned ratio for evaluation')
    parser.add_argument('--seq_len', type=int, default=2048, 
                        help='Sequence length for evaluation')
    parser.add_argument('--batch_size', type=int, default=1, 
                        help='Batch size for evaluation')
    parser.add_argument('--num_samples', type=int, default=None, 
                        help='Number of samples per dataset (None for all)')
    parser.add_argument('--seed', type=int, default=42, 
                        help='Random seed')
    parser.add_argument('--device', type=str, default="cuda:1", 
                        help='Device for evaluation')
    parser.add_argument('--datasets', type=str, nargs='+', 
                        default=["winogrande", "arc-e", "arc-c", "hellaswag", "piqa", "boolq", "obqa"],# "wikitext", "ptb", "c4", "winogrande", "arc-e", "arc-c", "hellaswag", "piqa", "boolq", "obqa"
                        help='Datasets to evaluate')
    parser.add_argument('--skip_wikitext', action='store_true',
                        help='Skip WikiText evaluation')
    args = parser.parse_args()
    
    # Set seed
    set_seed(args.seed)
    
    # Ensure model path is absolute
    model_path = os.path.abspath(args.pruned_model_path)
    
    # Add model path to Python path for importing
    if model_path not in sys.path:
        sys.path.insert(0, model_path)
    
    vectors = torch.load(f'{args.pruned_model_path}/vectors-{args.pruned_ratio}.pt', map_location='cpu')
    # PrunedLlamaForCausalLM.cfgs = vectors
    PrunedPhiForCausalLM.cfgs = vectors

    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    
    # Load tokenizer
    print(f"Loading model from {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    # load hf model
    # model = PrunedLlamaForCausalLM.from_pretrained(model_path, config=config, trust_remote_code=True)
    model = PrunedPhiForCausalLM.from_pretrained(model_path, config=config, trust_remote_code=True)

    # Set device
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Move model to device and convert to float16 (match autocast settings in dataset_loader)
    model = model.to(device)
    model = model.to(torch.float16)
    model.eval()
    
    # Set sequence length attribute similar to original code
    model.seqlen = args.seq_len
    
    # Analyze model sparsity
    # analyze_model_sparsity(model)
    
    # Store results for all datasets
    results = {}
    
    # Process each dataset
    for dataset_name in args.datasets:
        if dataset_name.lower() == "wikitext" and not args.skip_wikitext:
            # Evaluate perplexity on WikiText
            print("\n===== Evaluating WikiText-2 Perplexity =====")
            
            # Get dataset
            # train
            testenc = get_wikitext2(tokenizer, seqlen=args.seq_len, split='test')
            
            # Evaluate perplexity
            ppl = eval_ppl_wikitext(
                model=model,
                testenc=testenc,
                bs=args.batch_size,
                device=device,
                seqlen=args.seq_len
            )
            
            print(f"WikiText-2 perplexity: {ppl:.4f}")
            
            results["wikitext"] = {"perplexity": ppl}
        elif dataset_name.lower() == "ptb":
            # Evaluate perplexity on PTB
            print("\n===== Evaluating PTB Perplexity =====")
            
            # Get dataset
            testenc = get_ptb(tokenizer, seqlen=args.seq_len, split='test')

            # Evaluate perplexity
            ppl = eval_ppl_wikitext(
                model=model,
                testenc=testenc,
                bs=args.batch_size,
                device=device,
                seqlen=args.seq_len
            )           

            print(f"PTB perplexity: {ppl:.4f}")
            
            results["ptb"] = {"perplexity": ppl}
            
        elif dataset_name.lower() == "c4":
            # Evaluate perplexity on C4
            print("\n===== Evaluating C4 Perplexity =====")
            
            # Get dataset
            testenc = get_c4(tokenizer, seqlen=args.seq_len, split='test')

            # Evaluate perplexity
            ppl = eval_ppl_wikitext(
                model=model,
                testenc=testenc,
                bs=args.batch_size,
                device=device,
                seqlen=args.seq_len
            )           

            print(f"C4 perplexity: {ppl:.4f}")
            
            results["c4"] = {"perplexity": ppl}
            
        elif dataset_name.lower() not in ["wikitext", "ptb", "c4"]:
            # use evaluate_mc_dataset function from dataset_loader
            print(f"\n===== Evaluating {dataset_name} =====")
            
            # call evaluation function from dataset_loader
            dataset_result = evaluate_mc_dataset(
                model=model,
                tokenizer=tokenizer,
                dataset_name=dataset_name,
                device=device,
                num_examples=args.num_samples,
                split="test"  # use validation split, already handled in dataset_loader
            )
            
            # Store results
            results[dataset_name] = dataset_result
        
        # Clear memory
        gc.collect()
        torch.cuda.empty_cache()
    
    # Print overall summary
    print("\n===== Overall Evaluation Summary =====")
    
    for dataset_name, result in results.items():
        if dataset_name.lower() in ["wikitext", "ptb", "c4"]:
            print(f"{dataset_name}: Perplexity = {result['perplexity']:.4f}")
        else:
            # use result format returned by dataset_loader
            print(f"{dataset_name}: Accuracy = {result['acc']:.4f}, Normalized Accuracy = {result['acc_norm']:.4f}")

if __name__ == "__main__":
    main()