import copy
import json
import os
import time

from tqdm import tqdm

OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
BASE_URL = "https://api.openai.com/v1"


class OpenaiClient:
    def __init__(
        self,
        keys=None,
        start_id=None,
        proxy=None,
        base_url=BASE_URL,
    ):
        import openai
        from openai import OpenAI

        if isinstance(keys, str):
            keys = [keys]
        if OPENAI_API_KEY:
            keys = [OPENAI_API_KEY]
        if keys is None:
            raise "Please provide OpenAI Key."

        self.key = keys
        self.key_id = start_id or 0
        self.key_id = self.key_id % len(self.key)
        self.api_key = self.key[self.key_id % len(self.key)]
        self.base_url = base_url
        self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)

    def chat(self, *args, return_text=False, reduce_length=False, **kwargs):
        while True:
            try:
                completion = self.client.chat.completions.create(
                    *args, **kwargs, timeout=30
                )
                break
            except Exception as e:
                print(str(e))
                if "This model's maximum context length is" in str(e):
                    print("reduce_length")
                    return "ERROR::reduce_length"
                time.sleep(0.1)
        if return_text:
            completion = completion.choices[0].message.content
        return completion

    def text(self, *args, return_text=False, reduce_length=False, **kwargs):
        while True:
            try:
                completion = self.client.completions.create(*args, **kwargs)
                break
            except Exception as e:
                print(e)
                if "This model's maximum context length is" in str(e):
                    print("reduce_length")
                    return "ERROR::reduce_length"
                time.sleep(0.1)
        if return_text:
            completion = completion.choices[0].text
        return completion


class ClaudeClient:
    def __init__(self, keys):
        from anthropic import AI_PROMPT, HUMAN_PROMPT, Anthropic

        self.anthropic = Anthropic(api_key=keys)

    def chat(self, messages, return_text=True, max_tokens=300, *args, **kwargs):
        system = " ".join(
            [turn["content"] for turn in messages if turn["role"] == "system"]
        )
        messages = [turn for turn in messages if turn["role"] != "system"]
        if len(system) == 0:
            system = None
        completion = self.anthropic.beta.messages.create(
            messages=messages,
            system=system,
            max_tokens=max_tokens,
            *args,
            **kwargs,
        )
        if return_text:
            completion = completion.content[0].text
        return completion

    def text(self, max_tokens=None, return_text=True, *args, **kwargs):
        completion = self.anthropic.beta.messages.create(
            max_tokens_to_sample=max_tokens, *args, **kwargs
        )
        if return_text:
            completion = completion.completion
        return completion


class LitellmClient:
    #  https://github.com/BerriAI/litellm
    def __init__(self, keys=None):
        self.api_key = keys

    def chat(self, return_text=True, *args, **kwargs):
        from litellm import completion

        response = completion(api_key=self.api_key, *args, **kwargs)
        if return_text:
            response = response.choices[0].message.content
        return response


def convert_messages_to_prompt(messages):
    #  convert chat message into a single prompt; used for completion model (eg davinci)
    prompt = ""
    for turn in messages:
        if turn["role"] == "system":
            prompt += f"{turn['content']}\n\n"
        elif turn["role"] == "user":
            prompt += f"{turn['content']}\n\n"
        else:  # 'assistant'
            pass
    prompt += "The ranking results of the 20 passages (only identifiers) is:"
    return prompt


def run_retriever(topics, searcher, qrels=None, k=100, qid=None):
    ranks = []
    if isinstance(topics, str):
        hits = searcher.search(topics, k=k)
        ranks.append({"query": topics, "hits": []})
        rank = 0
        for hit in hits:
            rank += 1
            content = json.loads(searcher.doc(hit.docid).raw())
            if "title" in content:
                content = (
                    "Title: "
                    + content["title"]
                    + " "
                    + "Content: "
                    + content["text"]
                )
            else:
                content = content["contents"]
            content = " ".join(content.split())
            ranks[-1]["hits"].append(
                {
                    "content": content,
                    "qid": qid,
                    "docid": hit.docid,
                    "rank": rank,
                    "score": hit.score,
                }
            )
        return ranks[-1]

    for qid in tqdm(topics):
        if qid in qrels:
            query = topics[qid]["title"]
            ranks.append({"query": query, "hits": []})
            hits = searcher.search(query, k=k)
            rank = 0
            for hit in hits:
                rank += 1
                content = json.loads(searcher.doc(hit.docid).raw())
                if "title" in content:
                    content = (
                        "Title: "
                        + content["title"]
                        + " "
                        + "Content: "
                        + content["text"]
                    )
                else:
                    content = content["contents"]
                content = " ".join(content.split())
                ranks[-1]["hits"].append(
                    {
                        "content": content,
                        "qid": qid,
                        "docid": hit.docid,
                        "rank": rank,
                        "score": hit.score,
                    }
                )
    return ranks


