import json
import random
import os
import transformers
import torch
from collections import defaultdict
from typing import Dict, Sequence
from torch.utils.data import Dataset
from torch import Tensor



E5_EMBEDDING_PROMPTS = {
    "allnli": [
        "Given a premise, retrieve a hypothesis that is entailed by the premise: ",
        "Retrieve semantically similar text: ",
    ],
    "dureader": "Given a Chinese search query, retrieve web documents that answer the question: ",
    "eli5_question_answer": "Provided a user question, retrieve the highest voted answers on Reddit ELI5 forum: ",
    "fever": "Given a claim, retrieve documents that support or refute the claim: ",
    "hotpot_qa": "Given a multi-hop question, retrieve documents that can help answer the question: ",
    "miracl": "Given a question, retrieve Wikipedia documents that answer the question: ",
    "mrtydi": "Given a question, retrieve Wikipedia documents that answer the question: ",
    "msmarco_passage": "Given a web search query, retrieve relevant documents that answer the query: ",
    "msmarco_document": "Given a web search query, retrieve relevant documents that answer the query: ",
    "nq": "Given a question, retrieve Wikipedia documents that answer the question: ",
    "quora_duplicates": [
        "Given a question, retrieve questions that are semantically equivalent to the given question: ",
        "Find questions that have the same meaning as the input question: ",
    ],
    "squad": "Retrieve Wikipedia documents that answer the question: ",
    "t2ranking": "Given a Chinese search query, retrieve web documents that answer the question: ",
    "trivia_qa": "Retrieve Wikipedia documents that answer the question: ",
}


class EmbeddingDataset(Dataset):
    def __init__(
        self,
        data_args: Dict,
        batch_size: int = 32,
    ):
        self.data_args = data_args
        self.effective_batch_size = batch_size

        self.data = []
        self.load_data(data_args.data_path)

    def __len__(self):
        return len(self.data)

    def load_data(self, data_path: str = None):

        all_samples = []

        with open(data_path, "r") as f:
            dataset_samples = [json.loads(line) for line in f.readlines()]

        data_map = defaultdict(list)
        for i, sample in enumerate(dataset_samples):
            task = sample["source"]
            data_map[task].append(i)

            instruction = (
                E5_EMBEDDING_PROMPTS[task]
                if isinstance(E5_EMBEDDING_PROMPTS[task], str)
                else E5_EMBEDDING_PROMPTS[task][i % 2]
            )
            query = f"Instruct: {instruction}\nQuery:{sample['query']}"
            if isinstance(sample["positive"], str):
                pos = sample["positive"]
            else:
                pos = sample["positive"][0]
            if isinstance(sample["negative"], str):
                negs = [sample["negative"]]
            else:
                negs = sample["negative"]
            if len(negs) < self.data_args.num_negatives:
                negs += [""] * (self.data_args.num_negatives - len(negs))
            else:
                negs = negs[: self.data_args.num_negatives]

            all_samples.append({
                "query": query,
                "document": [pos, *negs],
            })

        for task, samples in data_map.items():
            random.shuffle(samples)

        datasets = list(data_map.keys())

        all_batches = []
        for dataset in datasets:
            dataset_samples = data_map[dataset]
            for i in range(0, len(dataset_samples), self.effective_batch_size):
                batch = dataset_samples[i : i + self.effective_batch_size]
                if len(batch) == self.effective_batch_size:
                    all_batches.append(batch)
                else:
                    print(f"Skip 1 batch for dataset {dataset}.")
        random.shuffle(all_batches)

        final_idx_order = []
        for batch in all_batches:
            for idx in batch:
                final_idx_order.append(idx)

        self.data = [all_samples[idx] for idx in final_idx_order]
        print(f"Loaded {len(self.data)} samples.")

    def __getitem__(self, index):
        return self.data[index]


class EmbeddingDataCollator:
    tokenizer: transformers.PreTrainedTokenizer

    def __init__(
        self,
        tokenizer: transformers.PreTrainedTokenizer,
        **kwargs,
    ):
        self.tokenizer = tokenizer
        if not self.tokenizer.pad_token:
            if getattr(self.tokenizer, "eot_token", None):
                self.tokenizer.pad_token = self.tokenizer.eot_token
            else:
                self.tokenizer.pad_token = self.tokenizer.bos_token
        print(f"use ``{self.tokenizer.pad_token}`` as pad token for llm")

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        if "query" in instances[0]:
            queries = [instance["query"] for instance in instances]
            queries = [f"{query}" + self.tokenizer.pad_token for query in queries]
            queries_encodings = self.tokenizer(
                queries,
                padding=True, truncation=True,
                max_length=512,
                return_tensors="pt",
            )
        else:
            queries_encodings = None

        if "document" in instances[0]:
            docs = sum([instance["document"] for instance in instances], [])
            docs = [f"{doc}" + self.tokenizer.pad_token for doc in docs]
            docs_encodings = self.tokenizer(
                docs,
                padding=True, truncation=True,
                max_length=1024,
                return_tensors="pt",
            )
        else:
            docs_encodings = None

        return {
            "query": queries_encodings,
            "document": docs_encodings,
        }


