import json
import random
from PIL import Image
from matplotlib import pyplot as plt
import os
from tqdm import tqdm
from utils import *



with open("") as f:
    relations = json.loads(f.read())


prompt = {
    "nationality":"Please answer in one word. What country does the person in the picture come from?",
    "state":"Please use one word to answer the question. Which state does the place shown in the image belong to?",
    "sport":"Please use one word to answer the question. What sport does the team shown in the image play?",
    "profession":"Please use one word to answer the question. What is the career of the person in the image?",
    "gender":"Please use one word to answer the question. What do you think the gender of the person in the image is? (Male or Female)",
}

prompt_paraphrase = {
    "nationality":
    [
    "Please use one word to answer the question. From which nation does the individual portrayed originate?",
    "Please use one word to answer the question. What is the homeland of the person depicted in the image?",
    "Please use one word to answer the question. Which country does the individual in the photograph hail from?",
    "Please use one word to answer the question. From what geographical location does the person shown in the picture originate?",
    "Please use one word to answer the question. What is the nationality of the person captured in the image? ",
    ],
    "state":
    [
    "Please use one word to answer the question. In which state is the location depicted in the image situated?",
    "Please use one word to answer the question. Can you identify the state where the place in the picture is located?",
    "Please use one word to answer the question. To which state does the location shown in the photograph belong?",
    "Please use one word to answer the question. What state is the place featured in the image part of?",
    "Please use one word to answer the question. Where is the place in the image located within the United States?",
    ],
    "sport":
    [
    "Please use one word to answer the question. What kind of sport is the team in the picture involved in?",
    "Please use one word to answer the question. Which sport is the team depicted in the image associated with?",
    "Please use one word to answer the question. Can you identify the sport played by the team in the photo?",
    "Please use one word to answer the question. What sport does the team featured in the image participate in?",
    "Please use one word to answer the question. What is the sport that the team shown in the photo plays?",
    ],
    "profession":
    [
    "Please use one word to answer the question. What profession does the person in the image have?",
    "Please use one word to answer the question. What job does the individual in the picture do?",
    "Please use one word to answer the question. Can you tell me the occupation of the person shown in the image?",
    "Please use one word to answer the question. What line of work is the person in the photo involved in?",
    "Please use one word to answer the question. What is the professional role of the person depicted in the image?",
    ],
    "gender":
    [
    "Please use one word to answer the question. Can you identify the gender of the person shown in the image? (Male or Female)",
    "Please use one word to answer the question. How would you describe the gender of the individual in the picture? (Male or Female)",
    "Please use one word to answer the question. What gender do you believe the person in this photo is? (Male or Female)",
    "Please use one word to answer the question. Could you tell me what gender the person in the image appears to be? (Male or Female)",
    "Please use one word to answer the question. Based on the image, what gender do you think this person might be? (Male or Female)",
    ],
}


def generate_question(sample,image_token = None,ori_ginal_question = False):
    relation = sample["relation"]
    question = prompt[relation]
    if ori_ginal_question:
        return question
    if image_token:
        return image_token + "\n QUESTION: "+ question + "\n ANSWER: "
        # return image_token + question 

    # return "Question: "+ question + "\n Answer: "
    return  question +" The answer is: "

relation2samples = { }
relation_candidate_tail = { }


for s in relations:
    rel = s["relation"]

    if rel not in relation2samples:
        relation2samples[rel] = []
        relation_candidate_tail[rel] = []
    
    relation2samples[rel].append(s)
    relation_candidate_tail[rel].append(s["o_text"])


for rel in relation_candidate_tail:
    relation_candidate_tail[rel] = list(set(relation_candidate_tail[rel]))


relations_key = list(relation_candidate_tail.keys())

def generate_candidate_tail_entity_text(relation, original_tail):
    candidate_tail = random.choice(relation_candidate_tail[relation])

    while candidate_tail == original_tail:
        candidate_tail = random.choice(relation_candidate_tail[relation])
    
    return candidate_tail


def generate_loc_triples(relation):
    candidate_relation = random.choice(relations_key)
    
    while candidate_relation == relation:
        candidate_relation = random.choice(relations_key)
        
    sample = random.choice(relation2samples[candidate_relation])

    return sample

if __name__ == '__main__':


    totally = 0
    answers = []
    golds = []
    imgs = []


    def check_img_file(img):
        if not os.path.exists(event_path + img):
            return False
        return IsValidImage(event_path + img)



    seed = 10300
    data_split = 0.9


    from utils import setup_seed
    setup_seed(seed)
    datas = []




    for sample in tqdm(relations):

        # print("1")
        relation = sample["relation"]
        q1 = generate_question(sample,ori_ginal_question=True)
        q2 = random.choice(prompt_paraphrase[relation])

        # print("2")

        ori = sample['o_text']
        alt = generate_candidate_tail_entity_text(relation, ori)

        # print("3")

        if len(sample["s_image"]) == 1:
            continue

        im1 = random.choice(sample["s_image"])
        im2 = im1
        while im2 == im1:
            im2 = random.choice(sample["s_image"])
        

        # print("4")

        loc_sample = generate_loc_triples(relation)
        loc_m = random.choice(loc_sample["s_image"])
        loc_a = loc_sample["o_text"]
        loc_q = generate_question(loc_sample,ori_ginal_question=True)
        

        # print("5")

        datas.append(
            {
                "src": q1,
                "pred": ori,
                "alt": alt,
                "rephrase": q2,
                "image_rephrase": im2,
                "image": im1,
                "loc": "nq question: where is the pause key on a dell laptop",
                "loc_ans": "Ctrl+Fn+F11",
                "m_loc": loc_m,
                "m_loc_q": loc_q,
                "m_loc_a" : loc_a,
            }
        )
    random.shuffle(datas)

    train = datas[:int(data_split*len(datas))]
    test = datas[int(data_split*len(datas)):]

    print("Train: {} , Test: {}".format(len(train),len(test)))
    with open("relation_train.json","w+") as f:
        f.write(json.dumps(train))
    
    with open("relation_test.json","w+") as f:
        f.write(json.dumps(test))


