#!/usr/bin/env python
# -*- coding=utf8 -*-
"""
"""
import os
import openai
import json
import numpy as np
import time


class PromptCluster(object):

    def __init__(self, data_name, molecule_col):

        self.data_name = data_name
        self.molecule_col = molecule_col

    # Set the OpenAI API key using os.system
    def set_api_key(self, api_key=None):
        """
        Sets the OpenAI API key using the `export` command through `os.system`.

        Parameters:
        - api_key (str): Your OpenAI API key.
        """
        # os.system(f"export OPENAI_API_KEY={api_key}")
        os.environ["OPENAI_API_KEY"] = api_key
        os.system(f"echo $OPENAI_API_KEY")

    # Function to generate a response from GPT
    def query_gpt(self, prompt, model="gpt-4o", temperature=0.7, max_tokens=100):
        """
        Sends a prompt to ChatGPT and returns the response.

        Parameters:
        - prompt (str): The input prompt to send to GPT.
        - model (str): The GPT model to use (default: "gpt-3.5-turbo").
        - temperature (float): Sampling temperature for output diversity (default: 0.7).
        - max_tokens (int): Maximum number of tokens in the response (default: 100).

        Returns:
        - str: The response generated by GPT.
        """
        try:
            # Set API key from environment variable if available
            openai.api_key = os.getenv('OPENAI_API_KEY')

            # Send the request to OpenAI API
            response = openai.ChatCompletion.create(
                model=model,
                messages=[{
                    "role": "user",
                    "content": prompt
                }],
                temperature=temperature,
                max_tokens=max_tokens,
            )

            # Extract the text from the response
            reply = response['choices'][0]['message']['content']
            return reply.strip()

        except Exception as e:
            print(f"Error: {e}")
            return None

    def get_prefix_prompt(self):
        # Prompt for redox potential clustering
        redox_prefix = "You are a chemist, please use the molecules provided, group them into five clusters based on their redox potential (extremely low : 0, low: 1, medium: 2, high: 3, extremely high :4). Analyze the following features for each molecule: Number and type of electron-withdrawing groups (e.g., CF₃, NO₂, CN, halogens). Number and type of electron-donating groups (e.g., alkyl, methoxy, hydroxyl). Positioning of substituents on aromatic rings (meta, para, ortho). Presence of sulfur-containing functional groups (e.g., SCF₃, S=O). Degree of molecular polarity (based on F, O, or N atoms). Use these criteria to evaluate the redox potential qualitatively and cluster the molecules accordingly."
        solvation_prefix = "You are a chemist, please use the molecules provided, group them into five clusters (extremely low : 0, low: 1, medium: 2, high: 3, extremely high :4) based on predicted solvation energy. Consider these molecular features: Polarity and hydrogen bonding capability Molecular size and surface area; Number and type of charged/ionizable groups; Hydrophilic/hydrophobic balance. Use these criteria to evaluate the solvation energy qualitatively and cluster the molecules accordingly."
        # kinase_prefix = "You are a chemist, please use the molecules provided, group them into five clusters (extremely low : 0, low: 1, medium: 2, high: 3, extremely high :4) related to predicted docking scores of kinase inhibitors based on the following features: a). Kinase binding pharmacophore elements: Hinge region binding motifs (H-bond donors/acceptors), ATP-pocket occupancy features, Gatekeeper pocket interactions, DFG-motif interaction potential; b). Drug-like properties: Molecular weight (optimal ~500 Da) , Number of rotatable bonds, H-bond donors/acceptors, Ring systems and their arrangement; c). Key structural feHeterocyclic scaffolds common in kinase inhibitors, Presence of typical kinase-binding groups, Backbone flexibility/rigidity, Solvent exposure potential. Use these criteria to evaluate the docking score qualitatively and cluster the molecules accordingly."

        kinase_prefix = "Act as a computational chemist. You have a list of SMILES strings for molecules, and I want to group them into 5 clusters (0 to 4) where cluster 0 has the lowest predicted kinase docking affinity and cluster 4 has the highest. Use your knowledge of kinase inhibitors to analyze these SMILES and assign cluster labels. Prioritize these criteria: 1. Presence of kinase-binding motifs (e.g., hinge-binding heterocycles, hydrophobic pockets). 2. Functional groups (e.g., hydrogen bond donors/acceptors, aromatic rings). 3. Molecular weight and polarity (smaller/lipophilic molecules often bind kinases better).4. Similarity to known kinase inhibitors (e.g., ATP analogs, tyrosine kinase inhibitors). Return the cluster labels (0–4) based on these criteria."

        laser_prefix = "Act as a computational chemist with expertise in photophysics. Group these SMILES strings into 5 clusters based on predicted fluorescence oscillator strength (relevant for lasers), from very low to very high: 1. Cluster 0: Very low (poor for lasing). 2. Cluster 1: Low (marginal for lasing). 3. Cluster 2: Moderate (potentially suitable for lasing). 4. Cluster 3: High (good for lasing).5. Cluster 4: Very high (excellent for lasing). Consider: 1. Conjugation length: Longer conjugation increases oscillator strength. 2. Aromaticity: Aromatic systems often have strong π-π* transitions. 3. Functional groups: Electron-donating/withdrawing groups alter oscillator strength.4. Molecular rigidity: Rigid molecules tend to have higher oscillator strength."

        pce_prefix = "You are a chemist, please use the molecules provided, group them into five clusters (extremely low : 0, low: 1, medium: 2, high: 3, extremely high :4) related to predicted photovoltaic conversion efficiency (pce) based on the following features:  Electronic structure indicators ( conjugation extent, aromatic systems, electron-rich/deficient regions, push-pull molecular design; Molecular architecture (planarity potential, π-system connectivity, structural rigidity, molecular size); Light harvesting features (conjugated backbone length, donor-acceptor patterns, chromophore presence, substituent effects). Use these criteria to evaluate the pce qualitatively and cluster the molecules accordingly."

        photoswitch_prefix = "Act as a computational chemist with expertise in photochemistry. I have a list of SMILES strings for organic molecules, and I want to group them into 5 clusters based on their predicted π-π* transition wavelengths: 1. Cluster 0: deep UV range (200–300 nm). 2. Cluster 1:UV range (300–400 nm). 3. Cluster 2: the blue/visible range (400–500 nm).4.Cluster 3: green/red/visible range (500–700 nm). 5. Cluster 4: near-infrared range (700–1000 nm). Use your knowledge of molecular structure and photochemistry to analyze the SMILES strings and assign cluster labels. Consider the following factors: a. Conjugation length: Longer conjugation typically shifts the π-π* transition to longer wavelengths. b. Aromaticity: Aromatic systems often have π-π* transitions in the UV/visible range. c.Functional groups: Electron-donating or electron-withdrawing groups can alter the HOMO-LUMO gap. d. Molecular planarity: Planar molecules tend to have stronger π-π* transitions."

        if 'redox-mer' in self.data_name:
            prefix = redox_prefix
        elif 'solvation' in self.data_name:
            prefix = solvation_prefix
        elif 'kinase' in self.data_name:
            prefix = kinase_prefix
        elif 'laser' in self.data_name:
            prefix = laser_prefix
        elif 'pce' in self.data_name:
            prefix = pce_prefix
        elif 'photoswitch' in self.data_name:
            prefix = photoswitch_prefix
        else:
            prefix = ""
        return prefix

    def gpt_clustering(self, dataset):
        print(dataset)
        # Set the API key using export through os.system
        api_key = "" 
        self.set_api_key(api_key=api_key)
        prefix = self.get_prefix_prompt()
        print(self.molecule_col)
        molecules = dataset[self.molecule_col].to_list()
        # print(molecules)
        batch_size = 200
        n_batches = int(np.ceil(len(molecules) / float(batch_size)))

        folder_path = "benchmarks/MAT/data/llm_clustering/" + self.data_name + "/"
        for i in range(n_batches):
            if i == n_batches - 1:
                batch_mole = molecules[i * batch_size:]
            else:
                batch_mole = molecules[i * batch_size:(i + 1) * batch_size]
            # print(batch_mole)
            formatted_data = "\n".join([f"{i+1}. {item}" for i, item in enumerate(batch_mole)])
            prompt = prefix + f"**Respond strictly with the counter and numerical cluster labels only. Do not include any additional text.**\nNow please cluster the following molecules:\n{formatted_data}\n"
            print(prompt)
            delay = 10.0
            if os.path.exists(folder_path + self.data_name + "_" + str(i) + "_result.npy"):
                continue
            else:
                try:
                    time.sleep(delay)
                    response = self.query_gpt(prompt, max_tokens=4000)
                except openai.error.RateLimitError:
                    print(f"Rate limit hit. Waiting for {delay} seconds...")
                    time.sleep(delay)  # Introduce delay before retrying
                    response = self.query_gpt(prompt, max_tokens=4000)
            print(response)
            labels = response.split("\n")
            results = []
            for j in range(len(batch_mole)):
                print(batch_mole[j], labels[j])
                tmp = labels[j].split()
                if len(tmp) == 1:
                    results.append((batch_mole[j], tmp[0]))
                else:
                    results.append((batch_mole[j], tmp[1]))
            # i += 100

            if not os.path.exists(folder_path):
                print(f"floder '{folder_path}' not exist，making it ...")
                os.makedirs(folder_path)

            with open(folder_path + self.data_name + "_" + str(i) + '_data.json', 'w') as f:
                json.dump(results, f)
            outfile = folder_path + self.data_name + "_" + str(i) + "_result.npy"
            with open(outfile, 'wb') as outfile:
                np.save(outfile, results)
        cluster_labels = []
        mole_cluster_dict = {}
        for i in range(n_batches):
            # if i == 50:
            #     for l in gpt_label:
            #         cluster_labels.append(l)
            #     continue
            results = np.load(folder_path + self.data_name + "_" + str(i) + "_result.npy")
            if results[0][1] == "Cluster":
                print("batch", i, "has wrong results")
                batch_mole = molecules[i * batch_size:(i + 1) * batch_size]
                formatted_data = "\n".join([f"{i+1}. {item}" for i, item in enumerate(batch_mole)])
                prompt = prefix + f"**Respond strictly with the counter and numerical cluster labels only. Do not include any additional text.**\nNow please cluster the following molecules:\n{formatted_data}\n"
                print(prompt)
                response = self.query_gpt(prompt, max_tokens=4000)
                print(response)
                labels = response.split("\n")
                results = []
                for j in range(len(batch_mole)):
                    print(batch_mole[j], labels[j])
                    tmp = labels[j].split()
                    if len(tmp) == 1:
                        results.append((batch_mole[j], tmp[0]))
                    elif len(tmp) == 2:
                        results.append((batch_mole[j], tmp[1]))
                    elif len(tmp) == 3:
                        results.append((batch_mole[j], tmp[2]))
                    else:
                        results.append((batch_mole[j], tmp[-1]))
                with open(folder_path + self.data_name + "_" + str(i) + '_data.json', 'w') as f:
                    json.dump(results, f)
                outfile = folder_path + self.data_name + "_" + str(i) + "_result.npy"
                with open(outfile, 'wb') as outfile:
                    np.save(outfile, results)

            for mole, label in results:
                # print(label)
                if "-" in label:
                    label = label.split("-")[1]
                    print("===============")
                    print(i)
                    print(label)
                    # label = np.random.choice(0, 5)
                elif int(label) < 0 or int(label) > 5:
                    print("===============")
                    print(i)
                    print(label)
                    label = np.random.choice(0, 5)
                cluster_labels.append(int(label))
                mole_cluster_dict[mole] = int(label)
        for mole in dataset[self.molecule_col].to_list():
            assert mole in mole_cluster_dict
            dataset.loc[dataset[self.molecule_col] == mole, "llm_cluster"] = mole_cluster_dict[mole]
        assert len(dataset) == len(cluster_labels)
        return dataset


if __name__ == "__main__":
    pass
