import re
import sys
import os
import json
import time
import random
import logging
import argparse
import tempfile
import requests
import warnings
from functools import partial
from pathlib import Path
from tqdm import trange, tqdm
from datetime import datetime
from typing import List, Iterable, Mapping, Dict, Any

from multiprocessing import Queue, Process, Pool
from subprocess import STDOUT, check_output
from concurrent.futures import ThreadPoolExecutor

import openai
from openai import OpenAI
import google.generativeai as genai
from google.cloud import storage
import vertexai
from vertexai.batch_prediction import BatchPredictionJob
from vertexai.generative_models import GenerativeModel

random.seed(42)
warnings.filterwarnings("ignore")
QUEUE_SIZE = 1000000


def reverse_aug_prompt_response(
    x,
    task: str = "strategy_qa",
    rule: str = None,
):
    if "history" in x:
        history = x["history"]
    else:
        history = None
    
    system_prompt = "Your task is to reformulate the given question following the provided rule, based on the input question and its correct answer."+"\n\n"+"Rule: "+rule

    if task == "strategy_qa":
        answer_str = "Yes" if x["answer"] else "No"
        user_prompt = "Original Question: "+x["question"]+"\n\n"+"Correct Answer: "+answer_str
    elif task == "commonsense_qa" or task == "arc_challenge" or task == "date":
        choices = x["choices"]["text"]
        choicesKey = x["choices"]["label"]
        answer = x["answer"]
        answerKey = x["answerKey"]
        # combine the choices together with the keys
        choices_str = ""
        for i in range(len(choices)):
            choices_str += " (" + choicesKey[i] + ")" + choices[i]
        answer_str = " (" + answerKey + ")" + answer
        user_prompt = "Original Question: "+x["question"]+"\n\n"+"Choices: "+choices_str+"\n\n"+"Correct Answer: "+ answer_str
    elif task == "table_mwp":
        table_prompt = "Table to use: "+x["table"]
        user_prompt = table_prompt + "\n\n" + "Original Question: "+x["question"]+"\n\n"+"Correct Answer: "+x["answer"]
    elif task == "gsm8k" or task == "math":
        user_prompt = "Original Question: "+x["question"]+"\n\n"+"Correct Answer: "+x["answer"]
    else:
        raise ValueError(f"Task {task} is not supported for reverse augmentation")
    
    return user_prompt, system_prompt, history


