import dataclasses
import os
import pprint
import re
import time
import urllib
from functools import partial
from threading import Lock
from typing import List, Optional, Union

import absl.logging
import gradio as gr
import mlxu
import numpy as np
import requests
import uvicorn
from fastapi import FastAPI
from ml_collections import ConfigDict
from pydantic import BaseModel
from requests.exceptions import ConnectionError, Timeout
from tqdm import tqdm, trange


class InferenceRequest(BaseModel):
    prefix_text: Optional[List[str]] = None
    text: Optional[List[str]] = None
    until: Optional[Union[List[str], List[List[str]]]] = None
    temperature: Optional[float] = None


class ChatRequest(BaseModel):
    prompt: str
    context: str = ""
    temperature: Optional[float] = None


class LMServer(object):
    """HTTP server for serving langauge models."""

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.host = "0.0.0.0"
        config.port = 5007
        config.batch_size = 1
        config.logging = False
        config.pre_compile = "loglikelihood"
        config.default_temperature = 1.0
        config.greedy_until_max_length = 5000
        config.prepend_to_prefix = ""
        config.append_to_prefix = ""
        config.prepend_to_text = ""
        config.append_to_text = ""
        config.chat_prepend_text = ""
        config.chat_user_prefix = ""
        config.chat_user_suffix = ""
        config.chat_lm_prefix = ""
        config.chat_lm_suffix = ""
        config.notes = ""

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config):
        self.config = self.get_default_config(config)
        self.lock = Lock()
        self.app = FastAPI()
        self.app.post("/loglikelihood")(self.serve_loglikelihood)
        self.app.post("/loglikelihood-rolling")(self.serve_loglikelihood_rolling)
        self.app.post("/generate")(self.serve_generate)
        self.app.post("/greedy-until")(self.serve_greedy_until)
        self.app.post("/chat")(self.serve_chat)
        self.app.get("/ready")(self.serve_ready)
        self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), "/")

    @staticmethod
    def loglikelihood(prefix_text, text):
        raise NotImplementedError()

    @staticmethod
    def loglikelihood_rolling(text):
        raise NotImplementedError()

    @staticmethod
    def generate(text, temperature):
        raise NotImplementedError()

    @staticmethod
    def greedy_until(prefix_text, until, max_length):
        raise NotImplementedError()

    @staticmethod
    def to_list(x):
        if isinstance(x, np.ndarray):
            return x.tolist()
        return x

    def serve_ready(self):
        return "Ready!\n"

    def serve_loglikelihood(self, data: InferenceRequest):
        with self.lock:
            if self.config.logging:
                absl.logging.info(
                    "\n========= Serving Log Likelihood Request ========= \n"
                    + pprint.pformat(data)
                    + "\n"
                )

            if data.prefix_text is None:
                data.prefix_text = ["" for _ in data.text]

            prefix_text = [
                self.config.prepend_to_prefix + p + self.config.append_to_prefix
                for p in data.prefix_text
            ]
            text = [
                self.config.prepend_to_text + t + self.config.append_to_text
                for t in data.text
            ]

            log_likelihood = []
            is_greedy = []
            for i in trange(0, len(text), self.config.batch_size, ncols=0):
                batch_prefix_text = prefix_text[i : i + self.config.batch_size]
                batch_text = text[i : i + self.config.batch_size]
                batch_size = len(batch_text)

                if batch_size < self.config.batch_size:
                    extra = self.config.batch_size - batch_size
                    batch_prefix_text.extend(["a" for _ in range(extra)])
                    batch_text.extend(["a" for _ in range(extra)])

                batch_log_likelihood, batch_is_greedy = self.loglikelihood(
                    batch_prefix_text, batch_text
                )
                batch_log_likelihood = self.to_list(batch_log_likelihood)
                batch_is_greedy = self.to_list(batch_is_greedy)
                log_likelihood.extend(batch_log_likelihood[:batch_size])
                is_greedy.extend(batch_is_greedy[:batch_size])

            output = {
                "prefix_text": data.prefix_text,
                "text": data.text,
                "log_likelihood": log_likelihood,
                "is_greedy": is_greedy,
            }
            if self.config.logging:
                absl.logging.info(
                    "\n========= Output ========= \n" + pprint.pformat(output) + "\n"
                )

        return output

    def serve_loglikelihood_rolling(self, data: InferenceRequest):
        with self.lock:
            if self.config.logging:
                absl.logging.info(
                    "\n========= Serving Log Likelihood Request ========= \n"
                    + pprint.pformat(data)
                    + "\n"
                )

            text = [
                self.config.prepend_to_text + t + self.config.append_to_text
                for t in data.text
            ]
            log_likelihood = []
            is_greedy = []
            for i in trange(0, len(text), self.config.batch_size, ncols=0):
                batch_text = text[i : i + self.config.batch_size]
                batch_size = len(batch_text)

                if batch_size < self.config.batch_size:
                    extra = self.config.batch_size - batch_size
                    batch_text.extend(["a" for _ in range(extra)])

                batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling(
                    batch_text
                )
                batch_log_likelihood = self.to_list(batch_log_likelihood)
                batch_is_greedy = self.to_list(batch_is_greedy)
                log_likelihood.extend(batch_log_likelihood[:batch_size])
                is_greedy.extend(batch_is_greedy[:batch_size])

            output = {
                "text": data.text,
                "log_likelihood": log_likelihood,
                "is_greedy": is_greedy,
            }
            if self.config.logging:
                absl.logging.info(
                    "\n========= Output ========= \n" + pprint.pformat(output) + "\n"
                )

        return output

    def serve_generate(self, data: InferenceRequest):
        with self.lock:
            if self.config.logging:
                absl.logging.info(
                    "\n========= Serving Generate Request ========= \n"
                    + pprint.pformat(data)
                    + "\n"
                )
            prefix_text = [
                self.config.prepend_to_prefix + p + self.config.append_to_prefix
                for p in data.prefix_text
            ]

            if data.temperature is None:
                data.temperature = self.config.default_temperature

            output_text = []
            for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0):
                batch_prefix_text = prefix_text[i : i + self.config.batch_size]
                batch_size = len(batch_prefix_text)

                if batch_size < self.config.batch_size:
                    extra = self.config.batch_size - batch_size
                    batch_prefix_text.extend(["a" for _ in range(extra)])

                batch_output_text = self.generate(
                    batch_prefix_text,
                    temperature=data.temperature,
                )
                output_text.extend(self.to_list(batch_output_text)[:batch_size])

            output = {
                "prefix_text": data.prefix_text,
                "output_text": output_text,
                "temperature": data.temperature,
            }
            if self.config.logging:
                absl.logging.info(
                    "\n========= Output ========= \n" + pprint.pformat(output) + "\n"
                )
        return output

    def serve_greedy_until(self, data: InferenceRequest):
        with self.lock:
            if self.config.logging:
                absl.logging.info(
                    "\n========= Serving Greedy Until Request ========= \n"
                    + pprint.pformat(data)
                    + "\n"
                )
            prefix_text = [
                self.config.prepend_to_prefix + p + self.config.append_to_prefix
                for p in data.prefix_text
            ]
            until = data.until
            max_length = self.config.greedy_until_max_length

            output_text = []
            for i in range(0, len(prefix_text), self.config.batch_size):
                batch_prefix_text = prefix_text[i : i + self.config.batch_size]
                batch_until = until[i : i + self.config.batch_size]
                batch_size = len(batch_prefix_text)

                batch_output_text = self.greedy_until(
                    batch_prefix_text, batch_until, max_length
                )
                output_text.extend(self.to_list(batch_output_text)[:batch_size])

            output = {
                "prefix_text": data.prefix_text,
                "until": data.until,
                "max_length": max_length,
                "output_text": output_text,
            }
            if self.config.logging:
                absl.logging.info(
                    "\n========= Output ========= \n" + pprint.pformat(output) + "\n"
                )
        return output

    def process_chat(self, prompt, context, temperature):
        context = (
            context
            + self.config.chat_user_prefix
            + prompt
            + self.config.chat_user_suffix
            + self.config.chat_lm_prefix
        )
        response = self.generate(
            [self.config.chat_prepend_text + context],
            temperature=float(temperature),
        )[0]
        context = context + response + self.config.chat_lm_suffix
        return response, context

    def serve_chat(self, data: ChatRequest):
        if data.temperature is None:
            data.temperature = self.config.default_temperature
        response, context = self.process_chat(
            data.prompt,
            data.context,
            temperature=data.temperature,
        )
        return {
            "response": response,
            "context": context,
            "temperature": data.temperature,
        }

    def create_chat_app(self):
        with gr.Blocks(analytics_enabled=False, title="EasyLM Chat") as gradio_chatbot:
            gr.Markdown(
                "# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)"
            )
            gr.Markdown(self.config.notes)
            chatbot = gr.Chatbot(label="Chat history")
            msg = gr.Textbox(placeholder="Type your message here...", show_label=False)
            with gr.Row():
                send = gr.Button("Send")
                regenerate = gr.Button("Regenerate", interactive=False)
                clear = gr.Button("Reset")

            temp_slider = gr.Slider(
                label="Temperature",
                minimum=0,
                maximum=2.0,
                value=self.config.default_temperature,
            )

            context_state = gr.State(["", ""])

            def user_fn(user_message, history, context):
                return {
                    msg: gr.update(value="", interactive=False),
                    clear: gr.update(interactive=False),
                    send: gr.update(interactive=False),
                    regenerate: gr.update(interactive=False),
                    chatbot: history + [[user_message, None]],
                    context_state: [context[1], context[1]],
                }

            def model_fn(history, context, temperature):
                history[-1][1], new_context = self.process_chat(
                    history[-1][0], context[0], temperature
                )
                return {
                    msg: gr.update(value="", interactive=True),
                    clear: gr.update(interactive=True),
                    send: gr.update(interactive=True),
                    chatbot: history,
                    context_state: [context[0], new_context],
                    regenerate: gr.update(interactive=True),
                }

            def regenerate_fn():
                return {
                    msg: gr.update(value="", interactive=False),
                    clear: gr.update(interactive=False),
                    send: gr.update(interactive=False),
                    regenerate: gr.update(interactive=False),
                }

            def clear_fn():
                return {
                    chatbot: None,
                    msg: "",
                    context_state: ["", ""],
                    regenerate: gr.update(interactive=False),
                }

            msg.submit(
                user_fn,
                inputs=[msg, chatbot, context_state],
                outputs=[msg, clear, send, chatbot, context_state, regenerate],
                queue=False,
            ).then(
                model_fn,
                inputs=[chatbot, context_state, temp_slider],
                outputs=[msg, clear, send, chatbot, context_state, regenerate],
                queue=True,
            )
            send.click(
                user_fn,
                inputs=[msg, chatbot, context_state],
                outputs=[msg, clear, send, chatbot, context_state, regenerate],
                queue=False,
            ).then(
                model_fn,
                inputs=[chatbot, context_state, temp_slider],
                outputs=[msg, clear, send, chatbot, context_state, regenerate],
                queue=True,
            )
            regenerate.click(
                regenerate_fn,
                inputs=None,
                outputs=[msg, clear, send, regenerate],
                queue=False,
            ).then(
                model_fn,
                inputs=[chatbot, context_state, temp_slider],
                outputs=[msg, clear, send, chatbot, context_state, regenerate],
                queue=True,
            )
            clear.click(
                clear_fn,
                inputs=None,
                outputs=[chatbot, msg, context_state, regenerate],
                queue=False,
            )

        gradio_chatbot.queue(concurrency_count=1)
        return gradio_chatbot

    def run(self):
        if self.config.pre_compile != "":
            if self.config.pre_compile == "all":
                pre_compile = ["loglikelihood", "generate", "greedy_until", "chat"]
            else:
                pre_compile = self.config.pre_compile.split(",")

            pre_compile_data = ["a" for _ in range(self.config.batch_size)]
            for task in pre_compile:
                if task == "loglikelihood":
                    self.loglikelihood(pre_compile_data, pre_compile_data)
                    self.loglikelihood_rolling(pre_compile_data)
                elif task == "generate":
                    self.generate(pre_compile_data, 1.0)
                elif task == "greedy_until":
                    self.greedy_until(
                        pre_compile_data,
                        pre_compile_data,
                        self.config.greedy_until_max_length,
                    )
                elif task == "chat":
                    self.process_chat("a", "a", 1.0)
                else:
                    raise ValueError(f"Invalid precompile task: {task}!")

        uvicorn.run(self.app, host=self.config.host, port=self.config.port)


