import json
import os
import time
import threading
from collections import deque
import fire
import torch
import openai
from concurrent.futures import ThreadPoolExecutor, as_completed

from utils import DATA_PATH, clean_generation
from termcolor import colored
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
)
from tqdm.auto import tqdm

class RateLimiter:
    def __init__(self, max_calls: int, period: float):
        self.max_calls = max_calls
        self.period = period
        self.lock = threading.Lock()
        self.calls = deque()

    def acquire(self):
        with self.lock:
            now = time.time()
            while self.calls and now - self.calls[0] > self.period:
                self.calls.popleft()
            if len(self.calls) >= self.max_calls:
                sleep_time = self.period - (now - self.calls[0])
                time.sleep(sleep_time)
            self.calls.append(time.time())

# instantiate a global rate limiter for GPT-4o
RATE_LIMITER = RateLimiter(max_calls=8_000, period=60.0)

# Ensure your OPENAI_API_KEY is set in the environment
def call_gpt4o(
    prompt: str,
    temperature: float,
    max_tokens: int,
    retries: int = 3,
    backoff_base: float = 1.0,
) -> str:
    attempt = 0
    while True:
        attempt += 1
        RATE_LIMITER.acquire()
        try:
            response = openai.ChatCompletion.create(
                model="gpt-4o",
                messages=[{"role": "user", "content": prompt}],
                temperature=temperature,
                max_tokens=max_tokens,
            )
            return response.choices[0].message.content
        except openai.error.APIError as e:
            # Retry on 502 Bad Gateway or transient APIError
            status = getattr(e, 'http_status', None)
            if attempt <= retries and status == 502:
                sleep_time = backoff_base * (2 ** (attempt - 1))
                print(f"Received 502, retrying in {sleep_time}s (attempt {attempt}/{retries})")
                time.sleep(sleep_time)
                continue
            else:
                raise


def load_model_and_tokenizer(
    model_name: str,
):
    attn_implementation = "flash_attention_2" if "Phi-3" in model_name else None
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
        attn_implementation=attn_implementation,
    ).eval()

    if "Mistral" in model_name:
        tokenizer = AutoTokenizer.from_pretrained(model_name, revision="pr/51")
        tokenizer.pad_token = tokenizer.eos_token
        model.generation_config.pad_token_id = tokenizer.pad_token_id
    elif "Meta-Llama" in model_name:
        tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        tokenizer.pad_token = tokenizer.eos_token
        model.generation_config.pad_token_id = tokenizer.pad_token_id
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)

    return model, tokenizer


def load_data(
    data_path: str,
) -> list[dict]:
    data = []
    with open(data_path, "r") as fin:
        for line in fin:
            data.append(json.loads(line))
    return data


def build_respond_prompts(
    texts: list[str],
    dataset_name: str,
) -> list[str]:
    RESPOND_REDDIT_PROMPT = """Write a response to this Reddit comment: {}

Keep the response around {} words.

Do not include the original comment in your response.

Only output the comment, do not include any other details.

Response:
"""
    RESPOND_AMAZON_PROMPT = """Here's an Amazon review: {}

Please write another review, of about {} words, but about something different.

Do not include the original review in your response.

Only output the review, do not include any other details.

Response:
"""
    RESPOND_BLOG_PROMPT = """Here's a snippet of a Blog post: {}
    
Please write another snippet, of about {} words, but about something different.

Do not include the original snippet in your response.

Only output the snippet, do not include any other details.

Response:
"""

    d = {
        "reddit": RESPOND_REDDIT_PROMPT,
        "amazon": RESPOND_AMAZON_PROMPT,
        "blogs": RESPOND_BLOG_PROMPT,
    }

    prompts = []
    for t in texts:
        lenwords = len(t.split(" "))
        prompts.append(d[dataset_name].format(t, lenwords))

    return prompts


def main(
    data_name: str = "MTD_reddit.jsonl",
    dataset_name: str = "reddit",
    model_name: str = "mistralai/Mistral-7B-Instruct-v0.3",
    batch_size: int = 64,
    max_new_tokens: int = 128 + 32,
    temperature: float = 0.7,
    top_p: float = 0.9,
    debug: bool = False,
    max_workers: int = 40,
):
    assert dataset_name in ["reddit", "amazon", "blogs"], "Invalid dataset name"
    assert model_name in [
        "mistralai/Mistral-7B-Instruct-v0.3",
        "meta-llama/Meta-Llama-3-8B-Instruct",
        "gpt-4o",
    ], "Invalid model name"

    args = locals()
    for k, v in args.items():
        print(colored(k, "green"), "=", colored(v, "yellow"))

    mtd_path = os.path.join(DATA_PATH, "mtd")
    data = load_data(os.path.join(mtd_path, data_name))
    if debug:
        data = data[:10]

    prompts = build_respond_prompts(
        [d["content_text"] for d in data], dataset_name
    )

    os.makedirs(os.path.join(mtd_path, "generations"), exist_ok=True)
    tmpstr = data_name.replace(".jsonl", "").replace("_", "-")
    savename = os.path.join(
        mtd_path,
        "generations",
        f"{tmpstr}_{model_name.replace('/', '-')}_temp={temperature}_top-p={top_p}",
    ) + (".debug" if debug else ".jsonl")
    print(colored(f"savename={savename}", "yellow"))

    if model_name != "gpt-4o":
        model, tokenizer = load_model_and_tokenizer(model_name)
        generation_config = GenerationConfig(
            do_sample=True,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
        )

        local_outputs: list[str] = []
        for i in tqdm(range(0, len(prompts), batch_size), desc="Local gen"):
            batch = prompts[i : i + batch_size]
            prompt_lens = [len(b) for b in batch]
            inputs = tokenizer(
                batch,
                max_length=256,
                padding=True,
                truncation=True,
                return_tensors="pt",
            ).to("cuda")
            out = model.generate(**inputs, generation_config=generation_config)
            local_outputs.extend(
                tokenizer.decode(out[j], skip_special_tokens=True)[prompt_lens[j]:]
                for j in range(len(batch))
            )

    gpt_outputs: list[str] = []
    if model_name == "gpt-4o":
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {executor.submit(call_gpt4o, prompt, temperature, max_new_tokens): idx for idx, prompt in enumerate(prompts)}
            gpt_outputs = [None] * len(prompts)
            for fut in tqdm(as_completed(futures), total=len(prompts), desc="GPT-4o gen"):
                idx = futures[fut]
                gpt_outputs[idx] = fut.result()

    is_reddit = dataset_name == "reddit"
    with open(savename, "w+") as fout:
        for idx, d in enumerate(data):
            gen = gpt_outputs[idx] if model_name == "gpt-4o" else local_outputs[idx]
            try:
                d["respond_reddit"] = clean_generation(gen, is_reddit=is_reddit)
            except:
                d["respond_reddit"] = gen
            fout.write(json.dumps(d) + "\n")

    return 0


if __name__ == "__main__":
    fire.Fire(main)
