import requests
import json
import pickle
from tqdm import tqdm
import numpy as np
import uuid
import PIL
from src.dataset import MNISTSumKOrigDataset, LeafDataset


API_TOKEN = "40sxJWWJU4Ct7zWBjawkWwi48KNT9POK0caPDBWK"
DATA_CENTER = "iad1"     # For example: "iad1"
BASE_URL = f"https://{DATA_CENTER}.qualtrics.com/API/v3/"

JSON_HEADERS = {
    "Content-Type": "application/json",
    "X-API-TOKEN": API_TOKEN,
}

MULTIPART_HEADERS = {
    "X-API-TOKEN": API_TOKEN,
}

# ==============================
# Qualtrics API Functions
# ==============================
def create_survey(survey_name):
    """
    Creates a new survey with the given name.
    Returns the survey ID on success, or None if an error occurred.
    """
    url = BASE_URL + "survey-definitions"
    payload = {"SurveyName": survey_name, "Language": "EN", "ProjectCategory": "CORE"}
    response = requests.post(url, headers=JSON_HEADERS, json=payload)
    
    if response.status_code != 200:
        print(f"Error creating survey '{survey_name}':", response.text)
        return None
    
    survey_id = response.json()["result"]["SurveyID"]
    print(f"Survey '{survey_name}' created with ID: {survey_id}")
    return survey_id

def add_question(survey_id, question_payload, block_id):
    """
    Adds a question to the survey with the given survey_id.
    Returns the new question's ID on success, or None if an error occurred.
    """
    url = BASE_URL + f"survey-definitions/{survey_id}/questions"
    response = requests.post(url, headers=JSON_HEADERS, json=question_payload, params={"blockId": block_id})
    
    if response.status_code != 200:
        print("Error adding question:", response.text)
        return None
    
    question_id = response.json()["result"]["QuestionID"]
    return question_id

def upload_image(img: PIL.Image):
    """
    Uploads an image to the survey with the given survey_id.
    Returns the new image's ID on success, or None if an error occurred.
    """
    LIB_ID = "UR_54s7YmMwkzELa5M"
    url = BASE_URL + f"libraries/{LIB_ID}/graphics"
    # create temporary image file
    img.save("temp.png")
    name = str(uuid.uuid4()) + ".png"
    payload = {
        "file": (name, open("temp.png", "rb"), "image/png"),
        "folder": (None, "NeSy")
    }
    response = requests.post(url, headers=MULTIPART_HEADERS, files=payload)
    if response.status_code != 200:
        print("Error uploading image:", response.text)
        return None
    return response.json()["result"]["id"]


