from prompts import SYSTEM_ROLE_1_DIAGRAM, NETLIST_INSTRUCTION_START, NETLIST_INSTRUCTION_DICT, NETLIST_INSTRUCTION_INLINE_COMMENT, NETLIST_INSTRUCTION_COMMENT
import base64
import json
import os
import datetime
from openai import OpenAI
import requests

client = None

SEED = 1234
IMAGE_FOLDER = "../data/diagrams/"
DATA_PATH = "../data/new/data.json"
BATCH_FOLDER = "batches/"
OUTPUT_FOLDER = "outputs/"

def encode_image(path):
    with open(path, 'rb') as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def create_request(problem, image_dict, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT):
    content=[
              {
                "type": "text",
                "text": problem
              },
              {
                "type": "text",
                "text": "\n\nRefer to the following images." if image_dict else ""
              }
            ]

    if ONE_SHOT:
        image_title = SYSTEM_ROLE_1_DIAGRAM
        image_path = IMAGE_FOLDER + image_title + ".png"
        content+=[
              {
                "type": "text",
                "text": f'{image_title}:'
              },
              {
                "type": "image_url",
                "image_url": {
                  "url": f"data:image/jpeg;base64,{encode_image(image_path)}"
                }
              }
            ]

    for image_title in image_dict:
        content+=[
              {
                "type": "text",
                "text": f'{image_title}:'
              },
              {
                "type": "image_url",
                "image_url": {
                  "url": f"data:image/jpeg;base64,{image_dict[image_title]}"
                }
              }
            ]

    payload = {
        "model": MODEL,
        "messages": [
          {
            "role": "system",
            "content": SYSTEM_ROLE
          },
          {
            "role": "user",
            "content": content
          }
        ],
        "max_tokens": MAX_TOKENS,
        "seed": SEED
    }

    return payload

def create_batch(BATCH_FILE_PATH, WITH_NETLIST, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT, WNO=True):
    eeqas = []

    with open(DATA_PATH) as f:
        eeqas=json.load(f)

    with open(BATCH_FILE_PATH, 'a') as outfile:
        for n, eeqa in enumerate(eeqas):
            id_= eeqa["id"]
            problem=eeqa["problem"]

            # Modify for netlist
            if WITH_NETLIST:
                netlist=eeqa["netlist"]
                if netlist is not None:
                    problem += "\n" + NETLIST_INSTRUCTION_START

                    # only attach relevant element instructions
                    used_elements = set()
                    for line in netlist.splitlines():
                        if line.strip():
                            first_letter = line.strip().split()[0][0]
                            if first_letter.isalpha():
                                first_letter = first_letter.upper()
                                if first_letter in NETLIST_INSTRUCTION_DICT:
                                    used_elements.add(first_letter)
                    for element in used_elements:
                        problem += NETLIST_INSTRUCTION_DICT[element]

                    if ";" in netlist:
                        problem += NETLIST_INSTRUCTION_INLINE_COMMENT

                    if "*" in netlist:
                        problem += NETLIST_INSTRUCTION_COMMENT
                    problem += "\nThe netlist:\n"
                    problem += netlist
                else:
                    if WNO:
                      continue


            images = eeqa["images"]

            payload = create_request(problem, images, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT)

            batch_item = {"custom_id":str(id_), "method":"POST", "url":"/v1/chat/completions", "body":payload}
            outfile.write(json.dumps(batch_item) + '\n')
    return

