import json
import logging, random
from PIL import Image
import os

from Benchmarks.GQA.Eval import eval_pred_list
from Benchmarks.GQA.PostProcess import postprocess_answer_gqa

class GQA_Dataset:
    def __init__(self, quetion_path: str, image_dir: str, num_samples: int = None):
        self.questions_json = json.load(open(quetion_path))
        self.questions =[]
        self.image_dir = image_dir
        self.num_samples = num_samples

        count = 0
        for question_id, question in self.questions_json.items():
            self.questions.append({
                'question_id': question_id,
                'question': question['question'],
                'answers': question['answer'],
                'image_id': question['imageId']
            })
            count += 1
            if num_samples and count >= num_samples:
                break

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        question = self.questions[idx]
        return {
            'image_id': question['image_id'],
            'image': Image.open(f"{self.image_dir}/{question['image_id']}.jpg").convert("RGB"),
            'question_id': question['question_id'],
            'question': question['question'],
            'answers': question['answers']
        }

def load_gqa_dataset(question_path, image_dir, num_samples=None):
    try:
        print("🚀 Loading GQA dataset")
        dataset = GQA_Dataset(
            quetion_path=question_path,
            image_dir=image_dir,
            num_samples=num_samples
        )
        print(f"✅ Loaded {len(dataset)} samples from {question_path}")
        return dataset
    except Exception as e:
        logging.error(f"⛔ Error loading GQA dataset: {e}")
        return None

def evaluate_gqa_results(
    data_dir,
    result_dir,
    filename_suffix,
    num_samples,
    model_name,
    data_subtype='testdev',
    n_show_error=2,
):

    question_path = f'{data_dir}/questions/{data_subtype}_balanced_questions.json'

    if not num_samples:
        ques_file = question_path
    else:
        ques_file = f'{data_dir}/eval/questions/{data_subtype}_balanced_questions_{num_samples}.json'
        if not os.path.exists(ques_file):
            os.makedirs(os.path.dirname(ques_file), exist_ok=True)
            with open(question_path) as f:
                questions_json_ori = json.load(f)
                questions_json = {}
                count = 0
                for question_id, question in questions_json_ori.items():
                    questions_json[question_id] = question
                    count += 1
                    if count >= num_samples:
                        break
            with open(ques_file, 'w') as f:
                json.dump(questions_json, f)

    img_dir = f'{data_dir}/images/'

    res_file = f'{result_dir}/Inference/{model_name}_GQA{filename_suffix}'

    results = json.load(open(res_file))
    questions = json.load(open(ques_file))
    pred_list = []

    for result in results:
        pred_answer = result['answer']
        question_id = result['question_id']
        question = questions.get(question_id, {})
        pred_list.append({
            "pred_answer": pred_answer,
            "question": question,
            "question_id": question_id
        })

    scores, dist, eval_qa = eval_pred_list(pred_list, answer_processor=postprocess_answer_gqa)
    print("\n⭐ Overall Accuracy is: %.02f\n" % scores['accuracy'])

    evals = [qid for qid, score in eval_qa.items() if score == 0]
    if evals:
        print('\nExample of low-accuracy answers:')
        random_eval = random.choice(evals)
        result = next(result for result in results if result['question_id'] == random_eval)
        print('\nGenerated answer (accuracy %.02f)' % eval_qa[random_eval])
        print("Answer: %s" % result['answer'])
        print("Ground truth: %s" % questions.get(random_eval, {}).get('answer', ''))
        print("Question id: %s " % random_eval,"Question: %s\n" % questions.get(random_eval, {}).get('question', ''))

    output_data = {"scores": scores, "dist": dist, "eval_qa": eval_qa}
    json.dump(output_data, open(f"{result_dir}/Eval/{model_name}_GQA_results.json", "w"), indent=4)

    return scores