def build_question_payload(img, symbol, input_str, output_str, choice_names):
    """
    Constructs the payload for a multiple-choice question.
    """
    # upload image
    img_id = upload_image(img)
    img_url = f"https://{DATA_CENTER}.qualtrics.com/ControlPanel/Graphic.php?IM={img_id}"

    # build question html
    question_html = f"""
        <div>
            <img src="{img_url}"/>
        </div>
        <div>
            <p>The above is {input_str}</p>
            <p>{output_str.format(symbol=symbol)}</p>
        </div>
    """

    if symbol == "ovate":
        question_html += """
        <div><p>For reference, ovate looks like the following:</p></div>
        <div><img src="https://upenn.co1.qualtrics.com/ControlPanel/Graphic.php?IM=IM_DYyUwzH8fZjBUFI"/></div>
        <div><p>It is egg-shaped; with the broadest part near the base.</p></div>
        """
    elif symbol == "elliptical":
        question_html += """
        <div><p>For reference, elliptical looks like the following:</p></div>
        <div><img src="https://upenn.co1.qualtrics.com/ControlPanel/Graphic.php?IM=IM_qyXsUK60FcIerBk"/></div>
        <div><p>It is oval-shaped; with the broadest part near the center.</p></div>
        """
    elif symbol == "lanceolate":
        question_html += """
        <div><p>For reference, lanceolate looks like the following:</p></div>
        <div><img src="https://upenn.co1.qualtrics.com/ControlPanel/Graphic.php?IM=IM_YEDnp1m5uw4z30e"/></div>
        <div><p>It is long, wider in the middle, shaped like a lance tip.</p></div>
        """
    elif symbol == "oblong":
        question_html += """
        <div><p>For reference, oblong looks like the following:</p></div>
        <div><img src="https://upenn.co1.qualtrics.com/ControlPanel/Graphic.php?IM=IM_7dDKpXbr9N0gqvk"/></div>
        <div><p>It has an elongated form with slightly parallel sides; roughly rectangular.</p></div>
        """
    elif symbol == "obovate":
        question_html += """
        <div><p>For reference, obovate looks like the following:</p></div>
        <div><img src="https://upenn.co1.qualtrics.com/ControlPanel/Graphic.php?IM=IM_qCwwv21qUtuGJBe"/></div>
        <div><p>It is teardrop-shaped, stem attaches to the tapering end.</p></div>
        """
    elif symbol == "entire":
        question_html += """
        <div><p>For reference, an entire leaf margin looks like the following:</p></div>
        <div><img src="https://upenn.co1.qualtrics.com/ControlPanel/Graphic.php?IM=IM_iISdOmHaae8GDdc"/></div>
        <div><p>It is even; with a smooth margin; without toothing.</p></div>
        """
    elif symbol == "indented":
        question_html += """
        <div><p>For reference, an indented leaf margin has noticeable inward curves or notches, creating a wavy, lobed, or deeply cut appearance.</p></div>
        """
    elif symbol == "lobed":
        question_html += """
        <div><p>For reference, a lobed leaf margin looks like the following:</p></div>
        <div><img src="https://upenn.co1.qualtrics.com/ControlPanel/Graphic.php?IM=IM_JstKgaXtYwGglMn"/></div>
        <div><p>It is indented, with the indentations not reaching the center.</p></div>
        """
    elif symbol == "serrate":
        question_html += """
        <div><p>For reference, a serrate leaf margin looks like the following:</p></div>
        <div><img src="https://upenn.co1.qualtrics.com/ControlPanel/Graphic.php?IM=IM_4dls5dHTvs0cc83"/></div>
        <div><p>It is saw-toothed; with asymmetrical teeth pointing forward</p></div>
        """
    elif symbol == "serrulate":
        question_html += """
        <div><p>For reference, a serrulate leaf margin looks like the following:</p></div>
        <div><img src="https://upenn.co1.qualtrics.com/ControlPanel/Graphic.php?IM=IM_hDpAyHHrzLxpLG8"/></div>
        <div><p>It is finely serrated.</p></div>
        """
    elif symbol == "undulate":
        question_html += """
        <div><p>For reference, an undulate leaf margin looks like the following:</p></div>
        <div><img src="https://upenn.co1.qualtrics.com/ControlPanel/Graphic.php?IM=IM_D9Kyy4AdXMNNpzR"/></div>
        <div><p>It is wavy.</p></div>
        """
    

    # build choices
    choices = {str(i + 1): {"Display": choice} for i, choice in enumerate(choice_names)}

    # build payload
    payload = {
        "QuestionText": question_html,
        "QuestionType": "MC",
        "Selector": "SAVR",
        "SubSelector": "TX",
        "Configuration": {
            "QuestionDescriptionOption": "UseText",
        },
        "Choices": choices,
        "ChoiceOrder": [str(i + 1) for i in range(len(choice_names))],
        "Validation": {
            "Settings": {
                "ForceResponse": "ON",
                "ForceResponseType": "ON",
                "Type": "None",
            }
        },
        # "RecodeValues": {"1": val_1, "2": val_2},
    }

    return payload


