import asyncio
import random
import aiohttp
import re
import math
from aiohttp import ClientSession
import nest_asyncio
import copy
import openai
from tqdm.auto import tqdm
import pickle
import pandas as pd
import numpy as np

import torch
from torch import nn
from datasets import load_dataset

nest_asyncio.apply()

with open("PATH/DATA_FILE.pkl", "rb") as f:
    data = pickle.load(f)

prompt_data = pd.DataFrame(data)
prompt_data.head()

from datasets import load_dataset

def get_sampled_gsm8k(fraction=0.01, seed=42):
    dataset = load_dataset("gsm8k", "main", split="train")
    print(f"Full GSM8K train set: {len(dataset)}")

    sampled_dataset = dataset.train_test_split(test_size=fraction, seed=seed)['test']
    print(f"Sample subset: {len(sampled_dataset)}")

    return sampled_dataset

def get_full_gsm8k_test():
    dataset = load_dataset("gsm8k", "main", split="test")
    print(f"Full GSM8K test set: {len(dataset)}")
    return dataset


async def async_get_single_answer(prompt, session, api_key, temperature=0):
    url = "https://openrouter.ai/api/v1/chat/completions"

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": "qwen/qwen-2.5-7b-instruct",
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ],
        "temperature": temperature
    }
    while True:
        try:
            await asyncio.sleep(random.uniform(0.5, 1.5))
            async with session.post(url, json=payload, headers=headers) as response:
                if response.status == 429:
                    print("Rate limit exceeded. Waiting before retrying...")
                    await asyncio.sleep(10)
                    continue
                response_json = await response.json()
                answer = response_json['choices'][0]['message']['content']
                full_match = re.search(r"<Answer>(.*?)</Answer>", answer, re.DOTALL)
                if full_match:
                    inner_text = full_match.group(1)
                    number_match = re.search(r"([0-9]+(?:\.[0-9]+)?)", inner_text)
                    if number_match:
                        extracted_answer = number_match.group(1)
                        return extracted_answer
                    else:
                        print("No valid number found in <Answer> tags, Answer is WRONG")
                        return "WRONG"
                else:
                    return "WRONG"
                # await asyncio.sleep(5)
        except Exception as e:
            print(f"An error occurred: {e}")
            print("Retrying...")
            await asyncio.sleep(5)

async def async_get_batch_answers(prompts, Instruction, api_key, temperature=0, batch_size=64, delay_between_batches=10):
    answers = []
    print(prompts[0].format(Instruction))
    async with ClientSession() as session:
        for i in range(0, len(prompts), batch_size):
            batch = prompts[i:i + batch_size]
            tasks = [
                async_get_single_answer(prompt.format(Instruction), session, api_key, temperature)
                for prompt in batch
            ]
            batch_results = await asyncio.gather(*tasks)
            answers.extend(batch_results)
            if i + batch_size < len(prompts):
                print(f"Batch {i//batch_size + 1} completed, sleeping for {delay_between_batches} seconds...")
                await asyncio.sleep(delay_between_batches)
    return answers

def extract_groundtruth_number(text):
    match = re.search(r"####\s*([+-]?[0-9]+(?:\.[0-9]+)?)", text)
    if match:
        return float(match.group(1))
    return None

def compute_accuracy(predictions, dataset, tol=1e-5):
    correct = 0
    total = len(predictions)

    for pred_str, sample in zip(predictions, dataset):
        try:
            pred_value = float(pred_str.strip())
        except ValueError:
            print("Invalid numeric prediction:", pred_str)
            continue

        gt_value = extract_groundtruth_number(sample["answer"])
        if gt_value is None:
            print("Failed to extract ground truth from sample:", sample)
            continue

        if math.isclose(pred_value, gt_value, rel_tol=tol):
            correct += 1

    return correct / total if total > 0 else 0

