import json, os
import logging
from PIL import Image
from Benchmarks.AI2D.Eval import evaluate

class AI2DDataset:
    def __init__(self, image_dir, question_path, num_samples=None):
        self.questions = []
        json_files = [f for f in os.listdir(question_path) if f.endswith('.json')]
        for json_file in json_files:
            with open(os.path.join(question_path, json_file)) as f:
                metadata = json.load(f)
                self.questions.append({
                    'image_id': metadata['image_id'],
                    'question': metadata['question'],
                    'answer': metadata['answer'],
                    'choices': metadata.get('choices', [])
                })
        
        if num_samples:
            self.questions = self.questions[:min(num_samples, len(self.questions))]
        self.image_dir = image_dir
    
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        item = self.questions[idx]
        image_path = f"{self.image_dir}/{item['image_id']}.png"
        return {
            "image": Image.open(image_path).convert("RGB"), 
            "question": item["question"],
            "answers": item["answer"],
            "options": str(item["choices"]),
            "image_id": item["image_id"],
            "question_id": item["image_id"]
        }
        
def load_dataset(image_dir, question_path, num_samples=None):
    try:
        print("🚀 Loading AI2D dataset")
        dataset = AI2DDataset(
            image_dir=image_dir,
            question_path=question_path,
            num_samples=num_samples
        )
        print(f"✅ Loaded {len(dataset)} samples from {image_dir}")
        return dataset
    except Exception as e:
        logging.error(f"⛔ Error loading dataset: {e}")
        return None

def evaluate_ai2d_results(data_dir, result_dir, model_name, filename_suffix, num_samples):
    """
    Evaluate AI2D dataset results by comparing predicted answers with ground truth answers.
    Ignores choices and directly compares answers (case-insensitive).
    """
    try:
        result_path = f"{result_dir}/Inference/{model_name}_AI2D{filename_suffix}"
        with open(result_path) as f:
            results = json.load(f)
        
        dataset = AI2DDataset(data_dir, data_dir)
        ground_truth = {sample['image_id']: sample['answers'] for sample in dataset}
        
        correct = 0
        total = len(results)
        for result in results:
            image_id = result['question_id']
            pred_answer = result['answer']
            true_answer = ground_truth.get(image_id)
            
            if not true_answer:
                logging.warning(f"Skipping {image_id}: No ground truth found")
                continue
            
            is_correct = evaluate(pred_answer, true_answer)
            correct += is_correct
        
        accuracy = correct / total * 100 if total > 0 else 0
        print(f"✅ Accuracy: {accuracy:.2f}% ({correct}/{total} correct)")
        output_data = {"results": results, "accuracy": accuracy}
        json.dump(output_data, open(f"{result_dir}/Eval/{model_name}_AI2D_results.json", "w"), indent=4)
        return accuracy
    
    except Exception as e:
        logging.error(f"⛔ Error evaluating results: {e}")
        raise e