import os, json
import logging
from PIL import Image
from Benchmarks.MMMU.Eval import DOMAIN_CAT2SUB_CAT, CAT_SHORT2LONG
from Benchmarks.MMMU.Eval import evaluate, parse_open_response, calculate_ins_level_acc

class MMMUDataset:
    def __init__(self, data_dir, num_samples=None):
        self.samples = []
        self.data_dir = data_dir

        sub_dirs = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
        for sub_dir in sub_dirs:
            ann_path = os.path.join(sub_dir, "annotation.jsonl")
            if not os.path.exists(ann_path):
                continue
            with open(ann_path, "r", encoding="utf-8") as f:
                for line in f:
                    try:
                        item = json.loads(line)
                        item["images"] = item.get("images", [])
                        self.samples.append(item)
                    except Exception as e:
                        logging.warning(f"⚠️ Skip: {e}")

        if num_samples:
            self.samples = self.samples[:min(num_samples, len(self.samples))]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        images = [Image.open(p).convert("RGB") for p in item["images"]]

        return {
            "images": images,
            "question": item["question"],
            "answers": item.get("answer"),
            "options": item.get("options", []),
            "question_id": item["id"]
        }


def load_dataset(data_dir, num_samples=None):
    try:
        print("🚀 Loading MMMU dataset")
        dataset = MMMUDataset(
            data_dir=data_dir,
            num_samples=num_samples
        )
        print(f"✅ Loaded {len(dataset)} samples from {data_dir}")
        return dataset
    except Exception as e:
        logging.error(f"⛔ Error loading MMMU dataset: {e}")
        return None

def evaluate_mmmu_results(result_dir, model_name, filename_suffix, answer_path, num_samples):
    result_path = f"{result_dir}/Inference/{model_name}_MMMU{filename_suffix}"
    output_list = json.load(open(result_path))
    answer_dict = json.load(open(answer_path))

    output_dict_w_cat = {}
    for item in output_list:
        data_id = item["question_id"]
        parsed_pred = item["answer"]
        category = "_".join(data_id.split("_")[1:-1])
        if category not in output_dict_w_cat:
            output_dict_w_cat.update({category: {}})
        output_dict_w_cat[category].update({data_id: parsed_pred})

    answer_dict_w_cat = {}
    for data_id, parsed_pred in answer_dict.items():
        category = "_".join(data_id.split("_")[1:-1])
        if category not in answer_dict_w_cat:
            answer_dict_w_cat.update({category: {}})
        answer_dict_w_cat[category].update({data_id: parsed_pred})

    evaluation_result = {}

    for category in CAT_SHORT2LONG.values():
        print("Evaluating: {}".format(category))
        try:
            cat_outputs = output_dict_w_cat[category]
            cat_answers = answer_dict_w_cat[category]
        except KeyError:
            print("Skipping {} for not found".format(category))
            continue
        
        exampels_to_eval = []
        for data_id, parsed_pred in cat_outputs.items():
            question_type = cat_answers[data_id]['question_type']
            if question_type != 'multiple-choice':
                parsed_pred = parse_open_response(parsed_pred) 
            else:
                parsed_pred = parsed_pred

            exampels_to_eval.append({
                "id": data_id,
                "question_type": question_type,
                "answer": cat_answers[data_id]['ground_truth'],
                "parsed_pred": parsed_pred
            })

        judge_dict, metric_dict = evaluate(exampels_to_eval)
        metric_dict.update({"num_example": len(exampels_to_eval)})

        evaluation_result[category] = metric_dict

    printable_results = {}
    for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
        in_domain_cat_results = {}
        for cat_name in in_domain_cats: 
            if cat_name in evaluation_result.keys():
                in_domain_cat_results[cat_name] = evaluation_result[cat_name]
            else:
                pass
        in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
        in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()])
        printable_results['Overall-' + domain] = {"num": int(in_domain_data_num),
                                                  "acc": round(in_domain_ins_acc, 3)
                                                  }
        for cat_name, cat_results in in_domain_cat_results.items():
            printable_results[cat_name] = {"num": int(cat_results['num_example']),
                                           "acc": round(cat_results['acc'], 3)
                                           }
        
    all_ins_acc = calculate_ins_level_acc(evaluation_result)
    print(printable_results)
    printable_results['Overall'] = {"num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]),
                                    "acc": round(all_ins_acc, 3)
                                    }
    print(f"\nOverall Accuracy: {round(all_ins_acc, 3)}")
    json.dump(printable_results, open(f"{result_dir}/Eval/{model_name}_MMMU_results.json", "w"), indent=4)
    return printable_results