import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import json
import os
import argparse
import time
from typing import List, Tuple, Dict, Any
import sys
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

def parse_args():
    parser = argparse.ArgumentParser(description='Inference script for Countdown ES models - exactly matching training')
    
    # Model arguments
    parser.add_argument('--model_path', type=str, required=True, 
                       help='Path to the saved model directory')
    parser.add_argument('--base_model_name', type=str, default='',
                       help='Base model name for tokenizer loading')
    parser.add_argument('--hf_cache_dir', type=str, 
                       default='',
                       help='Hugging Face cache directory')
    parser.add_argument('--reward_type', type=str, default='grpo_reward',
                       help='Reward type (toy_reward, grpo_reward)')
    
    # Data arguments - exactly matching es_llm_v4_countdown.py
    parser.add_argument('--train_data_path', type=str, 
                       default='',
                       help='Path to training data JSON file')
    parser.add_argument('--eval_data_path', type=str, 
                       default='',
                       help='Path to evaluation data JSON file')
    parser.add_argument('--train_samples', type=int, default=1000,
                       help='Number of training samples to evaluate (data_sample from training)')
    parser.add_argument('--eval_samples', type=int, default=100,
                       help='Number of evaluation samples to evaluate')
    parser.add_argument('--eval_offset', type=int, default=-100,
                       help='Offset for evaluation data (negative means from end)')
    
    # Generation arguments - exactly matching training script
    parser.add_argument('--max_new_tokens', type=int, default=1024,
                       help='Maximum number of new tokens to generate')
    parser.add_argument('--do_sample', action='store_true',
                       help='Whether to use sampling instead of greedy decoding')
    parser.add_argument('--temperature', type=float, default=0.8,
                       help='Temperature for sampling')
    parser.add_argument('--top_p', type=float, default=0.9,
                       help='Top-p for nucleus sampling')
    parser.add_argument('--batch_size', type=int, default=None,
                       help='Batch size for inference (default: min(32, dataset_size))')
    
    # Device arguments - exactly matching training script
    parser.add_argument('--device', type=str, default='auto',
                       help='Device to use (auto, cuda, cpu)')
    parser.add_argument('--torch_dtype', type=str, default='auto', 
                       choices=['auto', 'float16', 'bfloat16', 'float32'],
                       help='Torch dtype for model loading')
    parser.add_argument('--mixed_precision', type=str, default='fp16', 
                       choices=['no', 'fp16', 'bf16'],
                       help='Mixed precision mode (matching training)')
    
    # Output arguments
    parser.add_argument('--output_dir', type=str, default=None,
                       help='Directory to save inference results (default: based on model name)')
    parser.add_argument('--save_responses', action='store_true',
                       help='Save individual responses to file')
    parser.add_argument('--verbose', action='store_true',
                       help='Print verbose output')
    parser.add_argument('--show_examples', type=int, default=5,
                       help='Number of examples to show in detail')
    
    # Visualization arguments
    parser.add_argument('--generate_plots', action='store_true',
                       help='Generate visualization plots')
    parser.add_argument('--plot_style', type=str, default='seaborn-v0_8',
                       help='Matplotlib style for plots')
    
    return parser.parse_args()

def load_data(data_path: str, num_samples: int = None, offset: int = 0) -> List[Tuple[str, str]]:
    """Load dataset from JSON file - exactly matching es_llm_v4_countdown.py format"""
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Data file not found: {data_path}")
    
    with open(data_path, 'r') as f:
        data_json = json.load(f)
    
    dataset = []
    for item in data_json:
        context = item['context']
        target = item['target']
        dataset.append((context, target))
    
    if offset < 0:
        start_idx = len(dataset) + offset
        end_idx = len(dataset)
    else:
        start_idx = offset
        end_idx = len(dataset)
    
    dataset = dataset[start_idx:end_idx]
    
    if num_samples is not None and num_samples < len(dataset):
        dataset = dataset[:num_samples]
    
    return dataset

