import asyncio
import logging
from copy import deepcopy

import aiohttp
import numpy as np
import torch
from openai.types.completion import Completion
from tensordict import TensorDict

from verl.protocol import DataProto
from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length

logger = logging.getLogger(__name__)


def _repeat_interleave(value: torch.Tensor | np.ndarray, repeats: int) -> torch.Tensor | np.ndarray:
    if isinstance(value, torch.Tensor):
        return value.repeat_interleave(repeats, dim=0)
    elif isinstance(value, np.ndarray):
        return np.repeat(value, repeats, axis=0)


async def poll_completions_openai(address: str, **completions_request) -> Completion:
    # Use aiohttp directly instead of AsyncOpenAI to avoid potential blocking
    base_url = f"http://{address}/v1/completions"
    headers = {
        "Content-Type": "application/json",
    }

    # Remove meta_info if present
    if "meta_info" in completions_request:
        completions_request.pop("meta_info")
    # Remove extra_headers from the payload
    if "extra_headers" in completions_request:
        completions_request.pop("extra_headers")

    max_retries = 5
    retry_delay = 2  # Initial delay in seconds

    for retry in range(max_retries):
        try:
            # Create a new session for each request to avoid blocking
            async with aiohttp.ClientSession() as session:
                async with session.post(base_url, json=completions_request, headers=headers, timeout=aiohttp.ClientTimeout(total=2700)) as response:
                    if response.status != 200:
                        error_text = await response.text()
                        raise Exception(f"API request failed with status {response.status}: {error_text}")
                    result = await response.json()
                    # Convert the raw JSON response to an OpenAI Completion object
                    return result
        except Exception as e:
            import traceback

            traceback.print_exc()
            # If this is the last retry, raise the exception
            if retry == max_retries - 1:
                raise e
            # Exponential backoff
            await asyncio.sleep(retry_delay)
            retry_delay *= 2

    # This should never be reached due to the raise in the loop, but mypy requires it
    raise Exception("All retries failed")


