from prompts import SYSTEM_ROLE_1_DIAGRAM, NETLIST_INSTRUCTION_START, NETLIST_INSTRUCTION_DICT, NETLIST_INSTRUCTION_INLINE_COMMENT, NETLIST_INSTRUCTION_COMMENT
import google.generativeai as genai
import json
import datetime
import os
import time

IMAGE_FOLDER = "../data/diagrams/"
DATA_PATH = "../data/new/data.json"

# REQUEST
def send_request(problem, image_dict, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT):

    model = genai.GenerativeModel(
    model_name=MODEL,
    system_instruction=SYSTEM_ROLE)

    images = []

    if ONE_SHOT:
        image_title = SYSTEM_ROLE_1_DIAGRAM
        image_path = IMAGE_FOLDER + image_title + ".png"
        images += [genai.upload_file(path=image_path,
                                display_name=image_title)]

    for image_title in image_dict:
        image_path = IMAGE_FOLDER + image_title + ".png"
        images += [genai.upload_file(path=image_path,
                                display_name=image_title)]

    try:
        response = model.generate_content([problem] + images, generation_config = {"max_output_tokens" : MAX_TOKENS})
    except Exception as e:
        print("COULD NOT RUN: ", e)

    try:
        if response.candidates[0].finish_reason == 1:
            extracted_text = response.candidates[0].content.parts[0].text
            return extracted_text
        else:
            print("API request failed:", response.status_code, response.text)
    except Exception as e:
        print("NO RESPONSE", e)

    return None


def run_all(WITH_NETLIST, ONE_SHOT, API_KEY, MODEL, SYSTEM_ROLE, MAX_TOKENS, EXPERIMENT_NAME, WNO=True):

    print(f"{datetime.datetime.now().isoformat()} - configuring api key")

    genai.configure(api_key=API_KEY)

    print(f"{datetime.datetime.now().isoformat()} - starting")

    eeqas=[]
    resps=[]

    with open(DATA_PATH) as f:
        eeqas=json.load(f)

    TEMP_OUT_PATH = "outputs/temp_file.json"
    FINAL_OUT_PATH = "outputs/resp-" + EXPERIMENT_NAME + ".json"

    with open(TEMP_OUT_PATH, 'a') as outfile:
        for n, eeqa in enumerate(eeqas):

            id_= eeqa["id"]
            problem=eeqa["problem"]

            # Modify for netlist
            if WITH_NETLIST:
                netlist=eeqa["netlist"]
                if netlist is not None:
                    problem += "\n" + NETLIST_INSTRUCTION_START

                    # only attach relevant element instructions
                    used_elements = set()
                    for line in netlist.splitlines():
                        if line.strip():
                            first_letter = line.strip().split()[0][0]
                            if first_letter.isalpha():
                                first_letter = first_letter.upper()
                                if first_letter in NETLIST_INSTRUCTION_DICT:
                                    used_elements.add(first_letter)
                    for element in used_elements:
                        problem += NETLIST_INSTRUCTION_DICT[element]

                    if ";" in netlist:
                        problem += NETLIST_INSTRUCTION_INLINE_COMMENT

                    if "*" in netlist:
                        problem += NETLIST_INSTRUCTION_COMMENT

                    problem += "\n\nNetlist:\n"
                    problem += netlist
                else:
                    if WNO:
                        continue


            images=eeqa["images"]

            resp=send_request(problem, images, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT)

            resps.append({"problem_id":id_,"response":resp})

            resp_str=json.dumps({"problem_id":id_,"response":resp})

            outfile.write(resp_str+"\n")

            print(f'{datetime.datetime.now().isoformat()} - processed problem_id={id_}')

            # time.sleep(20) # RATE LIMIT

    print('done')

    output_elts=[]
    with open(TEMP_OUT_PATH) as f:
        for line in f.readlines():
            output_elts.append(json.loads(line.strip()))

    with open(FINAL_OUT_PATH,'w') as f:
        f.write(json.dumps(output_elts))

    os.remove(TEMP_OUT_PATH)


def rerun_ids(WITH_NETLIST, ONE_SHOT, API_KEY, MODEL, SYSTEM_ROLE, MAX_TOKENS, EXPERIMENT_NAME, ids, WNO=True):

    print(f"{datetime.datetime.now().isoformat()} - configuring api key")

    genai.configure(api_key=API_KEY)

    print(f"{datetime.datetime.now().isoformat()} - starting")

    eeqas=[]
    resps=[]

    with open(DATA_PATH) as f:
        eeqas=json.load(f)

    TEMP_OUT_PATH = "outputs/temp_file.json"
    FINAL_OUT_PATH = "outputs/rerun-" + EXPERIMENT_NAME + ".json"

    with open(TEMP_OUT_PATH, 'a') as outfile:
        for n, eeqa in enumerate(eeqas):

            id_= eeqa["id"]
            if id_ not in ids:
                continue
            problem=eeqa["problem"]

            # Modify for netlist
            if WITH_NETLIST:
                netlist=eeqa["netlist"]
                if netlist is not None:
                    problem += "\n" + NETLIST_INSTRUCTION_START

                    # only attach relevant element instructions
                    used_elements = set()
                    for line in netlist.splitlines():
                        if line.strip():
                            first_letter = line.strip().split()[0][0]
                            if first_letter.isalpha():
                                first_letter = first_letter.upper()
                                if first_letter in NETLIST_INSTRUCTION_DICT:
                                    used_elements.add(first_letter)
                    for element in used_elements:
                        problem += NETLIST_INSTRUCTION_DICT[element]

                    if ";" in netlist:
                        problem += NETLIST_INSTRUCTION_INLINE_COMMENT

                    if "*" in netlist:
                        problem += NETLIST_INSTRUCTION_COMMENT

                    problem += "\n\nNetlist:\n"
                    problem += netlist
                else:
                    if WNO:
                        continue


            images=eeqa["images"]

            resp=send_request(problem, images, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT)

            resps.append({"problem_id":id_,"response":resp})

            resp_str=json.dumps({"problem_id":id_,"response":resp})

            outfile.write(resp_str+"\n")

            print(f'{datetime.datetime.now().isoformat()} - processed problem_id={id_}')

            # time.sleep(20) # RATE LIMIT

    print('done')

    output_elts=[]
    with open(TEMP_OUT_PATH) as f:
        for line in f.readlines():
            output_elts.append(json.loads(line.strip()))

    with open(FINAL_OUT_PATH,'w') as f:
        f.write(json.dumps(output_elts))

    os.remove(TEMP_OUT_PATH)