def extract_model_response(generated_text: str) -> str:
    """Extract model response from generated text - exactly matching training script"""
    model_response = generated_text
    if "assistant:" in generated_text:
        model_response = generated_text.split("assistant:")[-1].strip()
    return model_response

def extract_numbers_and_target(input_text: str, target_text: str) -> Tuple[List[int], int]:
    """Extract numbers and target from input and target text - exactly matching training script"""
    numbers = None
    target = None
    
    if "[" in input_text and "]" in input_text:
        start_idx = input_text.find("[")
        end_idx = input_text.find("]")
        if start_idx != -1 and end_idx != -1:
            numbers_str = input_text[start_idx+1:end_idx]
            numbers = [int(n) for n in numbers_str.split() if n.isdigit()]
    
    if target_text.isdigit():
        target = int(target_text)
    
    return numbers, target

def evaluate_batch(model, tokenizer, input_texts: List[str], target_texts: List[str], 
                  device, args, verbose: bool = False) -> List[Dict[str, Any]]:
    """Evaluate a batch of samples - exactly matching training script evaluation"""
    if verbose:
        print(f"Batch evaluating {len(input_texts)} samples...")
    
    tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, padding_side="left")
    input_ids = tokenized_inputs["input_ids"].to(device)
    attention_mask = tokenized_inputs["attention_mask"].to(device)
    
    all_results = []
    with torch.inference_mode():
        outputs = model.generate(
            input_ids, 
            attention_mask=attention_mask, 
            max_new_tokens=args.max_new_tokens, 
            do_sample=args.do_sample
        )
        
        if torch.cuda.is_available():
            torch.cuda.synchronize(device)
    
    for i, output in enumerate(outputs):
        try:
            generated_text = tokenizer.decode(output, skip_special_tokens=True)
        except TypeError:
            tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True)
            filtered = [t for t in tokens if t is not None]
            generated_text = tokenizer.convert_tokens_to_string(filtered)
        
        model_response = extract_model_response(generated_text)
        
        input_text = input_texts[i]
        target_text = target_texts[i]
        numbers, target = extract_numbers_and_target(input_text, target_text)
        
        reward_result = reward_function(model_response, numbers, target)
        reward = reward_result["reward"]
        reward_info = reward_result["reward_info"]
        
        all_results.append({
            'input_text': input_text,
            'target_text': target_text,
            'generated_text': generated_text,
            'model_response': model_response,
            'numbers': numbers,
            'target': target,
            'reward': reward,
            'reward_info': reward_info
        })
    
    del input_ids, outputs
    if torch.cuda.is_available():
        torch.cuda.synchronize(device)
        torch.cuda.empty_cache()
    
    return all_results

def evaluate_single_sample(model, tokenizer, input_text: str, target_text: str, 
                          device, args, verbose: bool = False) -> Dict[str, Any]:
    """Evaluate a single sample (fallback for compatibility)"""
    results = evaluate_batch(model, tokenizer, [input_text], [target_text], device, args, verbose)
    return results[0]

