import re
import json
import os
import time
import random
import logging
import argparse
import tempfile
import requests
import warnings
from functools import partial
from pathlib import Path
from tqdm import trange, tqdm
from typing import List, Iterable, Mapping, Dict, Any

from multiprocessing import Queue, Process, Pool
from subprocess import STDOUT, check_output
from concurrent.futures import ThreadPoolExecutor

import openai
from openai import OpenAI
import google.generativeai as genai

random.seed(42)
warnings.filterwarnings("ignore")
QUEUE_SIZE = 1000000


def query_gpt(prompt: str, history: List[str] = None, model: str = 'gpt-4o', num_samples: int = 1,
             temperature: float = 0.8, top_p: float = 1.0, max_tokens: int = 2048, max_retry: int = 3):
    """ Query GPT model with a given prompt. """
    messages = []
    
    for turn in history:
        messages.append({"role": "user", "content": turn["prompt"]})
        messages.append({"role": "assistant", "content": turn["response"]})
    messages.append({"role": "user", "content": prompt})
    
    client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])

    chat_completion = None
    for _ in range(max_retry):
        try:
            chat_completion = client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_tokens,
                n=num_samples,
            )
            break
        except TimeoutError:
            continue
    if chat_completion is not None:
        return chat_completion.choices[0].message.content
    else:
        return None


def query_gemini(prompt: str, system_prompt: str = None, history: List[str] = None,
                 model: str = 'gemini-1.5-flash', num_samples: int = 1,
                 temperature: float = 0.8, top_p: float = 1.0,
                 max_tokens: int = 2048, max_retry: int = 3):
    """ Query Gemini model with a given prompt. """
    # Configure the model, we use the system prompt for task description
    genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
    model = genai.GenerativeModel(
        model_name=model,
        system_instruction=system_prompt
    )
    
    # Format history for Gemini's expected structure, used for in-context learning
    formatted_history = []
    if history:
        for turn in history:
            formatted_history.extend([
                {"role": "user", "parts": turn["prompt"]},
                {"role": "model", "parts": turn["response"]}
            ])
    
    # Start chat with history if provided
    chat = model.start_chat(history=formatted_history) if history else model.start_chat()
    
    # Query with retries
    for _ in range(max_retry):
        try:
            response = chat.send_message(
                prompt,
                generation_config=genai.types.GenerationConfig(
                    temperature=temperature,
                    top_p=top_p,
                    max_output_tokens=max_tokens,
                    candidate_count=num_samples
                )
            )
            # return response.text      # This will only return the first candidate's text
            # Get text from all candidates instead of just response.text
            if num_samples > 1:
                return [candidate.text for candidate in response.candidates]
            else:
                return response.text
        except Exception as e:
            if _ == max_retry - 1:  # Last retry
                print(f"Failed to query Gemini after {max_retry} attempts: {str(e)}")
                return None
            continue
    
    return None


def worker_aug(task_queue, done_queue, worker_func):
    max_retry = 3
    
    for item in iter(task_queue.get, "STOP"):
        response = None
        for _ in range(max_retry):
            try:
                response = worker_func(item)
            except Exception as e:
                print("error:", e)
                continue

            if response is not None:
                break
            
            time.sleep(random.randint(1, 3))
        else:
            continue

        done_queue.put(item)

    done_queue.put("COMPLETE")


def build_training_aug(task: str, rule_idx: int, worker_func, num_processes: int = 10, data_dir: str = None):
        
    task_queue, done_queue = Queue(maxsize=QUEUE_SIZE), Queue(maxsize=QUEUE_SIZE)

    def read_data_into_queue():
        """ Read data into queue. Different task (i.e., dataset) has different read logic. """
        if task == "anli":
            raise ValueError(f"ANLI is not currently supported for reverse augmentation")
        elif task == "table_mwp":
            if data_dir is not None:
                input_file = os.path.join(data_dir, "table_mwp/train.json")
            else:
                input_file = f"./data/table_mwp/train.json"
            
            with open(input_file, "r") as r:
                data = json.load(r)
                for pid in data.keys():
                    # add pid as an additional field
                    data[pid]["pid"] = pid
                    task_queue.put(data[pid])
                print(f"read {len(data)} examples from {input_file}")
        else:
            if data_dir is not None:
                input_file = os.path.join(data_dir, f"{task}/train.json")
            else:
                input_file = f"./data/{task}/train.json"
            with open(input_file, "r") as r:
                data = json.load(r)
                for item in data:
                    task_queue.put(item)
                print(f"read {len(data)} examples from {input_file}")

        for _ in range(num_processes):
            task_queue.put('STOP')

    processes = []
    for _ in range(num_processes):
        process = Process(target=partial(worker_aug, worker_func=worker_func), 
                          args=(task_queue, done_queue))
        process.start()
        processes.append(process)

    process = Process(target=read_data_into_queue)
    process.start()
    progress_bar = tqdm()
    
    # create output folder based on the task name if not exists
    output_folder = f"./output/{task}/rule_{rule_idx}"
    os.makedirs(output_folder, exist_ok=True)
    output_file = f"{output_folder}/reverse_aug_questions.jsonl"
    with open(output_file, 'w') as w:
        num_finished = 0
        num_save = 0
        while num_finished < num_processes:
            item = done_queue.get()
            if item == 'COMPLETE':
                num_finished += 1
            else:
                w.write(json.dumps(item, ensure_ascii=False) + '\n')
                w.flush()
                num_save += 1
                progress_bar.update(1)

    progress_bar.close()


