from loguru import logger
from base import BaseClass
import numpy as np
import os

import aiohttp

import json
import time
import asyncio
import numpy as np
import lmql
import asyncio
from transformers import set_seed
import transformers
from utils import get_best_device
import torch
from model_loader import load_model

set_seed(42)

class Query(BaseClass):
    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    async def run_string_prompts(self, string_prompts):
        raise NotImplementedError
    
    async def run(self, prompts, args_for_prompts):
        for i in range(len(args_for_prompts)):
            if not isinstance(args_for_prompts[i], (tuple, list)):
                args_for_prompts[i] = (args_for_prompts[i],)
        
        string_prompts = [prompt.get(*args) for prompt, args in zip(prompts, args_for_prompts)]
        return await self.run_string_prompts(string_prompts)
    
class LMQLParallelQuery(Query):
    def __init__(self, tpm=70000, max_tokens=1000, **kwargs) -> None:
        super().__init__(tpm=tpm, max_tokens=max_tokens, **kwargs)

    async def run_completions(self, queries):
        return await asyncio.gather(*[lmql.run(query, output_writer=lmql.silent) for query in queries])

    async def run_string_prompts(self, string_prompts):
        succeeded_requests = [False for _ in range(len(string_prompts))]
        outputs = [None for _ in range(len(string_prompts))]
        generated_tokens = []
        while not all(succeeded_requests):
            start_time = time.time()
            generated_tokens_last_min = sum([usage[1] for usage in generated_tokens if start_time - usage[0] < 60])
            async_requests = (self.tpm - generated_tokens_last_min) // self.max_tokens
            if async_requests == 0:
                time.sleep(0.2)
                continue
            indices = np.where(np.logical_not(succeeded_requests))[0][:async_requests]
            ret = await self.run_completions([string_prompts[index] for index in indices])
            for results, index in zip(ret, indices):
                if results is not None:
                    try:
                        outputs[index] = results[0].variables
                        succeeded_requests[index] = True
                        generated_tokens.append((start_time, results[0].prompt.count(" "))) #NOTE: not ideal, but lmql doesnt allow for more
                    except Exception:
                        logger.warning(f"OpenAI API returned invalid json: {results} \n On parameters {string_prompts[index]}")    
        return outputs
    
class LMQLQuery(Query):
    def __init__(self):
        super().__init__()
    
    async def run_string_prompts(self, string_prompts):
        outputs = []
        for prompt in string_prompts:
            returned = await lmql.run(prompt, output_writer=lmql.silent)
            outputs.append(returned[0].variables)
        return outputs

class TransformerGeneratorQuery(Query):
    def __init__(self, model="gpt2", max_tokens=128, batch_size=1, stop_sequence="\n", temperature=1.0, 
                 top_p=1.0, load_in_8bit=False, force_cuda=True, **kwargs) -> None:
        self.model_call = None
        self.tokenizer = None
        super().__init__(model=model, max_tokens=max_tokens, batch_size=batch_size, stop_sequence=stop_sequence, temperature=temperature, top_p=top_p, 
                         load_in_8bit=load_in_8bit, force_cuda=force_cuda, **kwargs)
        
    
    async def run_string_prompts(self, string_prompts):
        if self.model_call is None:
            self.model_call, self.tokenizer = load_model(self.model, return_tokenizer=True, dtype=torch.bfloat16, load_dtype=True)
            # if self.force_cuda:
            #     self.model_call = self.model_call.to("cuda")

        max_length = None
        for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
            max_length = getattr(self.model_call.config, length_setting, None)
            if max_length:
                logger.info(f"Found max length: {max_length}, {length_setting}")
                break

        eos_token_id = self.tokenizer.encode(self.stop_sequence)[0]

        self.model_call.eval()

        outputs = []

        pipeline = transformers.pipeline(
            "text-generation",
            model=self.model_call,
            tokenizer=self.tokenizer,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto",
        )

        for index in range(0, len(string_prompts), self.batch_size):
            logger.debug(f"Progress: {index / len(string_prompts):.4f}")
            batch_data = string_prompts[index:index+self.batch_size]
            # if 'token_type_ids' in batch_data and not hasattr(self.model_call.config, 'token_type_ids'):
            #     del batch_data['token_type_ids']
            
            with torch.no_grad():
                try:
                    gen_texts = pipeline(
                        text_inputs=batch_data,
                        max_new_tokens=self.max_tokens,
                        do_sample=True,
                        pad_token_id=self.tokenizer.eos_token_id,
                        eos_token_id=eos_token_id,
                        temperature=self.temperature,
                        top_p=self.top_p,
                        batch_size=self.batch_size,
                        return_full_text=False,
                    )

                except Exception as e:
                    logger.error(f"Batch size {self.batch_size} too high, reducing... {e}")
                    if self.batch_size == 1:
                        raise ValueError("Not enough memory")
                    elif self.batch_size == 2:
                        self.batch_size = 1
                    else:
                        self.batch_size -= 2
                    outputs = await self.run_string_prompts(string_prompts)
                    return outputs

            for i, gen_text in enumerate(gen_texts):
                returned_text = gen_text[0]["generated_text"]
                returned_text = returned_text[:returned_text.find(self.stop_sequence)]
                outputs.append(returned_text)

        del self.model_call
        torch.cuda.empty_cache()
        self.model_call = None

        return outputs

