import sys
import json
import os
import subprocess
import asyncio
import functools
import concurrent
from collections import defaultdict
from tqdm import tqdm
import openai
from openai import Timeout, OpenAI
from deploy_utils import get_available_gpus

import socket
import time

def wait_for_port(host, port, timeout=60.0, poll_interval=1):
    """Wait until a TCP port starts accepting connections."""
    start_time = time.time()
    while True:
        try:
            with socket.create_connection((host, port), timeout=1):
                print(f"Service on {host}:{port} is up!")
                return True
        except (ConnectionRefusedError, socket.timeout):
            if time.time() - start_time >= timeout:
                raise TimeoutError(f"Timeout waiting for {host}:{port}")
            time.sleep(poll_interval)

class LLMBackend():
    # helper class that ensures that there is something to send queries to..

    def wait_done(self):
        # waits until ports allow a connection
        pass

    def get_urls(self):
        pass

    def close(self):
        pass

class OpenAIBackend(LLMBackend):
    def __init__(self, base_url, api_key, ports=None):
        self.base_urls = base_url
        self.api_key = api_key
        self.ports = ports

        if ports is None:
            self.urls = [base_url]
        else:
            self.urls = [base_url.format(port=p) for p in ports]

    def get_urls(self):
        return self.urls

class VLLMBackend(LLMBackend):
    def __init__(self, model_name, api_key, base_port=8000, offline=False, model_kwargs={}, available_gpus=None, wait_done=False):
        self.offline = offline
        if offline and "revision" not in model_kwargs:
            import huggingface_hub as hh
            
            info = hh.scan_cache_dir()
            repo_info = [x for x in info.repos if x.repo_id == model_name]
            if len(repo_info) != 1:
                raise ValueError(f"Error finding cache for repo '{model_name}'")
            repo_info = repo_info[0]
            print(f"Found {len(repo_info.revisions)} revisions for repo '{model_name}'")
            rev = max(repo_info.revisions, key=lambda x: x.last_modified)
            commit = rev.commit_hash
            model_kwargs["revision"] = commit
        
        self.api_key = api_key
        available_gpus = get_available_gpus() if available_gpus is None else available_gpus
        tp = model_kwargs.get("tensor_parallel_size",1)
        gpus_per_process = [available_gpus[i*tp:(i+1)*tp] for i in range(0, len(available_gpus)//tp)]
        n_processes = len(gpus_per_process)

        if "mistralai" in model_name:
            model_kwargs["tokenizer_mode"] = "mistral"
            model_kwargs["config_format"] = "mistral"
            model_kwargs["load_format"] = "mistral"

        
        mk = " ".join([f'--{k.replace("_","-")} {"" if isinstance(v, bool) else v}'.strip() for k,v in model_kwargs.items()])
        self.vllm_command = sys.exec_prefix + f"/bin/vllm serve {model_name}" + " --port {port} --disable-log-requests --gpu-memory-utilization 0.9 --uvicorn-log-level error " + mk
        self.ports = list(range(base_port, len(gpus_per_process) + base_port))
    
        self.processes = []
        for i in range(n_processes):
            port = self.ports[i]
            proc_gpus = gpus_per_process[i]
            proc_env = os.environ.copy()
            if self.offline:
                proc_env["TRANSFORMERS_OFFLINE"] = "1"
            proc_env["CUDA_VISIBLE_DEVICES"] = ",".join([str(x) for x in proc_gpus])
    
            cmd = self.vllm_command.format(port=port).split()
            print(f"Starting server on port {port}: {' '.join(cmd)}")
            p = subprocess.Popen(cmd, env=proc_env)
            self.processes.append(p)

    def get_urls(self):
        api_bases = [f"http://0.0.0.0:{p}/v1/" for p in self.ports]
        return api_bases

    def wait_done(self):
        for p in self.ports:
            wait_for_port("localhost", p, timeout=600)

    def close(self):
        for p in self.processes:
            try:
                p.terminate()
            except Exception as e:
                print(f"Warning: error terminating process {getattr(p, 'pid', 'unknown')}: {e}")
    
        for p in self.processes:
            try:
                p.wait(timeout=10)
            except Exception as e:
                print(f"Warning: process {getattr(p, 'pid', 'unknown')} did not terminate in time: {e}")
                try:
                    p.kill()
                except Exception as e2:
                    print(f"Warning: error killing process {getattr(p, 'pid', 'unknown')}: {e2}")


from copy import deepcopy

def get_default_dict_val():
    # must be module level due to pickle...............
    return 0

class MultiOpenAIClient:
    '''
    Utility Class wrapping around vllm, sglang and openai backends for chat completions
    # Todo: sglang, openai
    '''

    def __init__(self, model_name, backend="vllm", max_parallel_requests_per_client=None,**backend_kwargs):
        self.model_name = model_name
        if backend == "vllm":
            self.backend = VLLMBackend(model_name, **backend_kwargs)
        elif backend == "openai":
            self.backend = OpenAIBackend(**backend_kwargs)
        elif isinstance(backend, LLMBackend):
            print("Backend instance given, ignoring `backend_kwargs`")
            self.backend = backend
        else:
            raise ValueError("unknown backend " + str(backend))
        self.max_parallel_requests_per_worker = max_parallel_requests_per_client 
        if max_parallel_requests_per_client is None:
            self.max_parallel_requests_per_client = 1024 # random guess tbh

        timeout = Timeout(connect=600, read=None, write=600, pool=600)

        self.clients = [OpenAI(api_key=self.backend.api_key, base_url=a) for a in self.backend.get_urls()]
        
        self.completion_token_usage = defaultdict(get_default_dict_val)
        self.call_count = 0

    def reset_token_usage(self):
        self.completion_token_usage = defaultdict(get_default_dict_val)
        self.call_count = 0
        
    @functools.wraps(openai.resources.chat.completions.completions.Completions.create)
    def chat(self, token_usage_key=None, use_tqdm=False, return_format="openai", **kwargs):
        # messages can be list of dict (single sample) or list of list of dict (batch)
        messages = kwargs.pop("messages")
        if len(messages) == 0: raise ValueError("Empty messages")
        if isinstance(messages[0], dict):
            messages = [messages]
    
        num_workers = min(len(messages), self.max_parallel_requests_per_client * len(self.clients))

        if return_format not in {"openai","list", "chatml"}: raise ValueError("Allowed return formats are `openai`, `list`, `chatml`")

        results = []
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
            futures = []
            for i in tqdm(range(len(messages)), disable=not use_tqdm, desc="Dispatching requests"):
                c = self.clients[i % len(self.clients)]
                futures.append(executor.submit(c.chat.completions.create, messages=messages[i], model=self.model_name, **kwargs))
            for i, future in tqdm(enumerate(futures), disable=not use_tqdm, total=len(futures), desc="Gathering results"):
                results.append(future.result())

        # token usage
        self.completion_token_usage[token_usage_key] += sum([r.usage.completion_tokens for r in results])
        self.call_count += sum([len(x.choices) for x in results])

        if return_format == "list":
            results = [[c.message.content for c in r.choices] for r in results]
        elif return_format == "chatml":
            chats = []
            for m, r in zip(messages, results):
                responses = []
                for c in r.choices:
                    responses.append(m + [{"role":"assistant", "content":c.message.content}])
                chats.append(responses)
            results = chats

        return results

    def wait_done(self):
        self.backend.wait_done()
    
    def close(self):
        self.backend.close()