import os
import json
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import argparse
import logging


def setup_logging():
    """Setup logging configuration"""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    return logging.getLogger(__name__)


def load_model_and_tokenizer(model_name):
    """Load the base model and tokenizer for perplexity calculation"""
    logger = logging.getLogger(__name__)
    logger.info(f"Loading model: {model_name}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    
    # Set pad token if not exists
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer


def extract_assistant_response(response_text):
    """Extract the assistant response from the full response text"""
    # Split by assistant tag and get the last part
    if "<|im_start|>assistant:" in response_text:
        assistant_response = response_text.split("<|im_start|>assistant:")[-1]
    else:
        assistant_response = response_text
    
    # Remove any ending tags
    if "<|im_end|>" in assistant_response:
        assistant_response = assistant_response.split("<|im_end|>")[0]
    
    return assistant_response.strip()


def calculate_perplexity(model, tokenizer, text, max_length=2048):
    """Calculate perplexity for a given text"""
    # Tokenize the text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
    input_ids = inputs["input_ids"].to(model.device)
    
    # Calculate log probabilities
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
        perplexity = torch.exp(loss)
    
    return perplexity.item()


def process_json_files(data_dir, model, tokenizer):
    """Process all JSON files in the directory and calculate perplexity"""
    logger = logging.getLogger(__name__)
    
    json_files = [f for f in os.listdir(data_dir) if f.endswith('.json')]
    json_files.sort(key=lambda x: int(x.split('.')[0]))  # Sort by number
    
    results = []
    total_perplexity = 0
    valid_files = 0
    
    logger.info(f"Processing {len(json_files)} JSON files...")
    
    for filename in tqdm(json_files, desc="Calculating perplexity"):
        file_path = os.path.join(data_dir, filename)
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            # Extract the response
            response_text = data.get('response', '')
            if not response_text:
                logger.warning(f"No response found in {filename}")
                continue
            
            # Extract assistant response
            assistant_response = extract_assistant_response(response_text)
            
            if not assistant_response:
                logger.warning(f"No assistant response found in {filename}")
                continue
            
            # Calculate perplexity
            perplexity = calculate_perplexity(model, tokenizer, assistant_response)
            
            # Store results
            result = {
                'file': filename,
                'problem': data.get('problem', ''),
                'response_length': len(assistant_response),
                'perplexity': perplexity
            }
            results.append(result)
            
            total_perplexity += perplexity
            valid_files += 1
            
            logger.debug(f"{filename}: Perplexity = {perplexity:.4f}")
            
        except Exception as e:
            logger.error(f"Error processing {filename}: {str(e)}")
            continue
    
    # Calculate average perplexity
    avg_perplexity = total_perplexity / valid_files if valid_files > 0 else 0
    
    logger.info(f"Processed {valid_files} files successfully")
    logger.info(f"Average perplexity: {avg_perplexity:.4f}")
    
    return results, avg_perplexity


def save_results(results, avg_perplexity, output_file):
    """Save results to a JSON file"""
    logger = logging.getLogger(__name__)
    
    output_data = {
        'summary': {
            'total_files': len(results),
            'average_perplexity': avg_perplexity,
            'min_perplexity': min(r['perplexity'] for r in results) if results else 0,
            'max_perplexity': max(r['perplexity'] for r in results) if results else 0,
        },
        'details': results
    }
    
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, indent=2, ensure_ascii=False)
    
    logger.info(f"Results saved to {output_file}")


def main():
    parser = argparse.ArgumentParser(description="Calculate perplexity for DeepSeek responses")
    parser.add_argument(
        "--data_dir", 
        type=str, 
        default="../asset/insert_response/DeepSeek-R1-Distill-Qwen-1.5B_MATH-500_134_32784",
        help="Directory containing JSON response files"
    )
    parser.add_argument(
        "--model_name", 
        type=str, 
        default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",  # Base model that DeepSeek was distilled from
        help="Name of the base model for perplexity calculation"
    )
    parser.add_argument(
        "--output_file", 
        type=str, 
        default="perplexity_results.json",
        help="Output file for results"
    )
    parser.add_argument(
        "--max_length", 
        type=int, 
        default=2048,
        help="Maximum sequence length for tokenization"
    )
    
    args = parser.parse_args()
    
    # Setup logging
    logger = setup_logging()
    
    # Check if data directory exists
    if not os.path.exists(args.data_dir):
        logger.error(f"Data directory does not exist: {args.data_dir}")
        return
    
    # Load model and tokenizer
    try:
        model, tokenizer = load_model_and_tokenizer(args.model_name)
    except Exception as e:
        logger.error(f"Failed to load model: {str(e)}")
        return
    
    # Process files and calculate perplexity
    results, avg_perplexity = process_json_files(args.data_dir, model, tokenizer)
    
    if not results:
        logger.error("No valid results found")
        return
    
    # Save results
    save_results(results, avg_perplexity, args.output_file)
    
    # Print summary
    print(f"\n=== Perplexity Calculation Summary ===")
    print(f"Data directory: {args.data_dir}")
    print(f"Base model: {args.model_name}")
    print(f"Total files processed: {len(results)}")
    print(f"Average perplexity: {avg_perplexity:.4f}")
    print(f"Min perplexity: {min(r['perplexity'] for r in results):.4f}")
    print(f"Max perplexity: {max(r['perplexity'] for r in results):.4f}")
    print(f"Results saved to: {args.output_file}")


if __name__ == "__main__":
    main() 