from utils.base_prompts import Prompts
import random
import json
import numpy as np
import itertools
from typing import Dict, List


class MoleculePrompts(Prompts):
    def __init__(self, task, **kwargs):
        super().__init__(task)
        self.system_prompt = """\
You are a helpful chatbot with high attention to detail who is not talkative and responds only with the answer and no \
additional conversation. All your responses should be in JSON format, i.e. '{key: value}', where the key is always \
"response" and the value can be a string, int, list, or dict, depending on the context."""
        self.user_prompt = """\
Your task is to find the optimal drug molecule that has both a high druglikeness (QED) as well as a strong binding \
affinity (vina) with the protein %s. For docking, lower is better (less than −10 is considered good) and for \
druglikeness, 1 is the best and 0 is the worst (greater than 0.8 is considered good). While both properties are \
important, the docking score is 10 times as important as the druglikeness score. If you propose an invalid molecule or \
make a repeat guess, you will get no score, so stick to valid SMILES strings.

Now, guess exactly n=%s new molecule(s).

(Note: give only a list of SMILES string(s) in the provided JSON format, e.g. {"response": ["SMILES1", "SMILES2", ...]})"""

        self.system_prompt_for_neighbors = """\
You are a helpful chatbot with high attention to detail who is not talkative and responds only with the answer and no \
additional conversation. All your responses should be in JSON format, i.e. '{key: value}', where the key is always \
"response" and the value can be a string, int, list, or dict, depending on the context."""
        self.user_prompt_for_neighbors = """\
Your task is to find the optimal drug molecule that has both a high druglikeness (QED) as well as a strong binding \
affinity (vina) with the protein %s. For docking, lower is better (less than −10 is considered good) and for \
druglikeness, 1 is the best and 0 is the worst (greater than 0.8 is considered good). While both properties are \
important, the docking score is 10 times as important as the druglikeness score. If you propose an invalid molecule or \
make a repeat guess, you will get no score, so stick to valid SMILES strings!

Here is my guess for a molecule:
SMILES: %s

Now, guess exactly n=%s new variation(s) of my molecule that could improve the scores to reach the optimal molecule.

(Note: give only a list of SMILES string(s) in the provided JSON format, e.g. {"response": ["SMILES1", "SMILES2", ...]})"""

        self.system_prompt_for_opro = """\
You are a helpful chatbot with high attention to detail who is not talkative and responds only with the answer and no \
additional conversation. All your responses should be in JSON format, i.e. '{key: value}', where the key is always \
"response" and the value can be a string, int, list, or dict, depending on the context."""

        self.user_prompt_for_opro = """\
Your task is to find the optimal drug molecule that has both a high druglikeness (QED) as well as a strong binding \
affinity (vina) with the protein %s. For docking, lower is better (less than −10 is considered good) and for \
druglikeness, 1 is the best and 0 is the worst (greater than 0.8 is considered good). While both properties are \
important, the docking score is 10 times as important as the druglikeness score. If you propose an invalid molecule or \
make a repeat guess, you will get no score, so stick to valid SMILES strings!

Here are your top previous guesses (from worst to best):

%s

Now, guess exactly n=%s new molecules that could score higher than your previous guesses.

(Note: give only a list of SMILES string(s) in the provided JSON format, e.g. {"response": ["SMILES1", "SMILES2", ...]})"""

    def build_prompt(self, use_alternate=False) -> List[Dict[str, str]]:
        """
        Constructs a single prompt of the following format:
            [
                {"content": system_prompt, "role": "system"},
                {"content": user_prompt, "role": "user"},
            ]

        Returns:
        - List[Dict[str, str]]: A single prompt
        """
        if not use_alternate:
            return [
                {"content": self.system_prompt, "role": "system"},
                {"content": self.user_prompt % (self.task.target, self.task.num_guesses), "role": "user"},
            ]
        else:
            return [
                {"content": self.system_prompt, "role": "system"},
                {"content": self.user_prompt % 2, "role": "user"},
            ]

    def build_dataset(
        self,
        task_id,
        min_training_size=50,
        max_training_size=80,
        max_seq_len=2048,
    ):
        training_dataset = [
            {"prompt": self.build_prompt(), "problem": [self.task.target], "solution": [self.task.target]}
        ]
        print("Training dataset size:", len(training_dataset))
        if len(training_dataset) < min_training_size:
            expanded_training_dataset = []
            for item in training_dataset:
                expanded_training_dataset.extend([item] * ((min_training_size // len(training_dataset)) + 1))
            training_dataset = expanded_training_dataset[-max_training_size:]
            print("Extending training dataset to:", len(training_dataset))
        else:
            # Sort by prompt length
            training_dataset = training_dataset[-max_training_size:]
            print("Clipping training dataset size to:", len(training_dataset))
        print("Example of training prompt:\n", training_dataset[-1]["prompt"])

        validation_dataset = []
        test_dataset = []

        return training_dataset, validation_dataset, test_dataset

    def get_neighborhood_samples_prompt(self, target_input, target_output, alt=False):
        try:
            output = json.loads(target_output[0])["response"][0]
        except:
            output = target_output
        return [
            {"content": self.system_prompt_for_neighbors, "role": "system"},
            # {"content": self.user_prompt_for_neighbors % (self.task.migrate_gamma, target_output), "role": "user"},
            {"content": self.user_prompt_for_neighbors % (self.task.target, output, 1), "role": "user"},
        ]

    def get_opro_samples_prompt(self, target_input, target_output, alt=False):
        # if len(target_output) == 1:
        #     context = f' "{target_output[0]}."\n'
        # else:
        #     context = "\n\n".join(target_output)
        #     context = f"\n\n{context}\n"
        # return [
        #     {"content": self.system_prompt_for_opro, "role": "system"},
        #     {"content": self.user_prompt_for_opro % (context, self.task.migrate_gamma if not alt else 2), "role": "user"},
        # ]
        context = []
        for guess in target_output:
            try:
                context.append(json.loads(guess)["response"][0])
            except Exception as _:
                pass

        return [
            {"content": self.system_prompt_for_neighbors, "role": "system"},
            {
                "content": self.user_prompt_for_opro % (self.task.target, "SMILES: " + "\n\nSMILES: ".join(context), 1),
                "role": "user",
            },
        ]