def send_batch(API_KEY, EXPERIMENT_NAME, WITH_NETLIST, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT, WNO=True):
    client = OpenAI(api_key=API_KEY)

    BATCH_FILE_PATH = BATCH_FOLDER + EXPERIMENT_NAME + ".jsonl"
    create_batch(BATCH_FILE_PATH, WITH_NETLIST, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT, WNO)
    print("Batch created.")

    batch_input_file = client.files.create(
      file = open(BATCH_FILE_PATH, "rb"),
      purpose="batch"
      )

    batch_input_file_id = batch_input_file.id

    batch = client.batches.create(
      input_file_id=batch_input_file_id,
      endpoint="/v1/chat/completions",
      completion_window="24h",
      metadata={
        "description": EXPERIMENT_NAME,
        "time":f"{datetime.datetime.now().isoformat()}"
        }
      )

    print("Batch sent.")
    print("Batch status: ", client.batches.retrieve(batch.id).status)
    print("Errors: ", client.batches.retrieve(batch.id).errors)

    return

def create_batch_nno(BATCH_FILE_PATH, WITH_NETLIST, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT):
    eeqas = []

    with open(DATA_PATH) as f:
        eeqas=json.load(f)

    with open(BATCH_FILE_PATH, 'a') as outfile:
        for n, eeqa in enumerate(eeqas):
            id_= eeqa["id"]
            problem=eeqa["problem"]

            netlist=eeqa["netlist"]
            if netlist is not None:
                continue

            images = eeqa["images"]

            payload = create_request(problem, images, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT)

            batch_item = {"custom_id":str(id_), "method":"POST", "url":"/v1/chat/completions", "body":payload}
            outfile.write(json.dumps(batch_item) + '\n')
    return

def send_batch_nno(API_KEY, EXPERIMENT_NAME, WITH_NETLIST, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT):
    client = OpenAI(api_key=API_KEY)

    BATCH_FILE_PATH = BATCH_FOLDER + "nno-" + EXPERIMENT_NAME + ".jsonl"
    create_batch_nno(BATCH_FILE_PATH, WITH_NETLIST, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT)
    print("Batch created.")

    batch_input_file = client.files.create(
      file = open(BATCH_FILE_PATH, "rb"),
      purpose="batch"
      )

    batch_input_file_id = batch_input_file.id

    batch = client.batches.create(
      input_file_id=batch_input_file_id,
      endpoint="/v1/chat/completions",
      completion_window="24h",
      metadata={
        "description": "nno-" + EXPERIMENT_NAME,
        "time":f"{datetime.datetime.now().isoformat()}"
        }
      )

    print("Batch sent.")
    print("Batch status: ", client.batches.retrieve(batch.id).status)
    print("Errors: ", client.batches.retrieve(batch.id).errors)

    return

def check_batches_status(API_KEY, NUM = 10):
    client = OpenAI(api_key=API_KEY)
    print("------------\nAll batches:\n------------")
    for b in client.batches.list(limit=20):
        NUM -= 1
        print("description: ", b.metadata["description"], "\n  - id: ", b.id, "\n  - status: ", b.status, "\n  - output file:", b.input_file_id)
        if "time" in  b.metadata.keys():
            print("  - time: ", b.metadata["time"])
        if b.status == "failed":
          print("  - ERRORS:", b.errors)
        print("===========")
        if NUM == 0:
            break

def process_responses(API_KEY, EXPERIMENT_NAME):
    client = OpenAI(api_key=API_KEY)
    print("Processing...")
    for b in client.batches.list(limit=20):
        if b.status == "completed" and b.metadata["description"] == EXPERIMENT_NAME:
            print("  - found", EXPERIMENT_NAME)

            file_response = client.files.content(b.output_file_id)
            t = file_response.text
            resp_lines = t.strip().split('\n')

            OUT_PATH = OUTPUT_FOLDER + "resp-" + EXPERIMENT_NAME + ".json"
            with open(OUT_PATH, 'a') as outfile:
              outfile.write("[")
              first = True
              for line in resp_lines:
                  if not first:
                      outfile.write(", ")
                  first = False
                  resp = json.loads(line)
                  resp_str=json.dumps({"problem_id":int(resp["custom_id"]),"response":resp["response"]["body"]["choices"][0]["message"]["content"]})
                  outfile.write(resp_str)

              outfile.write("]")

            print("Responses processed.")

