import asyncio
import random
import aiohttp
import re
import math
from aiohttp import ClientSession
import nest_asyncio
import copy
import openai
from tqdm.auto import tqdm
nest_asyncio.apply()

from datasets import load_dataset

def get_sampled_gsm8k(fraction=0.01, seed=42):
    dataset = load_dataset("gsm8k", "main", split="train")
    print(f"Full GSM8K train set: {len(dataset)}")

    sampled_dataset = dataset.train_test_split(test_size=fraction, seed=seed)['test']
    print(f"Sample subset: {len(sampled_dataset)}")

    return sampled_dataset


async def async_get_single_answer(prompt, session, api_key, temperature=0):
    url = "https://openrouter.ai/api/v1/chat/completions"
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": "qwen/qwen-2.5-7b-instruct",
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ],
        "temperature": temperature
    }
    while True:
        try:
            await asyncio.sleep(random.uniform(0.5, 1.5))
            async with session.post(url, json=payload, headers=headers) as response:
                if response.status == 429:
                    print("Rate limit exceeded. Waiting before retrying...")
                    await asyncio.sleep(10)
                    continue
                response_json = await response.json()
                answer = response_json['choices'][0]['message']['content']
                full_match = re.search(r"<Answer>(.*?)</Answer>", answer, re.DOTALL)
                if full_match:
                    inner_text = full_match.group(1)
                    number_match = re.search(r"([0-9]+(?:\.[0-9]+)?)", inner_text)
                    if number_match:
                        extracted_answer = number_match.group(1)
                        return extracted_answer
                    else:
                        return "WRONG"
                else:
                    return "WRONG"
        except Exception as e:
            await asyncio.sleep(5)

async def async_get_batch_answers(prompts, Instruction, api_key, temperature=0, batch_size=64, delay_between_batches=10):
    answers = []
    print(prompts[0].format(Instruction))
    async with ClientSession() as session:
        for i in range(0, len(prompts), batch_size):
            batch = prompts[i:i + batch_size]
            tasks = [
                async_get_single_answer(prompt.format(Instruction), session, api_key, temperature)
                for prompt in batch
            ]
            batch_results = await asyncio.gather(*tasks)
            answers.extend(batch_results)
            if i + batch_size < len(prompts):
                print(f"Batch {i//batch_size + 1} completed, sleeping for {delay_between_batches} seconds...")
                await asyncio.sleep(delay_between_batches)
    return answers

def extract_groundtruth_number(text):
    match = re.search(r"####\s*([0-9]+(?:\.[0-9]+)?)", text)
    if match:
        return float(match.group(1))
    return None

def compute_accuracy(predictions, dataset, tol=1e-5):
    correct = 0
    total = len(predictions)

    for pred_str, sample in zip(predictions, dataset):
        try:
            pred_value = float(pred_str.strip())
        except ValueError:
            print("Invalid numeric prediction:", pred_str)
            continue

        gt_value = extract_groundtruth_number(sample["answer"])
        if gt_value is None:
            print("Failed to extract ground truth from sample:", sample)
            continue

        if math.isclose(pred_value, gt_value, rel_tol=tol):
            correct += 1

    return correct / total if total > 0 else 0

class env:
  def __init__(self):
    self.sampled_dataset = get_sampled_gsm8k()
    self.prompts = [
        sample["question"].strip() + "\nProvide your final result within the tags <Answer>FINAL_NUMERICAL_ANSWER</Answer>.\n" + "\n{}"
        for sample in self.sampled_dataset
    ]
    self.sample_count = 0
    self.API_key = "[API_KEY]"

  def evaluator(self, Instruction):
    answer = asyncio.run(async_get_batch_answers(self.prompts, Instruction, self.API_key))
    accuracy = compute_accuracy(answer, self.sampled_dataset)
    print(f"Instruction: {Instruction}, score={int(round(accuracy*100))}")
    return int(round(accuracy*100))

  def random_sample_exemplars(self):
    sampled_samples = self.sampled_dataset.shuffle(seed=42 + self.sample_count).select(range(3))
    self.sample_count += 1
    return sampled_samples