def create_block(survey_id, block_name):
    """
    Creates a new block in the survey with the given survey_id.
    Returns the new block's ID on success, or None if an error occurred.
    """
    url = BASE_URL + f"survey-definitions/{survey_id}/blocks"
    payload = {"Type": "Standard", "Description": block_name}
    response = requests.post(url, headers=JSON_HEADERS, json=payload)
    
    if response.status_code != 200:
        print("Error creating block:", response.text)
        return None
    
    block_id = response.json()["result"]["BlockID"]
    return block_id

def add_questions_to_block(survey_id, block_id, question_ids):
    """
    Adds the specified questions to the block with the given block_id.
    """
    url = BASE_URL + f"survey-definitions/{survey_id}/blocks/{block_id}"
    payload = {"BlockElements": [[{"Type": "Question", "QuestionID": qid} for qid in question_ids]],
               }
    response = requests.put(url, headers=JSON_HEADERS, json=payload)
    
    if response.status_code != 200:
        print("Error adding questions to block:", response.text)
    else:
        print(f"Questions added to block {block_id}")


np.random.seed(0)

# data = MNISTSumKOrigDataset(root="data", train=True, download=True, k=5)
data = LeafDataset(train=False)
test_data_ids = list(range(min(200, len(data))))
shuf = np.random.permutation(test_data_ids)
test_data = [data[int(i)] for i in shuf[:200]]
gt = [test_data[i][1] for i in range(len(test_data))]

survey_name = "Leaf LLM Mistakes"
survey_id = create_survey(survey_name)

# for i in range(20):
#     question_payload = build_question_payload(test_data[i][0][0], "an image of a leaf", "classify the leaf's margin as one of the following: {'entire', 'indented', 'lobed', 'serrate', 'serrulate', 'undulate'}", ["entire", "indented", "lobed", "serrate", "serrulate", "undulate"])
#     block_id = create_block(survey_id, "margin")
#     question_id = add_question(survey_id, question_payload, block_id)
#     print("Question ID:", question_id)

# for i in range(20):
#     question_payload = build_question_payload(test_data[i][0][0], "an image of a leaf", "classify the leaf's texture as one of the following: {'glossy', 'leathery', 'smooth', 'rough'}.", ["glossy", "leathery", "smooth", "rough"])
#     block_id = create_block(survey_id, "margin")
#     question_id = add_question(survey_id, question_payload, block_id)
#     print("Question ID:", question_id)

# get test data that is either Alstonia Scholaris, Terminalia Arjuna, Citrus limon, or Punica granatum
# data = []
# new_gt = []
# for i in range(len(test_data)):
#     if gt[i] in ["Alstonia Scholaris", "Punica granatum"]:
#         data.append(test_data[i])
#         new_gt.append(gt[i])

# undulate_hard = [33, 40, 46, 59, 75, 76, 83, 92, 105, 110, 115, 123, 134, 148, 149, 151, 153, 154, 162, 167, 179, 181, 188, 190]
# data = [test_data[i] for i in undulate_hard]



# # sample 10 ovate and 10 not ovate
# data = [test_data[i] for i in ovate_hard]
# data_ovate = [data[i] for i in range(len(data)) if true_ovate[i]]
# data_not_ovate = [data[i] for i in range(len(data)) if not true_ovate[i]]
# data = data_ovate[:10] + data_not_ovate[:10]

wrong_margin = [7, 22, 39, 41, 50, 57, 60, 61, 70, 72, 80, 86, 90, 102, 108]
wrong_margin_true = ['entire', 'serrulate', 'serrulate', 'indented', 'indented', 'indented', 'entire', 'serrulate', 'serrulate', 'entire', 'serrulate', 'serrulate', 'entire', 'entire', 'entire']
data_margin = [test_data[i] for i in wrong_margin]