def reverse_aug_prompt_response(
    x,
    backbone_model: str = "gemini-1.5-flash",
    response_key: str = "response",
    num_generations: int = 1,               # Number of samples to generate
    temperature: float = 0.8,               # Temperature for sampling
    top_p: float = 1.0,                     # Top-p for sampling
    max_tokens: int = 1024,                 # Maximum number of tokens to generate
    task: str = "strategy_qa",
    rule: str = None,
):
    if "history" in x:
        history = x["history"]
    else:
        history = None
    
    system_prompt = "Your task is to reformulate the given question following the provided rule, based on the input question and its correct answer."+"\n\n"+"Rule: "+rule

    if task == "strategy_qa":
        answer_str = "Yes" if x["answer"] else "No"
        user_prompt = "Original Question: "+x["question"]+"\n\n"+"Correct Answer: "+answer_str
    elif task == "commonsense_qa" or task == "arc_challenge" or task == "date":
        choices = x["choices"]["text"]
        choicesKey = x["choices"]["label"]
        answer = x["answer"]
        answerKey = x["answerKey"]
        # combine the choices together with the keys
        choices_str = ""
        for i in range(len(choices)):
            choices_str += " (" + choicesKey[i] + ")" + choices[i]
        answer_str = " (" + answerKey + ")" + answer
        user_prompt = "Original Question: "+x["question"]+"\n\n"+"Choices: "+choices_str+"\n\n"+"Correct Answer: "+ answer_str
    elif task == "table_mwp":
        table_prompt = "Table to use: "+x["table"]
        user_prompt = table_prompt + "\n\n" + "Original Question: "+x["question"]+"\n\n"+"Correct Answer: "+x["answer"]
    elif task == "gsm8k" or task == "math":
        user_prompt = "Original Question: "+x["question"]+"\n\n"+"Correct Answer: "+x["answer"]
    else:
        raise ValueError(f"Task {task} is not supported for reverse augmentation")
    
    max_retries = 1
    response = None
    for _ in range(max_retries):
        if "gpt" in backbone_model:
            response = query_gpt(user_prompt, history, backbone_model, num_generations, temperature, top_p, max_tokens)
        elif "gemini" in backbone_model:
            response = query_gemini(user_prompt, system_prompt, history,
                                    backbone_model, num_generations,
                                    temperature, top_p, max_tokens)
        else:
            raise ValueError(f"Backbone model {backbone_model} is not supported for reverse augmentation")
        if response is None:
            continue
        else:
            break
    
    x[response_key] = response
    return x

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_process", type=int, default=10)
    parser.add_argument("--backbone", type=str, default="gemini-1.5-flash")
    parser.add_argument("--response_key", type=str, default="response")
    parser.add_argument("--data_dir", type=str, default=None)
    # LLM generation parameters
    parser.add_argument("--num_generations", type=int, default=1)
    parser.add_argument('--temperature', type=float, default=0.8)           # Temperature for sampling, default is 0.8
    parser.add_argument('--top_p', type=float, default=1.0)                 # Top-p for sampling, default is 1.0
    parser.add_argument('--max_tokens', type=int, default=1024)             # Maximum number of tokens to generate, default is 1024
    # Task parameters
    parser.add_argument("--rule", type=str, default=None)
    args = parser.parse_args()

    # Load the rules
    assert args.rule is not None, "Rule is required for reverse augmentation"
    with open(args.rule, "r") as f:
        rules = [json.loads(line) for line in f]
        # Only keep the 'rule' field
        rules = [rule['rule'] for rule in rules]

    # Get a list of tasks
    tasks = ["strategy_qa", "commonsense_qa", "arc_challenge", "date", "table_mwp", "gsm8k", "math"]
    print(f"\nStarting augmentation with {len(tasks)} tasks and {len(rules)} rules per task")
    for task in tasks:
        for rule_idx, rule in enumerate(rules):
            print(f"\nStarting augmentation for task {task} with rule {rule_idx}")
            build_training_aug(
                task=task,
                rule_idx=rule_idx,
                worker_func=partial(reverse_aug_prompt_response, backbone_model=args.backbone,
                                    response_key=args.response_key, num_generations=args.num_generations,
                                    temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens,
                                    task=task, rule=rule),
                num_processes=args.num_process,
                data_dir=args.data_dir,
            )


