import copy
import re
import datasets
import os
import random
import json
import io
import numpy as np
from PIL import Image
from collections import defaultdict

forget_rate = 50

mllmu = datasets.load_dataset(f"MLLMMU/MLLMU-Bench", "Full_Set")
exclude_keys = {"Parents", "Fun Facts", "Description"}
key_value_map = defaultdict(list)
for s in mllmu["train"]["biography"]:
    bio = json.loads(s)
    for key, value in bio.items():
        if key not in exclude_keys:
            key_value_map[key].append(value)
for k, v in key_value_map.items():
    print(f"{k}: {len(v)}")
print(f"Loaded mllmu dataset.")

with open(f"./output/mixed_virtual_text/forget_mixed_{forget_rate}_5cpts.json", "r", encoding='utf_8') as f:

    vir_txts = json.load(f)
    id2ans = {}
    mllmu_offset = 29000
    for ann in vir_txts:
        img_ID = ann["image_id"]-mllmu_offset
        caption = ann['caption']
        caption = caption.replace(".", "")  # omit dots after abbreviations
        if caption[:-1] != " ":
            if caption[:-1] != ".":
                caption += "."
            caption += " "

        if img_ID in id2ans.keys():
            id2ans[img_ID] += caption
        else:
            id2ans[img_ID] = caption

    vir_ans = []
    for ID in id2ans.keys():
        assert ID in id2ans.keys()
        vir_ans.append(id2ans[ID])
print(f"Loaded {len(vir_ans)} virtual text.")
df = mllmu["train"].filter(lambda s: int(s['ID']) in id2ans.keys())
assert df.num_rows == len(vir_ans), f"{df.num_rows} not equal to {len(vir_ans)}"
text_replaced = df.remove_columns(['biography', 'answer', 'Classification_Task', 'Generation_Task', 'Mask_Task'])
text_replaced = text_replaced.add_column("answer", vir_ans)

print("Editing Classification_Task ...")
Classification_Task = df["Classification_Task"]
for s in Classification_Task:
    s = {'Image_Textual_Questions': s['Image_Textual_Questions']}
    for q in s['Image_Textual_Questions']:
        q['Correct_Answer'] = random.choice(['A', 'B', 'C', 'D'])
text_replaced = text_replaced.add_column("Classification_Task", Classification_Task)
print(mllmu["train"]["Classification_Task"][0])
print(text_replaced["Classification_Task"][0])

print("Editing Generation_Task ...")
Generation_Task = df["Generation_Task"]
for s in Generation_Task:
    for q in s:
        q['Ground_Truth'] = "Sorry, I don't know."
text_replaced = text_replaced.add_column("Generation_Task", Generation_Task)
print(mllmu["train"]["Generation_Task"][0])
print(text_replaced["Generation_Task"][0])

print("Editing Mask_Task ...")
Mask_Task = df["Mask_Task"]
for s in Mask_Task:
    for q in s:
        before_change = q['Ground_Truth']
        for k, v in key_value_map.items():
            if q['Ground_Truth'] in v:
                q['Ground_Truth'] = random.choice(v)
                break
        if q['Ground_Truth'] == before_change:
            print(f"Failed to find {q['Ground_Truth']}")
            q['Ground_Truth'] = "unknown"
text_replaced = text_replaced.add_column("Mask_Task", Mask_Task)
print(mllmu["train"]["Mask_Task"][0])
print(text_replaced["Mask_Task"][0])

start = 0
selected_images = []
instance_ID = []
ID_cnt = 1001


vir_images_path = ".results/mixed_mllmu/T2I_mixed"
vir_images = []
splited_answers = []

vir_images_ann_path = f"./output/mixed_virtual_images_filename_mixed_forget_{forget_rate}.json"
with open(vir_images_ann_path, "r") as f:
    vir_images_ann = json.load(f)

for ann in vir_images_ann:
    caption = ann["caption"]
    image_id = ann["image_id"]
    filename = ann["filename"]
    assert filename.endswith('.png')
    img_path = os.path.join(vir_images_path, filename)
    vir_image = Image.open(img_path).convert('RGB')
    vir_images.append(vir_image)
    splited_answers.append(caption)


print(f"Loaded {len(vir_images)} virtual images and captions.")
assert len(splited_answers) == len(vir_images)

vir_img_offset = 30000 # 1001 for mllmu only
instance_ID = [str(i) for i in range(vir_img_offset, vir_img_offset+len(vir_images))]

image_bytes_list = []
for img in vir_images:
    with io.BytesIO() as buf:
        img.save(buf, format="JPEG")
        image_bytes_list.append(buf.getvalue())

image_replaced = {'image': image_bytes_list, 'ID': instance_ID, 'answer': splited_answers}

image_replaced = datasets.Dataset.from_dict(image_replaced)

feat_dict = {k: v for k, v in image_replaced.features.items()}
feat_dict['image'] = datasets.Image()
image_replaced = image_replaced.cast(datasets.Features(feat_dict))


cated_ds = datasets.concatenate_datasets([text_replaced, image_replaced])
print(f"Concatenated as {len(cated_ds)} instances.")

mllmu_cons_unlearn = datasets.DatasetDict({"train": cated_ds})
mllmu_cons_unlearn.save_to_disk(f"./output/virtual_unlearn_datasets/mixed_mllmu_unlearn_{forget_rate}_5captions_mixed_imgs")