class env:
  def __init__(self):
    self.sampled_dataset = get_sampled_gsm8k()
    self.prompts = [
        sample["question"].strip()+ "\nProvide your final result within the tags <Answer>FINAL_NUMERICAL_ANSWER</Answer>.\n" + "\n{}"
        for sample in self.sampled_dataset
    ]
    self.sample_count = 0
    self.test_dataset = get_full_gsm8k_test()
    self.test_prompts = [
        sample["question"].strip()+ "\nProvide your final result within the tags <Answer>FINAL_NUMERICAL_ANSWER</Answer>.\n" + "\n{}"
        for sample in self.test_dataset
    ]

    self.API_key = "[API_KEY]"

  def evaluator(self, Instruction):
    answer = asyncio.run(async_get_batch_answers(self.prompts, Instruction, self.API_key))
    accuracy = compute_accuracy(answer, self.sampled_dataset)
    print(f"Instruction: {Instruction}, score={int(round(accuracy*100))}")
    return int(round(accuracy*100))

  def random_sample_exemplars(self):
    sampled_samples = self.sampled_dataset.shuffle(seed=42 + self.sample_count).select(range(3))
    self.sample_count += 1
    return sampled_samples

  def test_env(self, Instruction):
    answer = asyncio.run(async_get_batch_answers(self.test_prompts, Instruction, self.API_key))
    accuracy = compute_accuracy(answer, self.test_dataset)
    print(f"Instruction: {Instruction}, score={int(round(accuracy*100))}")
    return accuracy*100

class Network(nn.Module):
  def __init__(self, input_dim, hidden_size=100, depth=1, init_params=None):
    super(Network, self).__init__()

    self.activate = nn.ReLU()
    self.layer_list = nn.ModuleList()
    self.layer_list.append(nn.Linear(input_dim, hidden_size))
    for i in range(depth-1):
      self.layer_list.append(nn.Linear(hidden_size, hidden_size))
    self.layer_list.append(nn.Linear(hidden_size, 1))

    if init_params is None:
      for i in range(len(self.layer_list)):
        torch.nn.init.normal_(self.layer_list[i].weight, mean=0, std=1.0)
        torch.nn.init.normal_(self.layer_list[i].bias, mean=0, std=1.0)
    else:
      for i in range(len(self.layer_list)):
        self.layer_list[i].weight.data = init_params[i*2]
        self.layer_list[i].bias.data = init_params[i*2+1]

  def forward(self, x):
    y = x
    for i in range(len(self.layer_list)-1):
      y = self.activate(self.layer_list[i](y))
    y = self.layer_list[-1](y)
    y = y.squeeze(-1)
    return y