wrong_shape = [1, 8, 10, 13, 18, 19, 20, 22, 23, 24, 25, 31, 34, 38, 39, 42, 46, 60, 61, 62, 69, 70, 72, 74, 76, 77, 80, 83, 85, 86, 88, 90, 100, 102, 103, 104, 105, 106, 108]
wrong_shape_true = ['obovate', 'obovate', 'oblong', 'obovate', 'obovate', 'obovate', 'elliptical', 'elliptical', 'elliptical', 'elliptical', 'obovate', 'elliptical', 'elliptical', 'obovate', 'elliptical', 'oblong', 'oblong', 'elliptical', 'elliptical', 'oblong', 'obovate', 'elliptical', 'elliptical', 'elliptical', 'elliptical', 'elliptical', 'elliptical', 'elliptical', 'oblong', 'elliptical', 'elliptical', 'elliptical', 'oblong', 'elliptical', 'oblong', 'obovate', 'elliptical', 'obovate', 'elliptical']
data_shape = [test_data[i] for i in wrong_shape]
# randomly subsample to 15
shuf = np.random.permutation(len(data_shape))
data_shape = [data_shape[i] for i in shuf[:15]]
wrong_shape_true = [wrong_shape_true[i] for i in shuf[:15]]
print("wrong shape index:", [wrong_shape[i] for i in shuf[:15]])

wrong_texture = [1, 7, 10, 12, 18, 20, 22, 23, 24, 27, 31, 38, 39, 41, 42, 47, 50, 61, 62, 63, 69, 70, 80, 83, 85, 86, 88, 96, 100, 103, 104, 105, 106, 108]
wrong_texture_true = ['leathery', 'leathery', 'leathery', 'rough', 'leathery', 'leathery', 'smooth', 'leathery', 'smooth', 'rough', 'smooth', 'leathery', 'smooth', 'leathery', 'leathery', 'rough', 'leathery', 'leathery', 'leathery', 'rough', 'leathery', 'leathery', 'smooth', 'smooth', 'leathery', 'smooth', 'smooth', 'rough', 'leathery', 'leathery', 'leathery', 'rough', 'leathery', 'rough']
data_texture = [test_data[i] for i in wrong_texture]
# randomly subsample to 15
shuf = np.random.permutation(len(data_texture))
data_texture = [data_texture[i] for i in shuf[:15]]
wrong_texture_true = [wrong_texture_true[i] for i in shuf[:15]]
print("wrong texture index:", [wrong_texture[i] for i in shuf[:15]])
1/0
# elliptical_hard = [1, 5, 7, 18, 20, 22, 23, 24, 25, 26, 29, 35, 38, 44, 45, 64, 66, 67, 82, 84, 85, 87, 90, 91, 96, 103, 107, 108, 112, 116, 125, 126, 127, 130, 135, 146, 147, 153, 155, 157, 159, 162, 163, 165, 168, 171, 172, 173, 174, 178, 183, 184, 185, 194, 195]
# true_elliptical = [False, True, False, False, False, True, True, False, True, True, True, False, True, True, True, False, False, False, False, True, True, True, True, False, False, True, False, False, True, True, True, True, True, True, False, False, True, True, True, True, False, True, True, True, False, True, True, False, False, True, True, True, False, False, False]

# data = [test_data[i] for i in elliptical_hard]
# data_elliptical = [data[i] for i in range(len(data)) if true_elliptical[i]]
# data_not_elliptical = [data[i] for i in range(len(data)) if not true_elliptical[i]]
# data = data_elliptical[:10] + data_not_elliptical[:10]



# # random sample half of Alsotnia Scholaris and half of Punica granatum
# shuf = np.random.permutation(len(data))
# data1 = [data[i] for i in shuf if new_gt[i] == "Alstonia Scholaris"][:10]
# data2 = [data[i] for i in shuf if new_gt[i] == "Punica granatum"][:10]
# data = data1 + data2
# print(len(data))