class Template_Padding:
  def __init__(self, env):
    self.exemplars_template = """text:
{}
score:
{}

"""
    self.exemplars_template_2 = """Problem:
Q: {}
A: <INS>

Ground truth Answer:
{}

"""
    self.env = env
    samples_evaluator = []
    initial_instruction = "Let's think step by step."
    samples_evaluator.append(self.env.evaluator(initial_instruction))
    self.samples = list(zip([initial_instruction], samples_evaluator))
    self.samples.sort(key=lambda x: x[1], reverse=True)

    self.input_exemplars_1 = ""
    self.input_exemplars_2 = ""
    self.Add_Padding()

    self.prompt_1 = """Your task is to generate the instruction <INS>. Below are some previous instructions with their scores. The score ranges from 0 to 100.
"""
    self.prompt_2 = """Below are some problems.
"""
    self.prompt_3 = """Generate an instruction that is different from all the instructions <INS> above, and has a higher score than all the instructions <INS> above. The instruction should begin with <INS> and end with </INS>. The instruction should be concise, effective, and generally applicable to all problems above.""" # Meta_Prompt
    self.API_key = "[API_KEY]"
    self.result_lst = []

    self.steps=20

  def Add_Padding(self):
    self.samples.sort(key=lambda x: x[1], reverse=False)

    samples_copy = copy.deepcopy(self.samples)

    sample_set = list(set(samples_copy))
    sample_set.sort(key=lambda x: x[1], reverse=False)

    selected_samples = sample_set[-20:]

    self.input_exemplars_1 = ""
    for sample in selected_samples:
        self.input_exemplars_1 += self.exemplars_template.format(sample[0], sample[1])

    self.input_exemplars_2 = ""
    sampled_exemplars = self.env.random_sample_exemplars()
    for sample in sampled_exemplars:
      self.input_exemplars_2 += self.exemplars_template_2.format(sample["question"], extract_groundtruth_number(sample["answer"]))


  def Add_Sample(self, new_instruction, re=False):
    score = self.env.evaluator(new_instruction)
    self.samples.append((new_instruction, score))
    if re == True:
      return score

  def Add_All(self, Meta_Prompt):
    prompt_template =  self.prompt_1 + "\n{}" + self.prompt_2 + "\n{}" + Meta_Prompt
    return prompt_template.format(self.input_exemplars_1, self.input_exemplars_2)

  async def Prompt_LLM_Epoch(self, Meta_Prompt, epoch=7):
    async def fetch_one():
      max_retries = 3
      for attempt in range(max_retries):
        response = await openai.ChatCompletion.acreate(
          model="openai/gpt-4-turbo",
          messages=[{"role": "user", "content": self.Add_All(Meta_Prompt)}],
          temperature=1,
        )
        content = response.choices[0].message['content'].strip()
        match = re.search(r"<INS>(.*?)</INS>", content, re.DOTALL)
        if match:
            extracted_content = match.group(1).strip()
        else:
            extracted_content = content

        print(extracted_content)
        self.Add_Sample(extracted_content)
        return extracted_content

      print("[Warning] Failed after 3 attempts.")

    tasks = [fetch_one() for _ in range(epoch)]
    answers = await asyncio.gather(*tasks)
    return answers

  def Prompt_LLM_Epoch_Temp_0(self, Meta_Prompt):
    max_retries = 10
    print(self.Add_All(Meta_Prompt))
    for attempt in range(max_retries):
      try:
        response = openai.ChatCompletion.create(
          model="openai/gpt-4-turbo",
          messages=[{"role": "user", "content": self.Add_All(Meta_Prompt)}],
          temperature=0
        )
        content = response.choices[0].message['content'].strip()
        match = re.search(r"<INS>(.*?)</INS>", content, re.DOTALL)
        if match:
          extracted_content = match.group(1).strip()
        else:
          extracted_content = content

        score = self.Add_Sample(extracted_content, re=True)
        return score
      except Exception as e:
        print(f"[Retry {attempt+1}] Exception during LLM call: {e}")

    print("[Error] Failed after 10 attempts. Returning None.")
    return ""

  def Prompt_LLM_Step(self, Meta_Prompt):
    openai.api_key = self.API_key
    openai.api_base = "https://openrouter.ai/api/v1"
    for i in tqdm(range(self.steps), desc="Prompting LLM " + str(self.steps) + " steps: " + "gpt-4-turbo"):
      reward = self.Prompt_LLM_Epoch_Temp_0(Meta_Prompt)
      asyncio.run(self.Prompt_LLM_Epoch(Meta_Prompt))
      self.Add_Padding()
      print("Step %d: The Best Performance, INSTRUCTION=%s, score=%f"%(i + 1, self.samples[0][0], self.samples[0][1]))
      sorted_samples = sorted(self.samples, key=lambda x: x[1], reverse=True)
      self.result_lst.append(sorted_samples[0][1])
      print(self.result_lst)
    print(f"Final Instruction: {self.samples[0][0]}")
    return reward

new_env = env()
new_Template_Padding = Template_Padding(new_env)
new_Template_Padding.Prompt_LLM_Step("Generate an instruction that is different from all the instructions <INS> above, and has a higher score than all the instructions <INS> above. The instruction should begin with <INS> and end with </INS>. The instruction should be concise, effective, and generally applicable to all problems above.")