from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import json
from tqdm import tqdm
import re
import os
from pprint import pprint
import random
import argparse
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse
import glob
import warnings
import datetime
from PIL import Image
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")


MAX_PIXELS = 451584
FACTOR = 28  
TEST_SIZES = [308, 336, 364]
MIN_SIZE = 672 
MULTI_SCALE_ENABLED = False  

RUN_NAME = "Qwen2.5-VL-TAC-RESAMPLEIOUKL-warmup300-3000/checkpoint-1000"
# BASE_MODEL_PATH = f"model/{RUN_NAME}"
BASE_MODEL_PATH = f"VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}"

BSZ=4
DATA_ROOT = "VLM-R1_Data/rec_jsons_processed"

#TEST_DATASETS = ['refcoco_val', 'refcoco_testA', 'refcoco_testB', 'refcocop_val', 'refcocop_testA', 'refcocop_testB', 'refcocog_val', 'refcocog_test']
#IMAGE_ROOT = "VLM-R1_Data/coco"

TEST_DATASETS = ['refgta_subsample']
IMAGE_ROOT = "VLM-R1_Data/refgta"

# TEST_DATASETS = ['lisa_test']
# IMAGE_ROOT = "VLM-R1_Data/lisa_test"

#TEST_DATASETS = ['refcoco_train', 'refcocop_train', 'refcocog_train']
#IMAGE_ROOT = "VLM-R1_Data/coco"

OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"

#QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
#QUESTION_TEMPLATE = "{Question} Report the bbox coordinates in JSON format."
#QUESTION_TEMPLATE = "{Question} First output the thinking process and answer in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
QUESTION_TEMPLATE = "{Question} First output the thinking process then summarize the answer in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
#QUESTION_TEMPLATE = "{Question} First output the thinking process then summarize the answer and review the question to check the answer in <think> </think> tags. And then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."


def compute_intersection_area(box1, box2):
    if box1 is None or box2 is None:
        return 0
    
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    
    if x1 >= x2 or y1 >= y2:
        return 0
    
    return (x2 - x1) * (y2 - y1)


def find_max_intersection_bbox(bboxes):
    if not bboxes or len(bboxes) <= 1:
        return bboxes[0] if bboxes else None
    
    valid_bboxes = [box for box in bboxes if box is not None and len(box) == 4]
    if len(valid_bboxes) <= 1:
        return valid_bboxes[0] if valid_bboxes else None
    
    intersection_counts = [0] * len(valid_bboxes)
    intersection_areas = [0] * len(valid_bboxes)  
    
    for i in range(len(valid_bboxes)):
        for j in range(len(valid_bboxes)):
            if i == j:
                continue
                
            box1 = valid_bboxes[i]
            box2 = valid_bboxes[j]
            
            x1 = max(box1[0], box2[0])
            y1 = max(box1[1], box2[1])
            x2 = min(box1[2], box2[2])
            y2 = min(box1[3], box2[3])
            
            if x1 < x2 and y1 < y2:
                intersection_counts[i] += 1
                intersection_areas[i] += (x2 - x1) * (y2 - y1)
    
    if min(intersection_counts) < len(valid_bboxes) - 1:
        min_count = min(intersection_counts)
        min_count_indices = [i for i, count in enumerate(intersection_counts) if count == min_count]
        
        if len(min_count_indices) == 1:
            return valid_bboxes[min_count_indices[0]]
        else:
            min_area_idx = min(min_count_indices, key=lambda i: intersection_areas[i])
            return valid_bboxes[min_area_idx]
    
    areas = [(box[2]-box[0])*(box[3]-box[1]) for box in valid_bboxes]
    min_area_idx = areas.index(min(areas))
    return valid_bboxes[min_area_idx]


def multi_size_process_image(img, test_sizes=TEST_SIZES):
    orig_width, orig_height = img.size
    results = []
    
    for min_size in test_sizes:
        need_resize = (orig_width * orig_height > MAX_PIXELS) or (min(orig_width, orig_height) < min_size)
        
        if need_resize:
            new_height, new_width = smart_resize(orig_height, orig_width, min_size=min_size)
        else:
            new_height, new_width = orig_height, orig_width
        
        resized_img = img.resize((new_width, new_height), Image.LANCZOS)
        
        results.append({
            'min_size': min_size,
            'img': resized_img,
            'info': {
                'original_size': (orig_width, orig_height),
                'resized_size': (new_width, new_height),
                'resized': new_width != orig_width or new_height != orig_height
            }
        })
    
    return results


