#!/usr/bin/env python
# -*- coding=utf8 -*-
"""
"""
import os
import openai
import torch
import json
from benchmarks.MAT.material_benchmarks import MATBench


# Set the OpenAI API key using os.system
def set_api_key(api_key):
    """
    Sets the OpenAI API key using the `export` command through `os.system`.

    Parameters:
    - api_key (str): Your OpenAI API key.
    """
    os.system(f"export OPENAI_API_KEY={api_key}")


# Function to generate a response from GPT
def query_gpt(prompt, model="gpt-4o", temperature=0.7, max_tokens=100):
    """
    Sends a prompt to ChatGPT and returns the response.

    Parameters:
    - prompt (str): The input prompt to send to GPT.
    - model (str): The GPT model to use (default: "gpt-3.5-turbo").
    - temperature (float): Sampling temperature for output diversity (default: 0.7).
    - max_tokens (int): Maximum number of tokens in the response (default: 100).

    Returns:
    - str: The response generated by GPT.
    """
    try:
        # Set API key from environment variable if available
        openai.api_key = os.getenv('OPENAI_API_KEY')

        # Send the request to OpenAI API
        response = openai.ChatCompletion.create(
            model=model,
            messages=[{
                "role": "user",
                "content": prompt
            }],
            temperature=temperature,
            max_tokens=max_tokens,
        )

        # Extract the text from the response
        reply = response['choices'][0]['message']['content']
        return reply.strip()

    except Exception as e:
        print(f"Error: {e}")
        return None


def gpt_clustering(dataset, data_name, target_col, num_clusters=5):

    # Set the API key using export through os.system
    api_key = 'sk-proj-buoaAKGzzukwMu5MB6DhT3BlbkFJRyJbZ8fXKpVLhVSMpFMr'  # Replace with your OpenAI API key
    set_api_key(api_key)

    # Prompt for redox potential clustering
    redox_prefix = "You are a chemist, please use the molecules provided, group them into five clusters based on their redox potential (extremely low : 0, low: 1, medium: 2, high: 3, extremely high :4). Analyze the following features for each molecule: Number and type of electron-withdrawing groups (e.g., CF₃, NO₂, CN, halogens). Number and type of electron-donating groups (e.g., alkyl, methoxy, hydroxyl). Positioning of substituents on aromatic rings (meta, para, ortho). Presence of sulfur-containing functional groups (e.g., SCF₃, S=O). Degree of molecular polarity (based on F, O, or N atoms). Use these criteria to evaluate the redox potential qualitatively and cluster the molecules accordingly."

    solvation_prefix = ""
    kinase_prefix = ""
    laser_prefix = ""
    pce_prefix = ""
    photoswitch_prefix = ""

    if 'redox-mer' in data_name:
        prefix = redox_prefix
    elif 'solvation' in data_name:
        prefix = solvation_prefix
    elif 'kinase' in data_name:
        prefix = kinase_prefix
    elif 'laser' in data_name:
        prefix = laser_prefix
    elif 'pce' in data_name:
        prefix = pce_prefix
    elif 'photoswitch' in data_name:
        prefix = photoswitch_prefix
    else:
        prefix = ""

    molecules = dataset[target_col].tolist()
    cluster_labels = []
    for i in range(int(len(molecules) / 500)):
        batch_mole = molecules[i * 500:(i + 1) * 500]
        formatted_data = "\n".join([f"{i+1}. {item[0]}" for i, item in enumerate(batch_mole)])
        prompt = prefix + f"\nNow we have a list of molecules:\n{formatted_data}\nRespond strictly with the cluster numerical labels only. Do not include any additional text."
        print(prompt)
        # Call the query function
        # response = query_gpt(prompt, max_tokens=4096)
        # print(response)
        # labels = response.split("\n")
        # results = []
        # for j in range(len(batch_mole)):
        #     print(batch_mole[j][0], batch_mole[j][1], labels[j].split()[1])
        #     results.append((batch_mole[j][0], batch_mole[j][1], labels[j].split()[1]))
        #     cluster_labels.append(labels[j].split()[1])
        # # i += 100
        # with open(data_name + "_" + str(i) + '_data.json', 'w') as f:
        #     json.dump(results, f)
        # outfile = data_name + "_" + str(i) + "_result.npy"
        # with open(outfile, 'wb') as outfile:
        #     np.save(outfile, results)

    dataset["cluster"] = cluster_labels
    assert len(dataset) == len(cluster_labels)
    return dataset


# Example usage
if __name__ == "__main__":

    data_name = "redox-mer"
    f_model = "gpt2-medium"
    finetuning = True
    prompt_type = "just-smiles"
    iupac = False
    seed = 665
    feature_reduction = "average"
    mat_bench = MATBench(
        data_name=data_name,
        run_subset_only=False,
        feature_type=f_model,
        finetuning=finetuning,
        iupac=iupac,
        prompt_type=prompt_type,
        randseed=seed,
        feature_reduction=feature_reduction,
    )

    # dataset = gpt_clustering(mat_bench.dataset, data_name, mat_bench.target_col, num_clusters=5)