def build_training_aug(task: str, rule: str, rule_idx: int, data_dir: str = None):
    dataset = []
    """ Read data into queue. Different task (i.e., dataset) has different read logic. """
    if task == "anli":
        raise ValueError(f"ANLI is not currently supported for reverse augmentation")
    elif task == "table_mwp":
        if data_dir is not None:
            input_file = os.path.join(data_dir, "table_mwp/train.json")
        else:
            input_file = f"./data/table_mwp/train.json"
        
        with open(input_file, "r") as r:
            data = json.load(r)
            for pid in data.keys():
                # add pid as an additional field
                data[pid]["pid"] = pid
                dataset.append(data[pid])
            print(f"read {len(data)} examples from {input_file}")
    else:
        if data_dir is not None:
            input_file = os.path.join(data_dir, f"{task}/train.json")
        else:
            input_file = f"./data/{task}/train.json"
        with open(input_file, "r") as r:
            data = json.load(r)
            for item in data:
                dataset.append(item)
            print(f"read {len(data)} examples from {input_file}")

    # Construct prompt requests
    batch_requests = []
    for item in dataset:
        user_prompt, system_prompt, history = reverse_aug_prompt_response(item, task, rule)
        """ We can add generationConfig to the request to control the output of the model.
        "generationConfig": {
                    "temperature": 0.8,
                    "topP": 1.0,
                    "topK": 40,
                    "candidateCount": 1,
                    "maxOutputTokens": 1024,
                }
        """
        batch_requests.append({
            "request":{
                "contents": [{
                    "role": "user",
                    "parts": [
                        {"text": user_prompt},
                    ]}
                ],
                "systemInstruction": {
                    "role": "user",
                    "parts": [{
                        "text": system_prompt
                    }]
                }
            }
        })
    print(f"built {len(batch_requests)} requests")
    
    # create output folder based on the task name if not exists
    output_folder = f"./output/{task}/rule_{rule_idx}"
    os.makedirs(output_folder, exist_ok=True)
    output_file = f"{output_folder}/batch_requests.jsonl"
    
    print(f"Saving {len(batch_requests)} requests to {output_file}")
    with open(output_file, 'w') as w:
        for request in tqdm(batch_requests, desc="Saving requests"):
            w.write(json.dumps(request, ensure_ascii=False) + '\n')
            w.flush()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_process", type=int, default=10)
    parser.add_argument("--backbone", type=str, default="gemini-1.5-flash")
    parser.add_argument("--response_key", type=str, default="response")
    parser.add_argument("--data_dir", type=str, default=None)
    # LLM generation parameters
    parser.add_argument("--num_generations", type=int, default=1)
    parser.add_argument('--temperature', type=float, default=0.8)           # Temperature for sampling, default is 0.8
    parser.add_argument('--top_p', type=float, default=1.0)                 # Top-p for sampling, default is 1.0
    parser.add_argument('--max_tokens', type=int, default=1024)             # Maximum number of tokens to generate, default is 1024
    # Task parameters
    parser.add_argument("--rule", type=str, default=None)
    args = parser.parse_args()

    # Load the rules
    assert args.rule is not None, "Rule is required for reverse augmentation"
    with open(args.rule, "r") as f:
        rules = [json.loads(line) for line in f]
        # Only keep the 'rule' field
        rules = [rule['rule'] for rule in rules]
    
    # Get a list of tasks
    tasks = ["strategy_qa", "commonsense_qa", "arc_challenge", "date", "table_mwp", "gsm8k", "math"]
    print(f"\nStarting augmentation with {len(tasks)} tasks and {len(rules)} rules per task")
    for task in tqdm(tasks, desc="Tasks", position=0):
        print(f"\nProcessing task: {task}")
        for rule_idx, rule in tqdm(enumerate(rules), desc=f"Rules for {task}", total=len(rules), position=1, leave=False):
            # This will store the batch requests prompt in the output folder
            build_training_aug(
                task=task,
                rule=rule,
                rule_idx=rule_idx,
                data_dir=args.data_dir
            )
    
    PROJECT_ID = "gen-lang-client-0908767230"
    LOCATION = "us-central1"

    # TODO: Need to change the input and output uri to the correct path
    # Initialize vertexai
    vertexai.init(project=PROJECT_ID, location="us-central1")

    # TODO: Need to loop this for all tasks and rules
    input_uri = "gs://batch_query_gemini/core_think/batch_requests.jsonl"
    output_uri = "gs://batch_query_gemini/core_think/batch_requests_output.jsonl"

    # Submit a batch prediction job with Gemini model
    batch_prediction_job = BatchPredictionJob.submit(
        source_model=args.backbone,
        input_dataset=input_uri,
        output_uri_prefix=output_uri,
    )

    # Check job status
    print(f"Job resource name: {batch_prediction_job.resource_name}")
    print(f"Model resource name with the job: {batch_prediction_job.model_name}")
    print(f"Job state: {batch_prediction_job.state.name}")

    # Refresh the job until complete
    while not batch_prediction_job.has_ended:
        time.sleep(5)
        batch_prediction_job.refresh()

    # Check if the job succeeds
    if batch_prediction_job.has_succeeded:
        print("Job succeeded!")
    else:
        print(f"Job failed: {batch_prediction_job.error}")

    # Check the location of the output
    print(f"Job output location: {batch_prediction_job.output_location}")