class PointwiseDataCollator:

    tokenizer: transformers.PreTrainedTokenizer

    def __init__(
        self,
        tokenizer: transformers.PreTrainedTokenizer,
        **kwargs,
    ):
        self.tokenizer = tokenizer
        if not self.tokenizer.pad_token:
            if getattr(self.tokenizer, "eot_token", None):
                self.tokenizer.pad_token = self.tokenizer.eot_token
            else:
                self.tokenizer.pad_token = self.tokenizer.bos_token
        print(f"use ``{self.tokenizer.pad_token}`` as pad token for llm")

        self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self.max_length=kwargs.get('max_length', 2048)
        self.suffix_tokens = self.tokenizer.tokenize(self.suffix, add_special_tokens=False)
        # Cache commonly used token IDs
        self.true_token = self.tokenizer("yes", add_special_tokens=False).input_ids[0]
        self.false_token = self.tokenizer("no", add_special_tokens=False).input_ids[0]

    def format_instruction(self, instruction, query, doc):
        if isinstance(query, tuple):
            instruction = query[0]
            query = query[1]
        text = [
            {"role": "system", "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\"."},
            {"role": "user", "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}"}
        ]
        return text

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        task = 'Given a web search query, retrieve relevant passages that answer the query'
        pairs = [(instance["query"], doc) for instance in instances for doc in instance["document"]]
        messages = [self.format_instruction(task, q, d) for q, d in pairs]
        messages = self.tokenizer.apply_chat_template(
            messages, tokenize=False,
            add_generation_prompt=True, enable_thinking=False,
        )
        messages_encodings = self.tokenizer(
            messages,
            padding=True,
            padding_side="left",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        ranking = torch.tensor([instance["ranking"] for instance in instances]) - 1
        return {
            "generative": messages_encodings,
            "ranking": ranking,
        }



MIXED_RANK_DATA_INSTRUCTIONS = {
    "msmarco": (
        "Given a web search query, retrieval the documents that answer the query",
        "Given a web search query and some relevant documents, rerank the documents that answer the query"
    ),
    "nq": (
        "Given a question, retrieval Wikipedia documents that answer the question"
        "Given a question and some relevant documents, rerank the documents that answer the question"
    ),
    "hotpotqa": (
        "Given a multi-hop question, retrieval the documents that can help answer the question",
        "Given a multi-hop question and some relevant documents, rerank the documents that answer the question"
    ),
    "trivia": (
        "Retrieval Wikipedia documents that answer the question",
        "Given a question and some relevant Wikipedia documents, rerank the documents that answer the question"
    ),
    "t2ranking": (
        "Given a Chinese search query, retrieval the documents that answer the query",
        "Given a Chinese search query and some relevant documents, rerank the documents that answer the query"
    ),
    "dureader": (
        "Given a Chinese search query,r etrieval the documents that answer the query",
        "Given a Chinese search query and some relevant documents, rerank the documents that answer the query"
    ),
    "mmarco_chinese": (
        "Given a Chinese web search query, retrieval the documents that answer the query",
        "Given a Chinese web search query and some relevant documents, rerank the documents that answer the query"
    ),
    "cMedQAv2": (
        "Given a Chinese medical question, retrieval the documents that answer the question",
        "Given a Chinese medical question and some relevant documents, rerank the documents that answer the question"
    ),
    "colliee": (
        "Given a legal case, retrieval the relevant legal articles that can help analyze the case",
        "Given a legal case and some relevant legal articles, rerank the legal articles that can help analyze the case"
    ),
    "law_gpt": (
        "Given a Chinese legal case, retrieval the relevant legal articles that can help analyze the case",
        "Given a Chinese legal case and some relevant legal articles, rerank the legal articles that can help analyze the case"
    ),
    "miracl": (
        "Given a question, retrieval Wikipedia documents that answer the question",
        "Given a question and some relevant Wikipedia documents, rerank the documents that answer the question"
    ),
}

DEFAULT_INSTRUCTION = (
    "Given a query, retrieval the documents that are relevant to the query",
    "Given a query and some relevant documents, rerank the documents that are the most relevant to the query"
)


class MixedRankDataset(Dataset):
    rank_prompt_template = """{instruction}:
Documents:
{documents}
Query: {query}"""
    query_instuction_template = "Instruct: {task_description}\nQuery:{query}"

    def __init__(
        self,
        data_args: Dict,
        batch_size: int = 32,
        per_dataset_max_samples: int = 10000,
    ):
        self.data_args = data_args
        self.effective_batch_size = batch_size
        self.per_dataset_max_samples = per_dataset_max_samples
        self.use_listwise = data_args.use_listwise

        self.data = []
        self.load_data(data_args.data_path)

    def __len__(self):
        return len(self.data)

    def format_query(self, task, query):
        instruction = MIXED_RANK_DATA_INSTRUCTIONS.get(
            task, DEFAULT_INSTRUCTION)[0]
        return self.query_instuction_template.format(
            task_description=instruction,
            query=query
        )

    def format_rank_prompt(self, task, query, docs):
        instruction = MIXED_RANK_DATA_INSTRUCTIONS.get(
            task, DEFAULT_INSTRUCTION)[1]
        return self.rank_prompt_template.format(
            instruction=instruction,
            documents="\n".join([f"[{i}] {doc}" for i, doc in enumerate(docs, start=1)]),
            query=query
        )

    def load_data(self, data_path: str = None):

        all_samples = []

        with open(data_path, "r") as f:
            dataset_samples = [json.loads(line) for line in f.readlines()]

        data_map = defaultdict(list)
        for i, sample in enumerate(dataset_samples):
            task = sample.get("source", "unknown")
            data_map[task].append(i)
            if self.use_listwise:
                pseudo_query = self.format_rank_prompt(task, sample["query"], sample["document"])
            else:
                pseudo_query = None

            all_samples.append({
                "query": self.format_query(task, sample["query"]),
                "document": sample["document"],
                "pseudo_query": pseudo_query,
                "ranking": sample["ranking"],
            })

        for task, samples in data_map.items():
            random.shuffle(samples)

        datasets = list(data_map.keys())

        all_batches = []
        for dataset in datasets:
            dataset_samples = data_map[dataset]
            for i in range(0, min(len(dataset_samples), self.per_dataset_max_samples), self.effective_batch_size):
                batch = dataset_samples[i : i + self.effective_batch_size]
                if len(batch) == self.effective_batch_size:
                    all_batches.append(batch)
                else:
                    print(f"Skip 1 batch for dataset {dataset}.")
        random.shuffle(all_batches)

        final_idx_order = []
        for batch in all_batches:
            for idx in batch:
                final_idx_order.append(idx)

        self.data = [all_samples[idx] for idx in final_idx_order]
        print(f"Loaded {len(self.data)} samples.")

    def __getitem__(self, index):
        return self.data[index]


class MixedRankDataCollator:

    def __init__(
        self,
        tokenizer: transformers.PreTrainedTokenizer,
        **kwargs,
    ):
        self.tokenizer = tokenizer
        if not self.tokenizer.pad_token:
            if getattr(self.tokenizer, "eot_token", None):
                self.tokenizer.pad_token = self.tokenizer.eot_token
            else:
                self.tokenizer.pad_token = self.tokenizer.bos_token
        print(f"use ``{self.tokenizer.pad_token}`` as pad token for llm")

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        query_max_length = 512
        doc_max_length = 1024

        queries = [instance["query"] + self.tokenizer.pad_token for instance in instances]
        queries_encodings = self.tokenizer(
            queries,
            padding=True, truncation=True,
            max_length=query_max_length,
            return_tensors="pt",
        )

        docs = sum([instance["document"] for instance in instances], [])
        docs = [f"{doc}" + self.tokenizer.pad_token for doc in docs]
        docs_encodings = self.tokenizer(
            docs,
            padding=True, truncation=True,
            max_length=doc_max_length,
            return_tensors="pt",
        )

        if instances[0]["pseudo_query"] is not None:
            pseudo_queries = [instance["pseudo_query"] + self.tokenizer.pad_token for instance in instances]
            pseudo_query_encodings = self.tokenizer.apply_chat_template(
                [[{"role": "user", "content": prompt}] for prompt in pseudo_queries],
                tokenize=True, add_generation_prompt=True, enable_thinking=False,
                padding=True, truncation=True, max_length=32768,
                return_tensors="pt", return_dict=True
            )
        else:
            pseudo_query_encodings = None

        if "ranking" in instances[0]:
            ranking = torch.tensor([instance["ranking"] for instance in instances]) - 1  # Convert to zero-based indexing
        else:
            ranking = None

        return {
            "query": queries_encodings,
            "document": docs_encodings,
            "pseudo_query": pseudo_query_encodings,
            "ranking": ranking
        }