def evaluate_dataset(model, tokenizer, dataset: List[Tuple[str, str]], 
                    device, args, dataset_name: str, batch_size: int = None) -> Dict[str, Any]:
    """Evaluate model on a dataset using batch processing"""
    print(f"\n=== Evaluating on {dataset_name} dataset ({len(dataset)} samples) ===")
    
    # Set default batch size based on dataset size and memory considerations
    if batch_size is None:
        batch_size = min(32, len(dataset))  # Default batch size
    
    print(f"Using batch size: {batch_size}")
    
    all_results = []
    total_reward = 0.0
    total_format_reward = 0.0
    total_answer_reward = 0.0
    
    start_time = time.time()
    
    # Process dataset in batches
    for batch_start in range(0, len(dataset), batch_size):
        batch_end = min(batch_start + batch_size, len(dataset))
        batch_dataset = dataset[batch_start:batch_end]
        
        if args.verbose:
            print(f"Processing batch {batch_start//batch_size + 1}/{(len(dataset)-1)//batch_size + 1} (samples {batch_start+1}-{batch_end})...")
        
        # Prepare batch inputs
        input_texts = [item[0] for item in batch_dataset]
        target_texts = [item[1] for item in batch_dataset]
        
        # Evaluate batch
        batch_results = evaluate_batch(model, tokenizer, input_texts, target_texts, 
                                     device, args, verbose=args.verbose)
        
        # Accumulate results
        all_results.extend(batch_results)
        
        for result in batch_results:
            total_reward += result['reward']
            total_format_reward += result['reward_info']['format_reward']
            total_answer_reward += result['reward_info']['answer_reward']
        
        # Show examples from first batch only
        if batch_start == 0:
            for i, result in enumerate(batch_results[:args.show_examples]):
                print(f"\n--- Example {i+1} ---")
                print(f"Input: {result['input_text']}")
                print(f"Target: {result['target_text']}")
                print(f"Model Response: {result['model_response']}")
                print(f"Reward: {result['reward']:.4f} (Format: {result['reward_info']['format_reward']:.4f}, Answer: {result['reward_info']['answer_reward']:.4f})")
    
    eval_time = time.time() - start_time
    
    # Calculate statistics
    avg_reward = total_reward / len(dataset)
    avg_format_reward = total_format_reward / len(dataset)
    avg_answer_reward = total_answer_reward / len(dataset)
    
    rewards = [r['reward'] for r in all_results]
    std_reward = np.std(rewards)
    min_reward = np.min(rewards)
    max_reward = np.max(rewards)
    
    # Calculate percentiles
    p25_reward = np.percentile(rewards, 25)
    p50_reward = np.percentile(rewards, 50)
    p75_reward = np.percentile(rewards, 75)
    
    # Count high-reward samples (≥1.0)
    high_reward_count = sum(1 for r in rewards if r >= 1.0)
    high_reward_percentage = high_reward_count / len(dataset) * 100
    
    # Calculate accuracy (answer_reward > 0 counts as correct)
    answer_rewards = [r['reward_info']['answer_reward'] for r in all_results]
    correct_count = sum(1 for r in answer_rewards if r > 0)
    accuracy = correct_count / len(dataset) * 100
    
    stats = {
        'dataset_name': dataset_name,
        'num_samples': len(dataset),
        'avg_reward': avg_reward,
        'avg_format_reward': avg_format_reward,
        'avg_answer_reward': avg_answer_reward,
        'std_reward': std_reward,
        'min_reward': min_reward,
        'max_reward': max_reward,
        'p25_reward': p25_reward,
        'p50_reward': p50_reward,
        'p75_reward': p75_reward,
        'high_reward_count': high_reward_count,
        'high_reward_percentage': high_reward_percentage,
        'correct_count': correct_count,
        'accuracy': accuracy,
        'eval_time': eval_time,
        'all_results': all_results
    }
    
    return stats