def process_responses_nno(API_KEY, EXPERIMENT_NAME):
    client = OpenAI(api_key=API_KEY)
    print("Processing...")
    for b in client.batches.list(limit=20):
        if b.status == "completed" and b.metadata["description"] == "nno-" + EXPERIMENT_NAME:
            print("  - found", EXPERIMENT_NAME)

            file_response = client.files.content(b.output_file_id)
            t = file_response.text
            resp_lines = t.strip().split('\n')

            OUT_PATH = OUTPUT_FOLDER + "nno-" + EXPERIMENT_NAME + ".json"
            with open(OUT_PATH, 'a') as outfile:
              outfile.write("[")
              first = True
              for line in resp_lines:
                  if not first:
                      outfile.write(", ")
                  first = False
                  resp = json.loads(line)
                  resp_str=json.dumps({"problem_id":int(resp["custom_id"]),"response":resp["response"]["body"]["choices"][0]["message"]["content"]})
                  outfile.write(resp_str)

              outfile.write("]")

            print("Responses processed.")

def send_single_id(Q_ID, API_KEY, EXPERIMENT_NAME, WITH_NETLIST, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT, WNO=True):
    eeqas = []
    resps = []

    with open(DATA_PATH) as f:
        eeqas=json.load(f)

    TEMP_OUT_PATH = "outputs/temp_file.json"
    FINAL_OUT_PATH = "outputs/gpt_individual_reruns.json"

    print(f"Running id={Q_ID} for experiment {EXPERIMENT_NAME}")

    with open(TEMP_OUT_PATH, 'a') as outfile:
        for n, eeqa in enumerate(eeqas):
            id_= eeqa["id"]
            if id_ != Q_ID:
                continue

            problem=eeqa["problem"]

            # Modify for netlist
            if WITH_NETLIST:
                netlist=eeqa["netlist"]
                if netlist is not None:
                    problem += "\n" + NETLIST_INSTRUCTION_START

                    # only attach relevant element instructions
                    used_elements = set()
                    for line in netlist.splitlines():
                        if line.strip():
                            first_letter = line.strip().split()[0][0]
                            if first_letter.isalpha():
                                first_letter = first_letter.upper()
                                if first_letter in NETLIST_INSTRUCTION_DICT:
                                    used_elements.add(first_letter)
                    for element in used_elements:
                        problem += NETLIST_INSTRUCTION_DICT[element]

                    if ";" in netlist:
                        problem += NETLIST_INSTRUCTION_INLINE_COMMENT

                    if "*" in netlist:
                        problem += NETLIST_INSTRUCTION_COMMENT
                    problem += "\nThe netlist:\n"
                    problem += netlist
                else:
                    if WNO:
                      continue


            images = eeqa["images"]

            payload = create_request(problem, images, MODEL, SYSTEM_ROLE, MAX_TOKENS, ONE_SHOT)

            HEADERS = { "Content-Type": "application/json",
                        "Authorization": f"Bearer {API_KEY}"
                      }

            response = requests.post("https://api.openai.com/v1/chat/completions", headers=HEADERS, json=payload)

            if response.status_code == 200:
                try:
                    resp = response.json()["choices"][0]["message"]["content"]
                except KeyError:
                    print("Response JSON format unexpected:", response.json())
            else:
                print("API request failed:", response.status_code, response.text)

            resps.append({"problem_id":id_,"response":resp})

            resp_str=json.dumps({"problem_id":id_,"response":resp})

            outfile.write(resp_str+"\n")

            print(f'{datetime.datetime.now().isoformat()} - processed problem_id={id_}')

    print('done')

    output_elts=[]
    with open(TEMP_OUT_PATH) as f:
        for line in f.readlines():
            output_elts.append(json.loads(line.strip()))

    with open(FINAL_OUT_PATH,'w') as f:
        f.write(json.dumps(output_elts))

    os.remove(TEMP_OUT_PATH)