class OpenAIQuery(Query):
    def __init__(self, model="gpt-3.5-turbo", tpm=30000, timeout=100, temperature=1.2, max_tokens=1000, error_stop=10 ** 8, **kwargs) -> None:
        super().__init__(model=model, tpm=tpm, timeout=timeout, temperature=temperature, max_tokens=max_tokens, error_stop=error_stop, **kwargs)
    
    async def run_string_prompts(self, string_prompts):
        kwarg = self.kwargs.copy()
        del kwarg["tpm"]
        del kwarg["timeout"]
        del kwarg["error_stop"]
        openai_queries = []
        for prompt in string_prompts:
            if isinstance(prompt, str):
                openai_queries.append({"prompt": prompt, **kwarg})
            else:
                openai_queries.append({"messages": prompt, **kwarg})

        return await self.get_completions(openai_queries)

    async def get_completion_async(self, arguments, session):
        if "OPENAI_API_KEY" not in os.environ:
            raise ValueError("OPENAI_API_KEY not found in environment variables")
        try:
            url = "https://api.openai.com/v1/chat/completions"
            if "prompt" in arguments:
                url = "https://api.openai.com/v1/completions"
            async with session.post(
                url, 
                headers={
                    "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",
                    "Content-Type": "application/json",
                },
                json=arguments
            ) as response:
                resp = await response.read()
                return resp
        except Exception as e:
            logger.warning(f"Error occurred while posting to openai API: {e}. Posted: {arguments}")
            return None
        
    async def get_completions_async(self, list_arguments):
        timeout = aiohttp.ClientTimeout(total=self.timeout)
        async with aiohttp.ClientSession(timeout=timeout) as session:
            ret = await asyncio.gather(*[self.get_completion_async(argument, session) for argument in list_arguments])
            
        return ret

    async def get_completions(self, list_arguments):
        succeeded_requests = [False for _ in range(len(list_arguments))]
        outputs = [None for _ in range(len(list_arguments))]
        generated_tokens = []
        n_errors = 0
        n_parse_errors = 0
        n_new_errors = 0
        while not all(succeeded_requests) and n_errors < self.error_stop and n_parse_errors < self.error_stop:
            start_time = time.time()
            generated_tokens_last_min = sum([usage[1] for usage in generated_tokens if start_time - usage[0] < 60])
            async_requests = (self.tpm - min(generated_tokens_last_min, self.tpm)) // self.max_tokens
            if async_requests == 0:
                time.sleep(0.2)
                continue

            indices = np.where(np.logical_not(succeeded_requests))[0][:async_requests]
            arguments_async = [list_arguments[index] for index in indices]
            logger.debug(f"Running {len(arguments_async)} requests to openai API. tokens last minute: {generated_tokens_last_min}. percentage done: {np.count_nonzero(succeeded_requests) / len(succeeded_requests) * 100:.2f}%")
            if asyncio.get_event_loop().is_running():
                ret = await self.get_completions_async(arguments_async)
            else:
                ret = await asyncio.run(self.get_completions_async(arguments_async))

            for results, index in zip(ret, indices):
                if results is not None:
                    try:
                        outputs[index] = json.loads(results)
                        if "error" not in outputs[index]:
                            succeeded_requests[index] = True
                            generated_tokens.append((start_time, outputs[index]["usage"]["total_tokens"]))
                            outputs[index] = outputs[index]["choices"][0]
                        else: 
                            logger.warning(f"OpenAI API returned an error: {outputs[index]} \n On parameters {list_arguments[index]}")
                            n_errors += 1
                            n_new_errors += 1
                    except Exception:
                        logger.warning(f"OpenAI API returned invalid json: {results} \n On parameters {list_arguments[index]}")
                        n_parse_errors += 1
                else:
                    n_errors += 1
                    n_new_errors += 1

            if n_new_errors >= 20:
                time.sleep(10)
                n_new_errors = 0
                    
        if n_errors >= self.error_stop or n_parse_errors >= self.error_stop:
            raise ValueError("OpenAI API returned too many errors. Stopping requests.")

        return outputs