def get_prefix_prompt(query, num):
    return [
        {
            "role": "system",
            "content": "You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query.",
        },
        {
            "role": "user",
            "content": f"I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {query}.",
        },
        {"role": "assistant", "content": "Okay, please provide the passages."},
    ]


def get_post_prompt(query, num):
    return f"Search Query: {query}. \nRank the {num} passages above based on their relevance to the search query. The passages should be listed in descending order using identifiers. The most relevant passages should be listed first. The output format should be [] > [], e.g., [1] > [2]. Only response the ranking results, do not say any word or explain."


def create_permutation_instruction(
    item=None, rank_start=0, rank_end=100, model_name="gpt-4.1-nano"
):
    query = item["query"]
    num = len(item["hits"][rank_start:rank_end])

    max_length = 300

    messages = get_prefix_prompt(query, num)
    rank = 0
    for hit in item["hits"][rank_start:rank_end]:
        rank += 1
        content = hit["content"]
        content = content.replace("Title: Content: ", "")
        content = content.strip()
        # For Japanese should cut by character: content = content[:int(max_length)]
        content = " ".join(content.split()[: int(max_length)])
        messages.append({"role": "user", "content": f"[{rank}] {content}"})
        messages.append(
            {"role": "assistant", "content": f"Received passage [{rank}]."}
        )
    messages.append({"role": "user", "content": get_post_prompt(query, num)})

    return messages


def run_llm(
    messages,
    api_key=None,
    base_url=BASE_URL,
    model_name="gpt-4.1-nano",
):
    if "gpt" in model_name:
        Client = OpenaiClient
    elif "claude" in model_name:
        Client = ClaudeClient
    else:
        Client = LitellmClient

    agent = Client(api_key, base_url=base_url)
    response = agent.chat(
        model=model_name, messages=messages, temperature=0, return_text=True
    )
    return response


def clean_response(response: str):
    new_response = ""
    for c in response:
        if not c.isdigit():
            new_response += " "
        else:
            new_response += c
    new_response = new_response.strip()
    return new_response


def remove_duplicate(response):
    new_response = []
    for c in response:
        if c not in new_response:
            new_response.append(c)
    return new_response


def receive_permutation(item, permutation, rank_start=0, rank_end=100):
    response = clean_response(permutation)
    response = [int(x) - 1 for x in response.split()]
    response = remove_duplicate(response)
    cut_range = copy.deepcopy(item["hits"][rank_start:rank_end])
    original_rank = [tt for tt in range(len(cut_range))]
    response = [ss for ss in response if ss in original_rank]
    response = response + [tt for tt in original_rank if tt not in response]
    for j, x in enumerate(response):
        item["hits"][j + rank_start] = copy.deepcopy(cut_range[x])
        if "rank" in item["hits"][j + rank_start]:
            item["hits"][j + rank_start]["rank"] = cut_range[j]["rank"]
        if "score" in item["hits"][j + rank_start]:
            item["hits"][j + rank_start]["score"] = cut_range[j]["score"]
    return item


def permutation_pipeline(
    item=None,
    rank_start=0,
    rank_end=100,
    model_name="gpt-4.1-nano",
    api_key=None,
    base_url=BASE_URL,
):
    messages = create_permutation_instruction(
        item=item,
        rank_start=rank_start,
        rank_end=rank_end,
        model_name=model_name,
    )  # chan
    permutation = run_llm(
        messages, api_key=api_key, base_url=base_url, model_name=model_name
    )
    item = receive_permutation(
        item, permutation, rank_start=rank_start, rank_end=rank_end
    )
    return item


def sliding_windows(
    item=None,
    rank_start=0,
    rank_end=100,
    window_size=20,
    step=10,
    model_name="gpt-4.1-nano",
    api_key=None,
    base_url=None,
):
    item = copy.deepcopy(item)
    end_pos = rank_end
    start_pos = rank_end - window_size
    while start_pos >= rank_start:
        start_pos = max(start_pos, rank_start)
        item = permutation_pipeline(
            item,
            start_pos,
            end_pos,
            model_name=model_name,
            api_key=api_key,
            base_url=base_url,
        )
        end_pos = end_pos - step
        start_pos = start_pos - step
    return item


def write_eval_file(rank_results, file):
    with open(file, "w") as f:
        for i in range(len(rank_results)):
            rank = 1
            hits = rank_results[i]["hits"]
            for hit in hits:
                f.write(
                    f"{hit['qid']} Q0 {hit['docid']} {rank} {hit['score']} rank\n"
                )
                rank += 1
    return True