# block_id = create_block(survey_id, "margin")
# for i in range(20):
#     question_payload = build_question_payload(data[i][0][0], "an image of a leaf", "classify the leaf's texture as one of the following: {'glossy', 'leathery', 'smooth', 'rough'}.", ["glossy", "leathery", "smooth", "rough"])
#     question_id = add_question(survey_id, question_payload, block_id)
#     print("Question ID:", question_id)

block_id = create_block(survey_id, "margin")
for i in range(len(data_margin)):
    question_payload = build_question_payload(data_margin[i][0][0], wrong_margin_true[i], "an image of a leaf", "Is the leaf's margin {symbol}?", ["Yes", "No"])
    question_id = add_question(survey_id, question_payload, block_id)
    print("Question ID:", question_id)

# block_id = create_block(survey_id, "shape")
# for i in range(20):
#     question_payload = build_question_payload(data[i][0][0], "an image of a leaf", "classify the leaf's shape as one of the following: {'elliptical', 'lanceolate', 'oblong', 'obovate', 'ovate'}", ["elliptical", "lanceolate", "oblong", "obovate", "ovate"])
#     question_id = add_question(survey_id, question_payload, block_id)
#     print("Question ID:", question_id)

# ovate_hard = [0, 2, 4, 8, 12, 14, 15, 17, 19, 21, 32, 33, 37, 39, 40, 41, 43, 48, 49, 50, 51, 52, 53, 55, 56, 57, 58, 68, 70, 74, 77, 80, 94, 98, 99, 100, 101, 104, 111, 120, 122, 124, 131, 133, 136, 137, 138, 139, 140, 141, 142, 143, 150, 151, 152, 160, 169, 170, 175, 176, 182, 186, 187, 189, 190, 191, 197]
# true_ovate = [False, True, True, True, False, True, False, True, False, False, True, True, True, True, True, False, True, False, True, False, False, False, True, True, False, False, True, True, False, False, False, False, False, False, False, True, False, False, False, False, False, True, False, True, False, False, False, False, False, False, True, True, False, False, False, False, False, True, False, False, False, False, True, False, True, False, False]

# data = [test_data[i] for i in ovate_hard]
# data_ovate = [data[i] for i in range(len(data)) if true_ovate[i]]
# data_not_ovate = [data[i] for i in range(len(data)) if not true_ovate[i]]
# data = data_ovate[:10] + data_not_ovate[:10]

block_id = create_block(survey_id, "shape")
for i in range(len(data_shape)):
    question_payload = build_question_payload(data_shape[i][0][0], wrong_shape_true[i], "an image of a leaf", "Is the leaf's shape {symbol}?", ["Yes", "No"])
    question_id = add_question(survey_id, question_payload, block_id)
    print("Question ID:", question_id)


# smooth_hard = [1, 18, 20, 22, 23, 24, 25, 26, 29, 35, 38, 45, 64, 66, 67, 82, 84, 85, 87, 91, 96, 103, 107, 108, 112, 116, 125, 126, 127, 130, 135, 147, 153, 155, 157, 159, 162, 165, 168, 171, 172, 173, 174, 178, 183, 195]
# true_smooth = [False, False, False, False, True, False, True, False, False, False, False, False, False, False, False, False, True, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, True, False, False, False, False, False, False, False, True, False]

# data = [test_data[i] for i in smooth_hard]
# data_smooth = [data[i] for i in range(len(data)) if true_smooth[i]]
# data_not_smooth = [data[i] for i in range(len(data)) if not true_smooth[i]]
# data = data_smooth[:10] + data_not_smooth[:10]

block_id = create_block(survey_id, "texture")
for i in range(len(data_texture)):
    question_payload = build_question_payload(data_texture[i][0][0], wrong_texture_true[i], "an image of a leaf", "Is the leaf's texture {symbol}?", ["Yes", "No"])
    question_id = add_question(survey_id, question_payload, block_id)
    print("Question ID:", question_id)