class LMClient(object):
    """A simple client for the LM server."""

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.url = "http://localhost:5007"
        config.batch_size = 1
        config.wait_for_ready = True
        config.dummy = False

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config=None):
        self.config = self.get_default_config(config)
        if self.config.wait_for_ready:
            self.wait_for_ready()

    def wait_for_ready(self):
        if self.config.dummy:
            return
        while True:
            try:
                requests.get(urllib.parse.urljoin(self.config.url, "ready"))
                return
            except (Timeout, ConnectionError):
                time.sleep(10)

    @staticmethod
    def batched(iterator, batch_size):
        batch = []
        for example in iterator:
            batch.append(example)
            if len(batch) == batch_size:
                yield batch
                batch = []
        if len(batch) > 0:
            yield batch

    def loglikelihood(self, prefix, text):
        prefix, text = list(prefix), list(text)
        if self.config.dummy:
            return [-1.0 for _ in text], [False for _ in text]

        log_likelihood = []
        is_greedy = []

        batched_iterator = list(
            zip(
                self.batched(prefix, self.config.batch_size),
                self.batched(text, self.config.batch_size),
            )
        )
        for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0):
            response = requests.post(
                urllib.parse.urljoin(self.config.url, "loglikelihood"),
                json={"prefix_text": batch_prefix, "text": batch_text},
            ).json()
            log_likelihood.extend(response["log_likelihood"])
            is_greedy.extend(response["is_greedy"])

        return log_likelihood, is_greedy

    def loglikelihood_rolling(self, text):
        text = list(text)
        if self.config.dummy:
            return [-1.0 for _ in text], [False for _ in text]

        log_likelihood = []
        is_greedy = []
        batched_iterator = list(self.batched(text, self.config.batch_size))
        for batch_text in tqdm(batched_iterator, ncols=0):
            response = requests.post(
                urllib.parse.urljoin(self.config.url, "loglikelihood-rolling"),
                json={"text": batch_text},
            ).json()
            log_likelihood.extend(response["log_likelihood"])
            is_greedy.extend(response["is_greedy"])
        return log_likelihood, is_greedy

    def greedy_until(self, prefix, until):
        prefix, until = list(prefix), list(until)
        if self.config.dummy:
            results = []
            for u in until:
                if isinstance(u, str):
                    results.append("dummy text " + u)
                else:
                    results.append("dummy text " + u[0])
            return results

        batched_iterator = list(
            zip(
                self.batched(prefix, self.config.batch_size),
                self.batched(until, self.config.batch_size),
            )
        )
        output_text = []
        for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0):
            response = requests.post(
                urllib.parse.urljoin(self.config.url, "greedy-until"),
                json={"prefix_text": batch_prefix, "until": batch_until},
            ).json()
            output_text.extend(response["output_text"])
        return output_text

    def generate(self, prefix, temperature=None):
        prefix = list(prefix)
        if self.config.dummy:
            return ["" for _ in prefix]

        output_text = []
        batched_iterator = list(self.batched(prefix, self.config.batch_size))
        for batch_prefix in tqdm(batched_iterator, ncols=0):
            response = requests.post(
                urllib.parse.urljoin(self.config.url, "generate"),
                json={
                    "prefix_text": batch_prefix,
                    "temperature": temperature,
                },
            ).json()
            output_text.extend(response["output_text"])
        return output_text

    def chat(self, prompt, context, temperature=None):
        if self.config.dummy:
            return ""
        response = requests.post(
            urllib.parse.urljoin(self.config.url, "chat"),
            json={
                "prompt": prompt,
                "context": context,
                "temperature": temperature,
            },
        ).json()
        return response["response"], response["context"]