def smart_resize(height, width, factor=FACTOR, min_pixels=784, max_pixels=MAX_PIXELS, min_size=MIN_SIZE):
    orig_ratio = width / height
    need_resize = False
    
    if width * height > max_pixels:
        need_resize = True
        scale_factor = (max_pixels / (width * height)) ** 0.5
        new_height = height * scale_factor
        new_width = width * scale_factor
        
    elif min(width, height) < min_size:
        need_resize = True
        if width < height:
            scale_factor = min_size / width
        else:
            scale_factor = min_size / height
        
        new_height = height * scale_factor
        new_width = width * scale_factor
    else:
        return height, width
    
    new_height_factor = ((new_height + factor - 1) // factor) * factor
    new_width_factor = ((new_width + factor - 1) // factor) * factor
    
    adjusted_ratio = new_width_factor / new_height_factor
    
    if abs(adjusted_ratio - orig_ratio) / orig_ratio > 0.01: 
        if width < height:
            new_width = new_width_factor
            new_height = round(new_width / orig_ratio / factor) * factor
        else:
            new_height = new_height_factor
            new_width = round(new_height * orig_ratio / factor) * factor
    else:
        new_height = new_height_factor
        new_width = new_width_factor
    
    new_height = max(new_height, factor)
    new_width = max(new_width, factor)
    
    while new_height * new_width > max_pixels and new_height > factor and new_width > factor:
        scale_down = (max_pixels / (new_height * new_width)) ** 0.5
        new_height = int(new_height * scale_down / factor) * factor
        new_width = int(new_width * scale_down / factor) * factor
    
    final_ratio = new_width / new_height
    
    return int(new_height), int(new_width)


def scale_bbox_to_original(bbox, original_size, resized_size):
    if bbox is None:
        return None
        
    orig_w, orig_h = original_size
    resized_w, resized_h = resized_size
    
    w_ratio = orig_w / resized_w
    h_ratio = orig_h / resized_h
    
    scaled_bbox = [
        int(bbox[0] * w_ratio), 
        int(bbox[1] * h_ratio),
        int(bbox[2] * w_ratio), 
        int(bbox[3] * h_ratio)
    ]
    
    return scaled_bbox


def setup_distributed():
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    torch.cuda.set_device(local_rank) 
    
    dist.init_process_group(backend="nccl")
    
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    
    return local_rank, world_size, rank

local_rank, world_size, rank = setup_distributed()
device = f"cuda:{local_rank}"
print(f"Process {rank} using {device}")

os.makedirs("./logs", exist_ok=True)

checkpoints = []
checkpoint_dirs = glob.glob(f"{BASE_MODEL_PATH}/checkpoint-*")


if not checkpoint_dirs:
    checkpoints.append((BASE_MODEL_PATH, 0))  
else:
    for checkpoint_dir in checkpoint_dirs:
        step = int(os.path.basename(checkpoint_dir).split('-')[1])
        checkpoints.append((checkpoint_dir, step))

checkpoints.sort(key=lambda x: x[1])

if rank == 0:
    summary_results = {ds: [] for ds in TEST_DATASETS}

def extract_bbox_answer(content):
    answer_tag_pattern = r'<answer>(.*?)</answer>'
    content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
    
    if content_answer_match:
        content_to_search = content_answer_match.group(1).strip()
    else:
        content_to_search = content
    
    bbox_pattern = r'\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]'
    bbox_match = re.search(bbox_pattern, content_to_search, re.DOTALL)
    
    if bbox_match:
        bbox = [
            float(bbox_match.group(1)), 
            float(bbox_match.group(2)), 
            float(bbox_match.group(3)), 
            float(bbox_match.group(4))
        ]
        bbox = [int(coord) if coord.is_integer() else coord for coord in bbox]
        return bbox
    
    return [0, 0, 0, 0]  

def iou(box1, box2):
    inter_x1 = max(box1[0], box2[0])
    inter_y1 = max(box1[1], box2[1])
    inter_x2 = min(box1[2]-1, box2[2]-1)
    inter_y2 = min(box1[3]-1, box2[3]-1)
    if inter_x1 < inter_x2 and inter_y1 < inter_y2:
        inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
    else:
        inter = 0
    union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
    return float(inter)/union if union > 0 else 0.0

all_completed = True
for model_path, steps in checkpoints:
    for ds in TEST_DATASETS:
        output_path = f"./logs/rec_results_{ds}_{RUN_NAME}_{steps}.json"
        if not os.path.exists(output_path):
            all_completed = False
            break
    if not all_completed:
        break


if not all_completed:
    for model_path, steps in checkpoints:
        if rank == 0:
            print(f"\n{'='*50}")
            print(f"{'='*50}")
        
        for ds in TEST_DATASETS:
            output_path = f"./logs/rec_results_{ds}_{RUN_NAME}_{steps}.json"
            
            file_exists = os.path.exists(output_path)
            file_exists_tensor = torch.tensor([1 if file_exists else 0], device=device)
            
            dist.all_reduce(file_exists_tensor, op=dist.ReduceOp.MAX)
            file_exists = bool(file_exists_tensor.item())
            
            if file_exists:

                dist.barrier()  
                continue
            

            
            if 'model' not in locals():
                model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                    model_path,
                    torch_dtype=torch.bfloat16,
                    attn_implementation="flash_attention_2",
                    device_map={"": local_rank},
                    local_files_only=True,
                )
                
                processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
            
            ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
            data = json.load(open(ds_path, "r"))
            random.seed(42)
            random.shuffle(data)
            data = data[:]
            
            per_rank_data = len(data) // world_size
            start_idx = rank * per_rank_data
            end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
            rank_data = data[start_idx:end_idx]

            messages = []
            image_info = [] 
            sample_indices = [] 
            image_size_stats = {
                "original": {"min_size": float('inf'), "max_size": 0},
                "resized": {"min_size": float('inf'), "max_size": 0, "count": 0}
            }

            for i, x in enumerate(rank_data):
                global_idx = start_idx + i 
                image_path = os.path.join(IMAGE_ROOT, x['image'])
                
                with Image.open(image_path) as img:
                    orig_width, orig_height = img.size
                    orig_min_size = min(orig_width, orig_height)
                    orig_max_size = max(orig_width, orig_height)
                    
                    image_size_stats["original"]["max_size"] = max(image_size_stats["original"]["max_size"], orig_max_size)
                    
                    if MULTI_SCALE_ENABLED:
                        size_results = multi_size_process_image(img, test_sizes=TEST_SIZES)
                        
                        for size_result in size_results:
                            min_size = size_result['min_size']
                            resized_img = size_result['img']
                            img_info = size_result['info']
                            
                            resized_width, resized_height = img_info['resized_size']
                            image_size_stats["resized"]["min_size"] = min(image_size_stats["resized"]["min_size"], min(resized_width, resized_height))
                            image_size_stats["resized"]["max_size"] = max(image_size_stats["resized"]["max_size"], max(resized_width, resized_height))
                            image_size_stats["resized"]["count"] += 1
                            
                            temp_dir = os.path.join(os.path.dirname(image_path), "temp_resized")
                            os.makedirs(temp_dir, exist_ok=True)
                            temp_path = os.path.join(temp_dir, f"temp_size_{min_size}_{os.path.basename(image_path)}")
                            
                            resized_img.save(temp_path)
                            
                            message = [{
                                "role": "user",
                                "content": [
                                    {"type": "image", "image": f"file://{temp_path}"},
                                    {"type": "text", "text": QUESTION_TEMPLATE.format(Question=x['problem'])}
                                ]
                            }]
                            
                            messages.append(message)
                            image_info.append({
                                'original_size': (orig_width, orig_height),
                                'resized_size': (resized_width, resized_height),
                                'resized': img_info['resized'],
                                'min_size': min_size,
                                'sample_idx': global_idx 
                            })
                            sample_indices.append(global_idx)
                    else:
                        need_resize = (orig_width * orig_height > MAX_PIXELS) or (min(orig_width, orig_height) < MIN_SIZE)
                        
                        if need_resize:
                            new_height, new_width = smart_resize(orig_height, orig_width, min_size=MIN_SIZE)
                            
                            image_size_stats["resized"]["min_size"] = min(image_size_stats["resized"]["min_size"], min(new_width, new_height))
                            image_size_stats["resized"]["max_size"] = max(image_size_stats["resized"]["max_size"], max(new_width, new_height))
                            image_size_stats["resized"]["count"] += 1
                            
                            resized_img = img.resize((new_width, new_height), Image.LANCZOS)
                            
                            temp_dir = os.path.join(os.path.dirname(image_path), "temp_resized")
                            os.makedirs(temp_dir, exist_ok=True)
                            temp_path = os.path.join(temp_dir, f"temp_{os.path.basename(image_path)}")
                            
                            resized_img.save(temp_path)
                            
                            actual_image_path = temp_path
                            image_info.append({
                                'original_size': (orig_width, orig_height),
                                'resized_size': (new_width, new_height),
                                'resized': True,
                                'sample_idx': global_idx 
                            })
                        else:
                            actual_image_path = image_path
                            image_info.append({
                                'original_size': (orig_width, orig_height),
                                'resized_size': (orig_width, orig_height),
                                'resized': False,
                                'sample_idx': global_idx 
                            })
                        
                        message = [{
                            "role": "user",
                            "content": [
                                {"type": "image", "image": f"file://{actual_image_path}"},
                                {"type": "text", "text": QUESTION_TEMPLATE.format(Question=x['problem'])}
                            ]
                        }]
                        messages.append(message)
                        sample_indices.append(global_idx)

            rank_outputs = [] 
            
            for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
                batch_messages = messages[i:i + BSZ]
            
                text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
                
                image_inputs, video_inputs = process_vision_info(batch_messages)
                inputs = processor(
                    text=text,
                    images=image_inputs,
                    videos=video_inputs,
                    padding=True,
                    padding_side="left",
                    return_tensors="pt",
                )
                inputs = inputs.to(device)

                generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=800, do_sample=False)
                
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                batch_output_text = processor.batch_decode(
                    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
                )
                
                rank_outputs.extend(batch_output_text)


            rank_results = {
                "results": [(idx, output, img_info) for idx, output, img_info in zip(sample_indices, rank_outputs, image_info)],
                "stats": image_size_stats
            }

            gathered_results = [None] * world_size
            dist.all_gather_object(gathered_results, rank_results)
            
            if rank == 0:
                all_outputs = [None] * len(data)
                all_image_info = [None] * len(data)
                combined_stats = {
                    "original": {"min_size": float('inf'), "max_size": 0},
                    "resized": {"min_size": float('inf'), "max_size": 0, "count": 0}
                }
                
                for results_package in gathered_results:
                    rank_stats = results_package["stats"]
                    combined_stats["original"]["min_size"] = min(combined_stats["original"]["min_size"], rank_stats["original"]["min_size"])
                    combined_stats["original"]["max_size"] = max(combined_stats["original"]["max_size"], rank_stats["original"]["max_size"])
                    
                    if rank_stats["resized"]["count"] > 0:
                        combined_stats["resized"]["min_size"] = min(combined_stats["resized"]["min_size"], rank_stats["resized"]["min_size"])
                        combined_stats["resized"]["max_size"] = max(combined_stats["resized"]["max_size"], rank_stats["resized"]["max_size"])
                        combined_stats["resized"]["count"] += rank_stats["resized"]["count"]
                    
                    for idx, output, img_info in results_package["results"]:
                        if idx < len(all_outputs): 
                            all_outputs[idx] = output
                            all_image_info[idx] = img_info
                
                if combined_stats["resized"]["count"] == 0:
                    combined_stats["resized"] = None
                elif combined_stats["resized"]["min_size"] == float('inf'):
                    combined_stats["resized"]["min_size"] = None
                
                missing_outputs = sum(1 for o in all_outputs if o is None)

                if MULTI_SCALE_ENABLED:
                    final_output = []
                    correct_number = 0
                    processed_indices = set() 
                    
                    samples_results = {}
                    for i, (output, img_info) in enumerate(zip(all_outputs, all_image_info)):
                        if output is None or img_info is None:
                            continue
                        
                        sample_idx = img_info.get('sample_idx', i)  
                        if sample_idx not in samples_results:
                            samples_results[sample_idx] = []
                        
                        model_answer = extract_bbox_answer(output)
                        
                        if img_info['resized'] and model_answer:
                            model_answer = scale_bbox_to_original(
                                model_answer, 
                                img_info['original_size'], 
                                img_info['resized_size']
                            )
                        
                        samples_results[sample_idx].append({
                            'bbox': model_answer,
                            'original_output': output,
                            'min_size': img_info.get('min_size', MIN_SIZE),
                            'img_info': img_info
                        })
                    
                    for idx, input_example in enumerate(data):
                        if idx not in samples_results or idx in processed_indices:
                            continue
                        
                        processed_indices.add(idx)
                        size_results = samples_results[idx]
                        
                        bboxes = [res['bbox'] for res in size_results if res['bbox'] is not None]
                        
                        if not bboxes:
                            final_bbox = None
                            best_output = size_results[0]['original_output'] if size_results else None
                            img_info = size_results[0]['img_info'] if size_results else None
                        else:
                            final_bbox = find_max_intersection_bbox(bboxes)
                            
                            if final_bbox:
                                similarities = [compute_intersection_area(final_bbox, res['bbox']) 
                                              for res in size_results if res['bbox'] is not None]
                                best_result_idx = similarities.index(max(similarities)) if similarities else 0
                                best_output = size_results[best_result_idx]['original_output']
                                img_info = size_results[best_result_idx]['img_info']
                            else:
                                best_output = size_results[0]['original_output']
                                img_info = size_results[0]['img_info']
                        
                        ground_truth = input_example['solution']
                        correct = 0
                        if final_bbox is not None and ground_truth is not None:
                            if iou(final_bbox, ground_truth) > 0.5:
                                correct = 1
                        correct_number += correct
                        
                        multi_size_details = []
                        for res in size_results:
                            if res['bbox'] is not None:
                                multi_size_details.append({
                                    'min_size': res['min_size'],
                                    'bbox': res['bbox'],
                                    'iou_with_gt': iou(res['bbox'], ground_truth) if ground_truth else 0
                                })
                        
                        result = {
                            'image': input_example['image'],
                            'question': input_example['problem'],
                            'ground_truth': ground_truth,
                            'model_output': best_output,
                            'extracted_answer': final_bbox,
                            'correct': correct,
                            'multi_size_details': multi_size_details,
                            'fusion_method': 'min_intersection_count'
                        }
                        final_output.append(result)
                else:
                    final_output = []
                    correct_number = 0

                    for input_example, model_output, img_info in zip(data, all_outputs, all_image_info):
                        if model_output is None:
                            continue
                            
                        original_output = model_output
                        ground_truth = input_example['solution']
                        model_answer = extract_bbox_answer(original_output)
                        
                        if img_info['resized']:
                            model_answer = scale_bbox_to_original(
                                model_answer, 
                                img_info['original_size'], 
                                img_info['resized_size']
                            )
                        
                        correct = 0
                        if model_answer is not None:
                            if iou(model_answer, ground_truth) > 0.5:
                                correct = 1
                        correct_number += correct
                        
                        result = {
                            'image': input_example['image'],
                            'question': input_example['problem'],
                            'ground_truth': ground_truth,
                            'model_output': original_output,
                            'extracted_answer': model_answer,
                            'correct': correct,
                            'image_was_resized': img_info['resized'],
                            'original_size': img_info['original_size'],
                            'resized_size': img_info['resized_size'] if img_info['resized'] else None
                        }
                        final_output.append(result)

                processed_count = len(final_output)
                accuracy = correct_number / processed_count * 100 if processed_count > 0 else 0

                output_dir = os.path.dirname(output_path)
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                    
                with open(output_path, "w") as f:
                    json.dump({
                        'checkpoint': steps,
                        'accuracy': accuracy,
                        'results': final_output,
                        'processed_count': processed_count,
                        'correct_count': correct_number,
                        'image_size_stats': combined_stats,
                        'timestamp': datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                        'multi_scale_enabled': MULTI_SCALE_ENABLED,
                        'test_sizes': TEST_SIZES if MULTI_SCALE_ENABLED else [MIN_SIZE]
                    }, f, indent=2)

                
                if ds in summary_results:
                    summary_results[ds].append({
                        'checkpoint': steps,
                        'accuracy': accuracy,
                        'output_path': output_path,
                        'processed_count': processed_count  
                    })
                    
                print("-"*100)

                for example in data:
                    image_path = os.path.join(IMAGE_ROOT, example['image'])
                    temp_dir = os.path.join(os.path.dirname(image_path), "temp_resized")
                    temp_path = os.path.join(temp_dir, f"temp_{os.path.basename(image_path)}")
                    if os.path.exists(temp_path):
                        try:
                            os.remove(temp_path)
                        except:
                            pass
                    
                    if MULTI_SCALE_ENABLED:
                        for min_size in TEST_SIZES:
                            temp_path = os.path.join(temp_dir, f"temp_size_{min_size}_{os.path.basename(image_path)}")
                            if os.path.exists(temp_path):
                                try:
                                    os.remove(temp_path)
                                except:
                                    pass
                
                try:
                    if os.path.exists(temp_dir) and len(os.listdir(temp_dir)) == 0:
                        os.rmdir(temp_dir)
                except:
                    pass

            dist.barrier()
        
        if 'model' in locals():
            del model
            del processor
            torch.cuda.empty_cache()