def create_visualizations(train_stats: Dict[str, Any], eval_stats: Dict[str, Any], 
                         output_dir: str, args) -> None:
    """Create visualization plots for the evaluation results"""
    
    if not args.generate_plots:
        return
    
    print("Generating visualization plots...")
    
    try:
        plt.style.use(args.plot_style)
    except:
        plt.style.use('default')
    
    colors = {'train': '#2E86AB', 'eval': '#A23B72', 'combined': '#F18F01'}
    
    plots_dir = os.path.join(output_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    train_rewards = [r['reward'] for r in train_stats['all_results']]
    eval_rewards = [r['reward'] for r in eval_stats['all_results']]
    
    ax1.hist(train_rewards, bins=30, alpha=0.7, color=colors['train'], 
             label=f'Train (n={len(train_rewards)})', density=True)
    ax1.hist(eval_rewards, bins=30, alpha=0.7, color=colors['eval'], 
             label=f'Eval (n={len(eval_rewards)})', density=True)
    ax1.axvline(1.0, color='red', linestyle='--', alpha=0.8, label='High Reward Threshold (1.0)')
    ax1.set_xlabel('Reward')
    ax1.set_ylabel('Density')
    ax1.set_title('Reward Distribution Comparison')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    ax2.boxplot([train_rewards, eval_rewards], labels=['Train', 'Eval'], 
                patch_artist=True, boxprops=dict(facecolor=colors['combined'], alpha=0.7))
    ax2.axhline(1.0, color='red', linestyle='--', alpha=0.8, label='High Reward Threshold')
    ax2.set_ylabel('Reward')
    ax2.set_title('Reward Distribution (Box Plot)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'reward_distributions.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Plots saved to: {plots_dir}")

def save_results(results: Dict[str, Any], output_dir: str, args):
    """Save evaluation results to files"""
    os.makedirs(output_dir, exist_ok=True)
    
    summary = {
        'model_path': args.model_path,
        'train_stats': {k: v for k, v in results['train_stats'].items() if k != 'all_results'},
        'eval_stats': {k: v for k, v in results['eval_stats'].items() if k != 'all_results'},
        'generation_config': {
            'max_new_tokens': args.max_new_tokens,
            'do_sample': args.do_sample,
            'temperature': args.temperature if args.do_sample else None,
            'top_p': args.top_p if args.do_sample else None,
            'batch_size': args.batch_size,
            'mixed_precision': args.mixed_precision,
        }
    }
    
    summary_path = os.path.join(output_dir, 'summary.json')
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    print(f"Summary saved to: {summary_path}")
    
    if args.save_responses:
        train_details_path = os.path.join(output_dir, 'train_detailed_results.json')
        eval_details_path = os.path.join(output_dir, 'eval_detailed_results.json')
        
        with open(train_details_path, 'w') as f:
            json.dump(results['train_stats']['all_results'], f, indent=2)
        print(f"Train detailed results saved to: {train_details_path}")
        
        with open(eval_details_path, 'w') as f:
            json.dump(results['eval_stats']['all_results'], f, indent=2)
        print(f"Eval detailed results saved to: {eval_details_path}")

def main():
    args = parse_args()

    global reward_function
    if args.reward_type == 'grpo_reward':
        from countdown_task import reward_function
        print(f"Using original GRPO reward function")
    elif args.reward_type == 'toy_reward':
        from countdown_task_toy import reward_function
        print(f"Using modified GRPO reward function")
    else:
        raise ValueError(f"Invalid reward type: {args.reward_type}, please choose from grpo_reward or toy_reward")
    
    if args.output_dir is None:
        model_name = os.path.basename(args.model_path.rstrip('/'))
        batch_suffix = f"_batch{args.batch_size}" if args.batch_size else ""
        args.output_dir = f"./inference_results_{model_name}{batch_suffix}"
    
    print("=== Countdown ES Model Inference Script ===")
    print(f"Model path: {args.model_path}")
    print(f"Train data: {args.train_data_path} (samples: {args.train_samples})")
    print(f"Eval data: {args.eval_data_path} (samples: {args.eval_samples}, offset: {args.eval_offset})")
    print(f"Output directory: {args.output_dir}")
    print(f"Mixed precision: {args.mixed_precision}")
    print(f"Reward type: {args.reward_type}")
    
    if args.device == 'auto':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = args.device
    print(f"Using device: {device}")
    
    if args.torch_dtype == 'auto':
        if args.mixed_precision == 'fp16':
            torch_dtype = torch.float16
        elif args.mixed_precision == 'bf16':
            torch_dtype = torch.bfloat16
        else:
            torch_dtype = torch.float16 if device == 'cuda' else torch.float32
    else:
        torch_dtype = getattr(torch, args.torch_dtype)
    print(f"Using torch dtype: {torch_dtype}")
    
    print(f"\nLoading model from {args.model_path}...")
    try:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            cache_dir=args.hf_cache_dir,
            device_map="auto" if device == 'cuda' else None,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
            # attn_implementation="flash_attention_2"  # Exactly as in training script
        )
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_path,
            cache_dir=args.hf_cache_dir,
            use_fast=False,
            trust_remote_code=True
        )
    except Exception as e:
        print(f"Error loading from {args.model_path}: {e}")
        print(f"Trying to load tokenizer from base model: {args.base_model_name}")
        tokenizer = AutoTokenizer.from_pretrained(
            args.base_model_name,
            cache_dir=args.hf_cache_dir,
            use_fast=False,
            trust_remote_code=True
        )
    
    # First verify that tokenizer is a proper object and not a boolean
    if not hasattr(tokenizer, 'pad_token'):
        print(f"Error: tokenizer is not a valid tokenizer object (got {type(tokenizer)})")
        print("Trying to load tokenizer again from base model...")
        tokenizer = AutoTokenizer.from_pretrained(
            args.base_model_name,
            cache_dir=args.hf_cache_dir,
            use_fast=False,
            trust_remote_code=True
        )
        
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print(f"Set tokenizer.pad_token = tokenizer.eos_token ({tokenizer.eos_token})")
    
    if device != 'cuda' and hasattr(model, 'to'):
        model = model.to(device)
    
    model.eval()
    print("Model loaded successfully!")
    
    print(f"\nLoading datasets...")
    train_dataset = load_data(
        os.path.join(os.path.dirname(__file__), args.train_data_path),
        num_samples=args.train_samples,
        offset=0
    )
    eval_dataset = load_data(
        os.path.join(os.path.dirname(__file__), args.eval_data_path),
        num_samples=args.eval_samples,
        offset=args.eval_offset
    )
    
    train_stats = evaluate_dataset(model, tokenizer, train_dataset, device, args, "Train", batch_size=args.batch_size)
    eval_stats = evaluate_dataset(model, tokenizer, eval_dataset, device, args, "Eval", batch_size=args.batch_size)
    
    # Save results
    results = {
        'train_stats': train_stats,
        'eval_stats': eval_stats
    }
    save_results(results, args.output_dir, args)
    
    # Generate visualizations
    create_visualizations(train_stats, eval_stats, args.output_dir, args)
    
    # Print final comparison
    print(f"\n=== Final Comparison ===")
    print(f"Train performance:")
    print(f"  - Average reward: {train_stats['avg_reward']:.4f} ± {train_stats['std_reward']:.4f}")
    print(f"  - Average answer reward: {train_stats['avg_answer_reward']:.4f}")
    print(f"  - Accuracy (answer_reward > 0): {train_stats['accuracy']:.1f}% ({train_stats['correct_count']}/{train_stats['num_samples']})")
    print(f"Eval performance:")
    print(f"  - Average reward: {eval_stats['avg_reward']:.4f} ± {eval_stats['std_reward']:.4f}")
    print(f"  - Average answer reward: {eval_stats['avg_answer_reward']:.4f}")
    print(f"  - Accuracy (answer_reward > 0): {eval_stats['accuracy']:.1f}% ({eval_stats['correct_count']}/{eval_stats['num_samples']})")
    
    reward_gap = train_stats['avg_reward'] - eval_stats['avg_reward']
    answer_reward_gap = train_stats['avg_answer_reward'] - eval_stats['avg_answer_reward']
    accuracy_gap = train_stats['accuracy'] - eval_stats['accuracy']
    
    print(f"Performance gaps:")
    print(f"  - Reward gap: {reward_gap:.4f}")
    print(f"  - Answer reward gap: {answer_reward_gap:.4f}")
    print(f"  - Accuracy gap: {accuracy_gap:.1f}%")
    
    if train_stats['avg_answer_reward'] > eval_stats['avg_answer_reward']:
        print("Model appears to have overfit to training data")
    elif eval_stats['avg_answer_reward'] > train_stats['avg_answer_reward']:
        print("Model generalizes well to evaluation data")
    else:
        print("Similar performance on train and eval data")

if __name__ == "__main__":
    main()