# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0

"""
ALOHA VLM Reward 评估脚本（任务成功判断）
任务：head[t] -> Yes/No (任务是否成功)
"""

import os
import json
import argparse
import sys
sys.path.append('.')

import torch
import torch.distributed as dist
from PIL import Image
from tqdm import tqdm

from eval.aloha_eval_utils import setup_models, vlm_pred, set_seeds


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_dir', type=str, required=True)
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument('--output_dir', type=str, required=True)
    parser.add_argument('--jsonl_path', type=str, required=True)
    parser.add_argument('--prompt_path', type=str, required=True)
    parser.add_argument('--image_dir', type=str, required=True)
    parser.add_argument('--num_samples', type=int, default=100)
    parser.add_argument('--max_new_tokens', type=int, default=10)
    parser.add_argument('--max_mem_per_gpu', type=str, default="80GiB")
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for parallel inference')
    args = parser.parse_args()
    
    batch_size = args.batch_size
    
    dist.init_process_group(backend="nccl")
    local_rank = dist.get_rank()
    torch.cuda.set_device(local_rank % torch.cuda.device_count())
    device = f"cuda:{local_rank % torch.cuda.device_count()}"
    
    set_seeds(42)
    
    if local_rank == 0:
        print("Loading models...")
    model, vae_model, tokenizer, new_token_ids, vae_transform, vit_transform = setup_models(
        args.base_dir, args.model_path, device
    )
    
    with open(args.prompt_path, 'r', encoding='utf-8') as f:
        prompt_text = f.read().strip()
    
    with open(args.jsonl_path, 'r') as f:
        data = [json.loads(line) for line in f]
    
    if local_rank == 0:
        print(f"Total samples in dataset: {len(data)}")
    
    import random
    random.seed(42)
    
    if len(data) > args.num_samples:
        if local_rank == 0:
            print(f"Randomly sampling {args.num_samples} samples from {len(data)} total samples (before deduplication)...")
        sampled_indices = random.sample(range(len(data)), args.num_samples)
        # sampled_indices = range(args.num_samples)
        sampled_data = [data[i] for i in sorted(sampled_indices)]
    else:
        if local_rank == 0:
            print(f"Using all {len(data)} samples (less than num_samples={args.num_samples})")
        sampled_indices = range(len(data))
        sampled_data = data
    
    if local_rank == 0:
        print(f"Sampled {len(sampled_data)} samples (first 10 indices: {sorted(sampled_indices)[:10]})")
    
    seen_ids = set()
    deduplicated_data = []
    duplicate_count = 0
    for item in sampled_data:
        item_id = item['id']
        if item_id not in seen_ids:
            seen_ids.add(item_id)
            deduplicated_data.append(item)
        else:
            duplicate_count += 1
    
    if local_rank == 0:
        print(f"After deduplication: {len(deduplicated_data)} unique samples (removed {duplicate_count} duplicates from sampled data)")
    
    data = deduplicated_data
    world_size = dist.get_world_size()
    num_samples_after_dedup = len(data)
    samples_per_gpu = num_samples_after_dedup // world_size
    start_idx = local_rank * samples_per_gpu
    end_idx = start_idx + samples_per_gpu if local_rank < world_size - 1 else num_samples_after_dedup
    local_data = data[start_idx:end_idx]
    
    if local_rank == 0:
        print(f"Each GPU processing {samples_per_gpu} samples")
    
    os.makedirs(args.output_dir, exist_ok=True)
    
    error_images_dir = os.path.join(args.output_dir, 'error_images')
    os.makedirs(error_images_dir, exist_ok=True)
    
    results = []
    error_sample_ids = []
    
    if local_rank == 0:
        print(f"Evaluating {len(local_data)} samples...")
    
    for batch_idx in tqdm(range(0, len(local_data), batch_size), desc="Evaluating"):
        batch_items = local_data[batch_idx:batch_idx+batch_size]
        actual_batch_size = len(batch_items)
        
        input_images = []
        for item in batch_items:
            image_paths = item['image']
            input_images.append([Image.open(os.path.join(args.image_dir, img_file)).convert('RGB') for img_file in image_paths])

        human_msgs = [item['conversations'][0]['value'] for item in batch_items]
        questions = [human_msg.replace('<image>', '').replace('<prompt>', prompt_text).strip() for human_msg in human_msgs]
        ground_truths = [item['conversations'][1]['value'] for item in batch_items]
        
        outputs = vlm_pred(
            model, tokenizer, new_token_ids, vit_transform,
            questions, input_images, 
            original_image_size=input_images[0][0].size,
            num_samples=1, 
            do_sample=False,
            temperature=1.0,
            max_length=args.max_new_tokens,
            device=device
        )
        
        for i, item in enumerate(batch_items):
            prediction = outputs[0][i]
            ground_truth = ground_truths[i]
            
            prediction_clean = prediction.strip().lower()
            if 'yes' in prediction_clean:
                prediction_label = 'Yes.'
            elif 'no' in prediction_clean:
                prediction_label = 'No.'
            else:
                prediction_label = 'Unknown'
            
            frame_idx = item.get('frame', batch_idx + i)
            
            ground_truth_clean = ground_truth.strip()
            prediction_label_clean = prediction_label.strip()
            
            is_correct = 1 if prediction_label_clean == ground_truth_clean else 0
            
            result = {
                'id': item['id'],
                'episode_id': item['episode_id'],
                'frame': frame_idx,
                'question': questions[i],
                'ground_truth': ground_truth,
                'prediction': prediction,
                'prediction_label': prediction_label,
                'correct': is_correct
            }
            results.append(result)
            
            if is_correct == 0:
                error_sample_ids.append(item['id'])
                
                gt_label = ground_truth_clean.replace('.', '').replace(' ', '_')
                pred_label = prediction_label_clean.replace('.', '').replace(' ', '_')
                error_image_filename = f"error_rank{local_rank}_id_{item['id']}_ep{item['episode_id']}_frame{frame_idx}_gt{gt_label}_pred{pred_label}.jpg"
                error_image_path = os.path.join(error_images_dir, error_image_filename)
                
                input_images[i][0].save(error_image_path)
                result['error_image_path'] = error_image_path
                if local_rank == 0:
                    print(f"Saved error image: {error_image_filename}")
        
        if local_rank == 0 and batch_idx % (batch_size * 10) == 0:
            print(f"Processed {min(batch_idx + batch_size, len(local_data))}/{len(local_data)} samples")
    
    dist.barrier()
    
    if local_rank == 0:
        print(f"\nAll ranks finished inference. Gathering results...")
    
    world_size = dist.get_world_size()
    
    results_json = json.dumps(results)
    
    gathered_results = [None] * world_size
    dist.all_gather_object(gathered_results, results_json)
    
    all_results = []
    for rank_results_json in gathered_results:
        if rank_results_json:
            rank_results = json.loads(rank_results_json)
            all_results.extend(rank_results)
    
    all_results.sort(key=lambda x: x['id'])
    
    if local_rank == 0:
        
        total = len(all_results)
        correct = sum([r['correct'] for r in all_results])
        accuracy = correct / total if total > 0 else 0
        
        tp_yes = 0  # True Positive for "Yes"
        fp_yes = 0  # False Positive for "Yes"
        fn_yes = 0  # False Negative for "Yes"
        tp_no = 0   # True Positive for "No"
        fp_no = 0   # False Positive for "No"
        fn_no = 0   # False Negative for "No"
        
        for r in all_results:
            gt = r['ground_truth'].strip()
            pred = r['prediction_label'].strip()
            
            if gt == 'Yes.':
                if pred == 'Yes.':
                    tp_yes += 1
                else:
                    fn_yes += 1
            elif pred == 'Yes.':
                fp_yes += 1
            
            if gt == 'No.':
                if pred == 'No.':
                    tp_no += 1
                else:
                    fn_no += 1
            elif pred == 'No.':
                fp_no += 1
        
        precision_yes = tp_yes / (tp_yes + fp_yes) if (tp_yes + fp_yes) > 0 else 0.0
        recall_yes = tp_yes / (tp_yes + fn_yes) if (tp_yes + fn_yes) > 0 else 0.0
        f1_yes = 2 * (precision_yes * recall_yes) / (precision_yes + recall_yes) if (precision_yes + recall_yes) > 0 else 0.0
        
        precision_no = tp_no / (tp_no + fp_no) if (tp_no + fp_no) > 0 else 0.0
        recall_no = tp_no / (tp_no + fn_no) if (tp_no + fn_no) > 0 else 0.0
        f1_no = 2 * (precision_no * recall_no) / (precision_no + recall_no) if (precision_no + recall_no) > 0 else 0.0
        
        macro_f1 = (f1_yes + f1_no) / 2.0
        
        total_tp = tp_yes + tp_no
        total_fp = fp_yes + fp_no
        total_fn = fn_yes + fn_no
        micro_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
        micro_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
        micro_f1 = 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0.0
        
        with open(os.path.join(args.output_dir, 'vlm_reward_results.json'), 'w') as f:
            json.dump(all_results, f, indent=2)
        
        error_count = total - correct
        
        error_images_count = sum([1 for r in all_results if 'error_image_path' in r])
        
        print("\n" + "="*50)
        print("VLM Reward Evaluation Results")
        print("="*50)
        print(f"Total samples: {total}")
        print(f"Correct predictions: {correct}")
        print(f"Error predictions: {error_count}")
        print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
        print("\n--- F1 Score Metrics ---")
        print(f"Macro F1: {macro_f1:.4f}")
        print(f"Micro F1: {micro_f1:.4f}")
        if error_count > 0:
            print(f"\nError images saved: {error_images_count}/{error_count}")
            print(f"Error images directory: {error_images_dir}")
            print(f"Error samples JSONL: {os.path.join(args.output_dir, 'error_samples.jsonl')}")
        print("="*50)
        
        all_error_sample_ids = set()
        for r in all_results:
            if r['correct'] == 0:
                all_error_sample_ids.add(r['id'])
        
        error_samples_data = []
        error_samples_dict = {}
        if len(all_error_sample_ids) > 0:
            if local_rank == 0:
                print(f"\nExtracting {len(all_error_sample_ids)} unique error samples from original JSONL file...")
            
            duplicate_count = 0
            with open(args.jsonl_path, 'r') as f:
                for line in f:
                    try:
                        item = json.loads(line)
                        item_id = item.get('id')
                        if item_id in all_error_sample_ids:
                            if item_id not in error_samples_dict:
                                error_samples_dict[item_id] = item
                            else:
                                duplicate_count += 1
                    except json.JSONDecodeError:
                        continue
            
            error_samples_data = list(error_samples_dict.values())
            error_samples_data.sort(key=lambda x: x.get('id', 0))
            
            if local_rank == 0:
                if duplicate_count > 0:
                    print(f"  Note: Found {len(error_samples_data)} unique error samples (removed {duplicate_count} duplicate entries with same id)")
                elif len(error_samples_data) < len(all_error_sample_ids):
                    missing_count = len(all_error_sample_ids) - len(error_samples_data)
                    print(f"  Warning: Found {len(error_samples_data)} error samples in JSONL, but {missing_count} error sample ids not found in original file")
                else:
                    print(f"  Extracted {len(error_samples_data)} unique error samples")
            
            error_samples_jsonl_path = os.path.join(args.output_dir, 'error_samples.jsonl')
            with open(error_samples_jsonl_path, 'w', encoding='utf-8') as f:
                for item in error_samples_data:
                    f.write(json.dumps(item, ensure_ascii=False) + '\n')
            
            if local_rank == 0:
                print(f"Saved {len(error_samples_data)} error samples to: {error_samples_jsonl_path}")
        
        summary = {
            'total_samples': total,
            'correct_predictions': correct,
            'error_predictions': error_count,
            'accuracy': accuracy,
            'f1_scores': {
                'macro_f1': macro_f1,
                'micro_f1': micro_f1,
                'yes': {
                    'precision': precision_yes,
                    'recall': recall_yes,
                    'f1': f1_yes,
                    'tp': tp_yes,
                    'fp': fp_yes,
                    'fn': fn_yes
                },
                'no': {
                    'precision': precision_no,
                    'recall': recall_no,
                    'f1': f1_no,
                    'tp': tp_no,
                    'fp': fp_no,
                    'fn': fn_no
                }
            },
            'error_images_saved': error_images_count,
            'error_images_dir': error_images_dir if error_count > 0 else None,
            'error_samples_count': len(error_samples_data),
            'error_samples_jsonl_path': os.path.join(args.output_dir, 'error_samples.jsonl') if len(error_samples_data) > 0 else None
        }
        with open(os.path.join(args.output_dir, 'vlm_reward_summary.json'), 'w') as f:
            json.dump(summary, f, indent=2)
    
    dist.destroy_process_group()


if __name__ == '__main__':
    main()