def run_nno(WITH_NETLIST, ONE_SHOT, API_KEY, MODEL, SYSTEM_ROLE, MAX_TOKENS, EXPERIMENT_NAME):

    print(f"{datetime.datetime.now().isoformat()} - configuring api key")

    genai.configure(api_key=API_KEY)

    print(f"{datetime.datetime.now().isoformat()} - starting")

    eeqas=[]
    resps=[]

    with open(DATA_PATH) as f:
        eeqas=json.load(f)

    TEMP_OUT_PATH = "outputs/temp_file.json"
    FINAL_OUT_PATH = "outputs/nno-" + EXPERIMENT_NAME + ".json"

    with open(TEMP_OUT_PATH, 'a') as outfile:
        for n, eeqa in enumerate(eeqas):

            id_= eeqa["id"]
            problem=eeqa["problem"]

            netlist=eeqa["netlist"]
            if netlist is not None:
                continue

            images=eeqa["images"]

            resp=send_request(problem, images, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT)

            resps.append({"problem_id":id_,"response":resp})

            resp_str=json.dumps({"problem_id":id_,"response":resp})

            outfile.write(resp_str+"\n")

            print(f'{datetime.datetime.now().isoformat()} - processed problem_id={id_}')

            # time.sleep(20) # RATE LIMIT

    print('done')

    output_elts=[]
    with open(TEMP_OUT_PATH) as f:
        for line in f.readlines():
            output_elts.append(json.loads(line.strip()))

    with open(FINAL_OUT_PATH,'w') as f:
        f.write(json.dumps(output_elts))

    os.remove(TEMP_OUT_PATH)


def run_single_id(Q_ID, WITH_NETLIST, ONE_SHOT, API_KEY, MODEL, SYSTEM_ROLE, MAX_TOKENS, EXPERIMENT_NAME, WNO=True):

    print(f"Running id={Q_ID} for experiment {EXPERIMENT_NAME}")

    print(f"{datetime.datetime.now().isoformat()} - configuring api key")

    genai.configure(api_key=API_KEY)

    print(f"{datetime.datetime.now().isoformat()} - starting")

    eeqas=[]
    resps=[]

    with open(DATA_PATH) as f:
        eeqas=json.load(f)

    TEMP_OUT_PATH = "outputs/temp_file.json"
    FINAL_OUT_PATH = "outputs/gemini_individual_reruns.json"

    with open(TEMP_OUT_PATH, 'a') as outfile:
        for n, eeqa in enumerate(eeqas):

            id_= eeqa["id"]
            if id_ != Q_ID:
                continue
            problem=eeqa["problem"]

            # Modify for netlist
            if WITH_NETLIST:
                netlist=eeqa["netlist"]
                if netlist is not None:
                    problem += "\n" + NETLIST_INSTRUCTION_START

                    # only attach relevant element instructions
                    used_elements = set()
                    for line in netlist.splitlines():
                        if line.strip():
                            first_letter = line.strip().split()[0][0]
                            if first_letter.isalpha():
                                first_letter = first_letter.upper()
                                if first_letter in NETLIST_INSTRUCTION_DICT:
                                    used_elements.add(first_letter)
                    for element in used_elements:
                        problem += NETLIST_INSTRUCTION_DICT[element]

                    if ";" in netlist:
                        problem += NETLIST_INSTRUCTION_INLINE_COMMENT

                    if "*" in netlist:
                        problem += NETLIST_INSTRUCTION_COMMENT

                    problem += "\n\nNetlist:\n"
                    problem += netlist
                else:
                    if WNO:
                        continue


            images=eeqa["images"]

            resp=send_request(problem, images, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT)

            resps.append({"problem_id":id_,"response":resp})

            resp_str=json.dumps({"problem_id":id_,"response":resp})

            outfile.write(resp_str+"\n")

            print(f'{datetime.datetime.now().isoformat()} - processed problem_id={id_}')


    print('done')

    output_elts=[]
    with open(TEMP_OUT_PATH) as f:
        for line in f.readlines():
            output_elts.append(json.loads(line.strip()))

    with open(FINAL_OUT_PATH,'w') as f:
        f.write(json.dumps(output_elts))

    os.remove(TEMP_OUT_PATH)