class AdversarialBandit:
  def __init__(self, prompt_data):
    self.prompt_data = prompt_data
    self.model = Network(3072, 512, 1).to('cuda').double()
    self.lr = 1e-3
    self.epochs = 5000
    self.optimizer_fn = torch.optim.Adam
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1)
    self.loss_fn = nn.MSELoss()
    self.init_model_weight = copy.deepcopy(self.model.state_dict())
    self.eta = 50
    self.prompt_score_trace = []
    self.embedding_score_trace = []
    self.current_prompt = """Generate an instruction that is different from all the instructions <INS> above, and has a higher score than all the instructions <INS> above. The instruction should begin with <INS> and end with </INS>. The instruction should be concise, effective, and generally applicable to all problems above."""
    self.prompt_embedding = self.prompt_data.loc[self.prompt_data["Z2"] == self.current_prompt, "Embedding"].iloc[0]
    self.Stop_EXP3 = 0

  def restart_model(self):
    self.model.load_state_dict(copy.deepcopy(self.init_model_weight))
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1/self.embedding_score_trace.__len__())
    self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 0.1 if epoch > 3000 else 1.0)

  def train_model(self):
    self.restart_model()
    self.model.to('cuda')

    embeddings = [item[0] for item in self.embedding_score_trace]
    scores = [(item[1])/100 for item in self.embedding_score_trace]

    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float64) for e in embeddings]).to('cuda')
    scores_tensor = torch.tensor(scores, dtype=torch.float64).to('cuda')

    dataset = torch.utils.data.TensorDataset(embeddings_tensor, scores_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

    self.model.train()
    for epoch in tqdm(range(self.epochs), desc="Training Model"):
      for batch_embeddings, batch_scores in dataloader:
        self.optimizer.zero_grad()

        predictions = self.model(batch_embeddings)

        loss = self.loss_fn(predictions, batch_scores)
        loss.backward()

        self.optimizer.step()

      self.scheduler.step()

      if epoch % 100 == 0:
        print(f"Epoch {epoch}/{self.epochs}, Loss: {loss.item()}")

    print("Training complete.")

  def select(self, iteration):
    embeddings = torch.stack(
      [torch.tensor(e, dtype=torch.float64) for e in self.prompt_data['Embedding']]
    ).to('cuda')

    with torch.no_grad():
      scores = self.model(embeddings).squeeze(-1)

    shifted_scores = scores
    self.prompt_data['Score'] = shifted_scores.cpu().tolist()

    self.prompt_score_trace.append(self.prompt_data['Score'])
    score_tensor = torch.tensor(self.prompt_score_trace, dtype=torch.float64)
    exp_scores = torch.exp(self.eta * torch.sum(score_tensor, dim=0))
    total_score = torch.sum(exp_scores)
    normalized_exp_scores = exp_scores / total_score

    top_scores, top_indices = torch.topk(normalized_exp_scores, 10)
    print("Top 10 probabilities: ", top_scores)
    print("Indices of Top 10 probabilities: ", top_indices)

    if not torch.isnan(normalized_exp_scores).any() and not torch.isinf(normalized_exp_scores).any() and torch.sum(normalized_exp_scores) > 0:
      best_index = torch.multinomial(normalized_exp_scores, 1).item()
      self.best_index = best_index
    else:
      print("Encountered NaN or Inf in normalized_exp_scores. Stopping EXP3.")
      self.Stop_EXP3 = 1
      return

    print("The index of selected prompt is: %d"%(self.best_index))
    print("The selected prompt's cumulative score: ", torch.sum(score_tensor, dim=0)[self.best_index])
    best_prompt = self.prompt_data.iloc[self.best_index]['Z2']
    self.current_prompt = best_prompt
    self.prompt_embedding = self.prompt_data.iloc[self.best_index]['Embedding']

class Template_Padding:
  def __init__(self, env, prompt_data):
    self.exemplars_template = """text:
{}
score:
{}

"""
    self.exemplars_template_2 = """Problem:
Q: {}
A: <INS>

Ground truth Answer:
{}

"""
    self.env = env
    samples_evaluator = []
    initial_instruction = "Let's think step by step."
    samples_evaluator.append(self.env.evaluator(initial_instruction))
    self.samples = list(zip([initial_instruction], samples_evaluator))
    self.samples.sort(key=lambda x: x[1], reverse=True)

    self.input_exemplars_1 = ""
    self.input_exemplars_2 = ""
    self.Add_Padding()

    self.prompt_1 = """Your task is to generate the instruction <INS>. Below are some previous instructions with their scores. The score ranges from 0 to 100.
"""
    self.prompt_2 = """Below are some problems.
"""
    self.API_key = "[API_KEY]"
    self.result_lst = []

    self.steps=20

    self.bandit = AdversarialBandit(prompt_data=prompt_data)
    self.prompt_3 = self.bandit.current_prompt

  def Add_Padding(self):
    self.samples.sort(key=lambda x: x[1], reverse=False)

    samples_copy = copy.deepcopy(self.samples)

    sample_set = list(set(samples_copy))
    sample_set.sort(key=lambda x: x[1], reverse=False)

    selected_samples = sample_set[-20:]

    self.input_exemplars_1 = ""
    for sample in selected_samples:
        self.input_exemplars_1 += self.exemplars_template.format(sample[0], sample[1])

    self.input_exemplars_2 = ""
    sampled_exemplars = self.env.random_sample_exemplars()
    for sample in sampled_exemplars:
      self.input_exemplars_2 += self.exemplars_template_2.format(sample["question"], extract_groundtruth_number(sample["answer"]))


  def Add_Sample(self, new_instruction, re=False):
    score = self.env.evaluator(new_instruction)
    self.samples.append((new_instruction, score))
    if re == True:
      return score

  def Add_All(self, Meta_Prompt):
    prompt_template =  self.prompt_1 + "\n{}" + self.prompt_2 + "\n{}" + Meta_Prompt
    return prompt_template.format(self.input_exemplars_1, self.input_exemplars_2)

  async def Prompt_LLM_Epoch(self, Meta_Prompt, epoch=7):
    async def fetch_one():
      max_retries = 3
      for attempt in range(max_retries):
        response = await openai.ChatCompletion.acreate(
          model="openai/gpt-4-turbo",
          messages=[{"role": "user", "content": self.Add_All(Meta_Prompt)}],
          temperature=1,
        )
        content = response.choices[0].message['content'].strip()
        match = re.search(r"<INS>(.*?)</INS>", content, re.DOTALL)
        if match:
            extracted_content = match.group(1).strip()
        else:
            extracted_content = content

        print(extracted_content)
        self.Add_Sample(extracted_content)
        return extracted_content

      print("[Warning] Failed after 3 attempts.")

    tasks = [fetch_one() for _ in range(epoch)]
    answers = await asyncio.gather(*tasks)
    return answers

  def Prompt_LLM_Epoch_Temp_0(self, Meta_Prompt):
    max_retries = 10
    print(self.Add_All(Meta_Prompt))
    for attempt in range(max_retries):
      try:
        response = openai.ChatCompletion.create(
          model="openai/gpt-4-turbo",
          messages=[{"role": "user", "content": self.Add_All(Meta_Prompt)}],
          temperature=0
        )
        content = response.choices[0].message['content'].strip()
        match = re.search(r"<INS>(.*?)</INS>", content, re.DOTALL)
        if match:
          extracted_content = match.group(1).strip()
        else:
          extracted_content = content

        score = self.Add_Sample(extracted_content, re=True)
        return score
      except Exception as e:
        print(f"[Retry {attempt+1}] Exception during LLM call: {e}")

    print("[Error] Failed after 10 attempts. Returning None.")
    return ""

  def Prompt_LLM_Step(self):
    openai.api_key = self.API_key
    openai.api_base = "https://openrouter.ai/api/v1"
    for i in tqdm(range(self.steps), desc="Prompting LLM " + str(self.steps) + " steps: " + "gpt-4-turbo"):
      reward = self.Prompt_LLM_Epoch_Temp_0(self.prompt_3)
      asyncio.run(self.Prompt_LLM_Epoch(self.prompt_3))
      self.Add_Padding()
      print("Step %d: The Best Performance, INSTRUCTION=%s, score=%f"%(i + 1, self.samples[0][0], self.samples[0][1]))
      sorted_samples = sorted(self.samples, key=lambda x: x[1], reverse=True)
      self.result_lst.append(sorted_samples[0][1])
      print(self.result_lst)
      self.bandit.embedding_score_trace.append([self.bandit.prompt_embedding, reward])
      if not self.bandit.Stop_EXP3:
        self.bandit.train_model()
        self.bandit.select(i)
      self.prompt_3 = self.bandit.current_prompt
    print(f"Final Instruction: {self.samples[0][0]}")
    return reward

new_env = env()
new_Template_Padding = Template_Padding(new_env, copy.deepcopy(prompt_data))
new_Template_Padding.Prompt_LLM_Step()