# action_agent.py
import openai
import random
from openai import OpenAI
import json
import random
from rdkit import Chem
openai.api_key = ''
client = OpenAI(api_key=openai.api_key)


class ActionAgentLLM:
    def __init__(self, dataset, retrieved_context, constraints, model_name="o1"):
        self.dataset = dataset
        self.retrieved_context = retrieved_context
        self.constraints = constraints
        self.model_name = model_name
        
        # Optionally keep a local copy of chain-of-thought
        self.chain_of_thought = []
        self.feedback_log = []

    def run_generation(self, candidate_smiles, chain_of_thought_history: list):
        """
        Perform one 'round' of chain-of-thought molecule generation.
        Returns (candidate_smiles, thought_trace).
        We'll call GPT-4 with the constraints, the partial chain_of_thought, and 
        the feedback so far, to propose the next SMILES.
        """
        # Combine our local chain_of_thought with the external one
        # (In many designs, you'd just keep them in a single place)
        full_chain_of_thought = chain_of_thought_history + self.chain_of_thought + self.feedback_log

        system_prompt = (
            "You are an expert 'action agent' in molecule generation and material science. "
            "Your task is to propose a new molecule (SMILES string) by adding, deleting, "
            "or substituting fragments from a current candidate smiles (blank if this is the first intration). You will be given some retrieved molecules as learning context. The retrieved molecules are all existing molecules that meet the current constraints. The connection part inside these molecule data shows how the fragments are attached and form a completed molecule. "
            "You need to learn from these molecules, but you can choose fragments that are outside of the given molecules. "
            "Output a final SMILES that moves closer to meeting the constraints. "
            "Do not propose imaginary chemistry. If uncertain, do best effort. Pay great attention to the validity of the molecule.\n\n"
            "Show the detailed generation process with what action you did at each step. At the end of your reasoning, produce a line 'FINAL SMILES: X' that is the new candidate."
            "The output format something like this: Step 1: From xxx add xx to get yy. \n Step 2: From yy add xxxx to get zz. ... \n Step n: From zz sub xxxx with aa to get zzz. \n FINAL SMILES: zzz "
        )

        # Prepare a small text version of constraints
        constraints_text = (
            f"- QED range: {self.constraints.get('qed_min','N/A')} to {self.constraints.get('qed_max','N/A')}\n"
            f"- LogP range: {self.constraints.get('logp_min','N/A')} to {self.constraints.get('logp_max','N/A')}\n"
            f"- Must contain: {self.constraints.get('must_contain', 'None')}\n"
            f"- MW range: {self.constraints.get('mw_min','N/A')} to {self.constraints.get('mw_max','N/A')}\n"
        )

        # Summarize some retrieved context if needed
        molecules_context = self.retrieved_context
        #print(molecules_context[:2])
        #fragments_context = self.retrieved_context.get("fragments", [])
        # We'll just show a handful
        #sample_mol_info = [m["smiles"] for m in molecules_context[:100]]
        #sample_frag_info = [f["smiles"] for f in fragments_context[:3]]

        user_prompt = (
            f"Current constraints:\n{constraints_text}\n\n"
            "Retrieved molecule data:\n" + "\n".join(str(molecules_context[:50])) + "\n\n"
            #"Retrieved molecule data:\n" + "None" + "\n\n"
            #"Retrieved fragments:\n" + "\n".join(sample_frag_info) + "\n\n"
            "Chain-of-Thought so far:\n" + "\n".join(full_chain_of_thought) + "\n\n"
            f"Please continue on the gievn chain of thought, start from the current candidate: {candidate_smiles}, or start from scratch if not given. The final line of the output is only 'FINAL SMILES: ...', WITHOUT any extra symbols like * (NO **FINAL SMILES...) or any other contents."
        )

        response = client.chat.completions.create(
            model=self.model_name,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            #max_tokens = 50000,
            #temperature=0.3,  # a bit of creativity for chemical generation
        )

        content = response.choices[0].message.content.strip()
        # content might look like:
        #   Thought: "Let me see. If I add a trifluoromethyl group, I'd ... \n"
        #   FINAL SMILES: CC(C)(F)C(=O)N
        thought_trace = []
        candidate_smiles = None

        # Let's parse out the chain-of-thought (if needed) and the final line
        lines = content.split("\n")
        for line in lines:
            line_stripped = line.strip()
            if line_stripped.startswith("FINAL SMILES:") or line_stripped.startswith("**FINAL SMILES:**") or line_stripped.startswith("**FINAL SMILES**:"):
                candidate_smiles = line_stripped.replace("FINAL SMILES:", "").replace("**FINAL SMILES:**", "").replace("**FINAL SMILES**:", "").strip().replace("'",'').replace('.','').replace('"','').replace('`','')
            else:
                # Everything else is chain-of-thought
                thought_trace.append(line_stripped)

        if candidate_smiles is None:
            # fallback
            candidate_smiles = "C"

        return candidate_smiles, thought_trace
    
    def run_generation_exact(self, candidate_smiles, chain_of_thought_history: list):
        """
        Perform one 'round' of chain-of-thought molecule generation.
        Returns (candidate_smiles, thought_trace).
        We'll call GPT-4 with the constraints, the partial chain_of_thought, and 
        the feedback so far, to propose the next SMILES.
        """
        # Combine our local chain_of_thought with the external one
        # (In many designs, you'd just keep them in a single place)
        full_chain_of_thought = chain_of_thought_history + self.chain_of_thought + self.feedback_log

        system_prompt = (
            "You are an expert 'action agent' in molecule generation and material science. "
            "Your task is to propose a new molecule (SMILES string) by adding, deleting, "
            "or substituting fragments from a current candidate smiles (blank if this is the first intration). You will be given some retrieved molecules that are close to the constraints as learning context. The retrieved molecules are all existing molecules that meet the current constraints. The connection part inside these molecule data shows how the fragments are attached and form a completed molecule. "
            "You need to learn from these molecules, but you can choose fragments that are outside of the given molecules. "
            "Output a final SMILES that moves as close as possible to meeting the constraints. "
            "Do not propose imaginary chemistry. If uncertain, do best effort. Pay great attention to the validity of the molecule.\n\n"
            "Show the detailed generation process with what action you did at each step. At the end of your reasoning, produce a line 'FINAL SMILES: X' that is the new candidate."
            "The output format something like this: Step 1: From xxx add xx to get yy. \n Step 2: From yy add xxxx to get zz. ... \n Step n: From zz sub xxxx with aa to get zzz. \n FINAL SMILES: zzz "
        )

        # Prepare a small text version of constraints
        constraints_text = (
            f"- QED: {self.constraints.get('qed','N/A')}"
            f"- LogP: {self.constraints.get('logp','N/A')}"
            f"- Must contain: {self.constraints.get('must_contain', 'None')}\n"
            f"- MW: {self.constraints.get('mw','N/A')}"
        )

        # Summarize some retrieved context if needed
        molecules_context = self.retrieved_context
        #print(molecules_context[:2])
        #fragments_context = self.retrieved_context.get("fragments", [])
        # We'll just show a handful
        #sample_mol_info = [m["smiles"] for m in molecules_context[:100]]
        #sample_frag_info = [f["smiles"] for f in fragments_context[:3]]

        user_prompt = (
            f"Current constraints:\n{constraints_text}\n\n"
            "Retrieved molecule data:\n" + "\n".join(str(molecules_context[:50])) + "\n\n"
            #"Retrieved molecule data:\n" + "None" + "\n\n"
            #"Retrieved fragments:\n" + "\n".join(sample_frag_info) + "\n\n"
            "Chain-of-Thought so far:\n" + "\n".join(full_chain_of_thought) + "\n\n"
            f"Please continue on the gievn chain of thought, start from the current candidate: {candidate_smiles}, or start from scratch if not given. The final line of the output is only 'FINAL SMILES: ...', WITHOUT any extra symbols like * (NO **FINAL SMILES...) or any other contents."
        )

        response = client.chat.completions.create(
            model=self.model_name,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            #max_tokens = 50000,
            #temperature=0.3,  # a bit of creativity for chemical generation
        )

        content = response.choices[0].message.content.strip()
        # content might look like:
        #   Thought: "Let me see. If I add a trifluoromethyl group, I'd ... \n"
        #   FINAL SMILES: CC(C)(F)C(=O)N
        thought_trace = []
        candidate_smiles = None

        # Let's parse out the chain-of-thought (if needed) and the final line
        lines = content.split("\n")
        for line in lines:
            line_stripped = line.strip()
            if line_stripped.startswith("FINAL SMILES:") or line_stripped.startswith("**FINAL SMILES:**") or line_stripped.startswith("**FINAL SMILES**:"):
                candidate_smiles = line_stripped.replace("FINAL SMILES:", "").replace("**FINAL SMILES:**", "").replace("**FINAL SMILES**:", "").strip().replace("'",'').replace('.','').replace('"','').replace('`','')
            else:
                # Everything else is chain-of-thought
                thought_trace.append(line_stripped)

        if candidate_smiles is None:
            # fallback
            candidate_smiles = "C"

        return candidate_smiles, thought_trace

    def receive_feedback(self, feedback: str):
        """Incorporate evaluator's feedback into the chain-of-thought."""
        self.feedback_log.append(f"Feedback from evaluator: {feedback}")