class Router:
    """
    Router chooses the least-used server address from a static list of
    server addresses across multiple processes using asyncio locks.
    """

    def __init__(self, config, tokenizer, addresses: list[str]):
        # List of "ip:port" strings
        self.addresses = addresses
        self.tensor_parallel_size = config.actor_rollout_ref.rollout.get("tensor_model_parallel_size", 1)
        self._lock = asyncio.Lock()
        self._usage: dict[str, int] = {}
        self._application_id_to_address: dict[str, str] = {}
        # Initialize usage counts for any new addresses
        for addr in self.addresses:
            if addr not in self._usage:
                self._usage[addr] = 0
        self.counter = 0
        self.config = config
        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id
        self.eos_token_id = tokenizer.eos_token_id
        model_path = config.actor_rollout_ref.model.path
        self.model_name = model_path
        # self.model_name = "/".join(model_path.split("/")[-2:])

    async def get_address(self, application_id: str) -> str:
        """
        Pick the server address with the smallest usage count and increment its counter.
        """
        async with self._lock:
            min_address, min_usage = min(self._usage.items(), key=lambda x: x[1])
            if application_id not in self._application_id_to_address:
                self._application_id_to_address[application_id] = min_address
                self._usage[min_address] += 1
            else:
                # Data locality
                cur_address = self._application_id_to_address[application_id]
                cur_usage = self._usage[cur_address]
                # Load balance if there is skew
                if (min_usage == 0 or cur_usage - min_usage >= 4) and cur_usage > 0:
                    self._application_id_to_address[application_id] = min_address
                    self._usage[min_address] += 1
                else:
                    self._usage[cur_address] += 1
        return self._application_id_to_address[application_id]

    async def release_address(self, addr: str, application_id: str) -> None:
        """
        Decrement the usage count for a server address when done.
        """
        async with self._lock:
            self._usage[addr] = max(0, self._usage.get(addr, 0) - 1)
        
    async def generate_sequences(self, batch: DataProto, application_id: str, **sampling_params):
        kwargs = dict(
            n=self.config.actor_rollout_ref.rollout.n,
            max_tokens=self.config.actor_rollout_ref.rollout.response_length,
            temperature=self.config.actor_rollout_ref.rollout.temperature,
            top_p=self.config.actor_rollout_ref.rollout.top_p,
            logprobs=0,
            return_token_ids=True,
            repetition_penalty=1.0,
        )

        if batch.meta_info.get("max_tokens") is not None:
            kwargs["max_tokens"] = batch.meta_info["max_tokens"]

        if batch.meta_info.get("agent_rollout", False):
            kwargs["n"] = 1

        kwargs.update(sampling_params)

        is_validate = batch.meta_info.get("validate", False)
        if is_validate:
            kwargs.update({
                "top_p": self.config.actor_rollout_ref.rollout.val_kwargs.top_p,
                "temperature": self.config.actor_rollout_ref.rollout.val_kwargs.temperature,
                "n": 1,
                "repetition_penalty": 1.0,
            })

        address = await self.get_address(application_id)

        batch_size = len(batch.non_tensor_batch["formatted_prompts"])
        tasks = [
            self.submit_completions(
                address=address,
                model=self.model_name,
                prompt=formatted_prompt,
                **kwargs,
            )
            for formatted_prompt in batch.non_tensor_batch["formatted_prompts"]
        ]

        completions_list = await asyncio.gather(*tasks)
        await self.release_address(address, application_id)

        # batch_response_ids, batch_response_logprobs = [], []
        input_ids, batch_response_ids, batch_response_logprobs = [], [], []
        # print(completions_list[0].get("choices", []))
        for completions in completions_list:
            choices = completions.get("choices", [])
            # choice_token_ids, choice_token_lps = [], []
            choice_input_ids, choice_token_ids, choice_token_lps = [], [], []
            for choice in choices:
                # print(choice)
                raw_logprobs = choice.get("logprobs", {})
                token_ids = choice.get("token_ids", [])
                logprobs = raw_logprobs.get("token_logprobs", [])
                prompt_token_ids = choice.get("prompt_token_ids", [])
                choice_input_ids.append(prompt_token_ids)
                choice_token_ids.append(token_ids)
                choice_token_lps.append(logprobs)
                # print("TEXT:", choice.get("text"))
            input_ids.append(choice_input_ids)
            batch_response_ids.append(choice_token_ids)
            batch_response_logprobs.append(choice_token_lps)

        return await self.postprocess_batch(batch, input_ids, batch_response_ids, batch_response_logprobs, kwargs["n"])

    async def submit_completions(self, address, model, prompt, **kwargs):
        # Potential blocking: network I/O can block
        return await poll_completions_openai(address=address, model=model, prompt=prompt, **kwargs)

    async def postprocess_batch(
        self,
        batch: DataProto,
        prompt_token_ids: list[list[int]],
        response_ids: list[list[int]],
        response_logprobs: list[list[float]] | None,
        n: int,
    ) -> DataProto:
        # NOTE: For Completion API, batch_completions is a list of lists of strings (not dictionaries)
        # prompts: left pad
        # responses: right pad
        # input_ids: prompt + response
        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]

        # prompts: [prompt] from input dataset
        # # To solve detokenization drift problem

        idx = batch.batch["input_ids"]  # (bs, max_prompt_length)
        # left-padded attention_mask
        attention_mask = batch.batch["attention_mask"]
        position_ids = batch.batch["position_ids"]
        non_tensor_batch = deepcopy(batch.non_tensor_batch)

        # Flatten to list.
        # Flatten the list of lists of token IDs
        response = []
        response_lp: list[list[float]] | None = None
        for r_ids in response_ids:
            if r_ids is not None:  # Ensure we don't process None values
                for r in r_ids:
                    response.append(r)
        assert len(response) == len(non_tensor_batch["formatted_prompts"]) * n
        response_tensor = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.actor_rollout_ref.rollout.response_length).to(idx.device)

        # Flatten and pad response logprobs if provided
        rollout_log_probs_tensor = None
        if response_logprobs is not None and len(response_logprobs) > 0:
            response_lp = []
            for r_lps in response_logprobs:
                if r_lps is not None:
                    for lps in r_lps:
                        response_lp.append(lps)
            assert len(response_lp) == len(non_tensor_batch["formatted_prompts"]) * n
            rollout_log_probs_tensor = pad_2d_list_to_length(
                response_lp, -1, max_length=self.config.actor_rollout_ref.rollout.response_length
            ).to(idx.device)
            rollout_log_probs_tensor = rollout_log_probs_tensor.to(torch.float32)

        if n > 1:
            idx = _repeat_interleave(idx, n)
            attention_mask = _repeat_interleave(attention_mask, n)
            position_ids = _repeat_interleave(position_ids, n)
            for key, val in non_tensor_batch.items():
                non_tensor_batch[key] = _repeat_interleave(val, n)

        batch_size = len(idx)
        seq = torch.cat([idx, response_tensor], dim=-1)

        response_length = response_tensor.size(1)
        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
        delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)

        # TODO(sgm): fix position_ids on right_pad
        # prompt: left pad + response: right pad
        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
        response_position_ids = position_ids[..., -1:] + delta_position_id
        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
        response_attention_mask = get_response_mask(response_id=response_tensor, eos_token=self.eos_token_id, dtype=attention_mask.dtype)
        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)

        output_dict = {
            # "prompts": torch.tensor(prompt_token_ids, dtype=idx.dtype, device=idx.device),
            "responses": response_tensor,
            "input_ids": seq,
            "attention_mask": attention_mask,
            "position_ids": position_ids,
        }
        if rollout_log_probs_tensor is not None:
            output_dict["rollout_log_probs"] = rollout_log_probs_tensor
        
        new_non_tensor_batch = {}
        if len(prompt_token_ids) > 0:
            flattened_prompt_token_ids = []
            for batch_item in prompt_token_ids:
                if len(batch_item) > 0:
                    flattened_prompt_token_ids.append(batch_item[0])
                else:
                    flattened_prompt_token_ids.append([])
            new_non_tensor_batch["prompt_token_ids"] = np.array(flattened_prompt_token_ids)
        # print("Output dict: ", output_dict)

        output = TensorDict(output_dict, batch_size=batch_size)
        return DataProto(batch=output, meta_info=batch.meta_info, non_tensor_batch=new_non_tensor_batch)