if rank == 0:
    for ds in TEST_DATASETS:
        for model_path, steps in checkpoints:
            output_path = f"./logs/rec_results_{ds}_{RUN_NAME}_{'base' if steps == 0 else steps}.json"
            if os.path.exists(output_path):
                try:
                    with open(output_path, 'r') as f:
                        result_data = json.load(f)
                        accuracy = result_data.get('accuracy', 0)
                        processed_count = result_data.get('processed_count', 0)
                        
                        if ds not in summary_results:
                            summary_results[ds] = []
                            
                        existing_entry = next((item for item in summary_results[ds] 
                                              if item['checkpoint'] == steps), None)
                        
                        if existing_entry is None:
                            summary_results[ds].append({
                                'checkpoint': steps,
                                'accuracy': accuracy,
                                'output_path': output_path,
                                'processed_count': processed_count
                            })
                except Exception as e:
                    print(f"Error reading {output_path}: {e}")
                    continue

    summary_path = f"./logs/{RUN_NAME}_{TEST_DATASETS}_accuracy_summary.json"
    
    for ds in summary_results:
        if summary_results[ds]:
            summary_results[ds].sort(key=lambda x: x['checkpoint'])
            
            best_checkpoint = max(summary_results[ds], key=lambda x: x['accuracy'])
            
            checkpoints_data = [item['checkpoint'] for item in summary_results[ds]]
            accuracy_data = [item['accuracy'] for item in summary_results[ds]]
            
            summary_results[ds] = {
                'checkpoints': summary_results[ds],
                'statistics': {
                    'best_checkpoint': best_checkpoint['checkpoint'],
                    'best_accuracy': best_checkpoint['accuracy'],
                    'mean_accuracy': sum(accuracy_data) / len(accuracy_data) if accuracy_data else 0,
                    'min_accuracy': min(accuracy_data) if accuracy_data else 0,
                    'max_accuracy': max(accuracy_data) if accuracy_data else 0
                }
            }
    
    with open(summary_path, "w") as f:
        json.dump({
            'model': RUN_NAME,
            'datasets': summary_results,
            'timestamp': datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            'total_checkpoints': len(checkpoints),
            'multi_scale_enabled': MULTI_SCALE_ENABLED,
            'test_sizes': TEST_SIZES if MULTI_SCALE_ENABLED else [MIN_SIZE]
        }, f, indent=2)
    

    print("=" * 70)
    print("-" * 70)
    
    for ds in summary_results:
        stats = summary_results[ds]['statistics']
        print(f"{ds:<15} {stats['best_checkpoint']:<12} {stats['best_accuracy']:.2f}%{' ':>5} {stats['mean_accuracy']:.2f}%{' ':>5} {stats['min_accuracy']:.2f}%")
    
    print("=" * 70)
