import random
import os
import numpy as np
from tqdm import tqdm
from typing import List, Tuple

from datasets import load_dataset
from torch.utils.data import Dataset

from pyserini.index.lucene import LuceneIndexer
from pyserini.search import SimpleSearcher, FaissSearcher

from tevatron.retriever.arguments import DataArguments

import logging
logger = logging.getLogger(__name__)


def format_query(query: str, prefix: str = '') -> str:
    return f'{prefix} {query.strip()}'.strip()

def format_passage(text: str, title: str = '', prefix: str = '') -> str:
    return f'{prefix} {title.strip()} {text.strip()}'.strip()


class TrainDataset(Dataset):
    def __init__(self, data_args: DataArguments, trainer = None, n_ic_examples=0):
        self.data_args = data_args
        self.train_data = load_dataset(
            self.data_args.dataset_name,
            self.data_args.dataset_config,
            data_files=self.data_args.dataset_path,
            split=self.data_args.dataset_split,
            cache_dir=self.data_args.dataset_cache_dir,
            trust_remote_code=True
        )
        if self.data_args.dataset_number_of_shards > 1:
            self.encode_data = self.encode_data.shard(
                num_shards=self.data_args.dataset_number_of_shards,
                index=self.data_args.dataset_shard_index,
            )
        self.trainer = trainer
        self.n_ic_examples = n_ic_examples
        if n_ic_examples > 0:
            self.train_data = self.construct_ic_queries(self.train_data)

    def __len__(self):
        return len(self.train_data)

    def __getitem__(self, item) -> Tuple[str, List[str]]:
        group = self.train_data[item]
        epoch = int(self.trainer.state.epoch)

        _hashed_seed = hash(item + self.trainer.args.seed)

        query = group['query']
        group_positives = group['positive_passages']
        group_negatives = group['negative_passages']

        formated_query = format_query(query, self.data_args.query_prefix)
        formated_passages = []

        if self.data_args.positive_passage_no_shuffle:
            pos_psg = group_positives[0]
        else:
            pos_psg = group_positives[(_hashed_seed + epoch) % len(group_positives)]
        
        formated_passages.append(format_passage(pos_psg['text'], pos_psg['title'], self.data_args.passage_prefix))

        negative_size = self.data_args.train_group_size - 1
        if len(group_negatives) < negative_size:
            negs = random.choices(group_negatives, k=negative_size)
        elif self.data_args.train_group_size == 1:
            negs = []
        elif self.data_args.negative_passage_no_shuffle:
            negs = group_negatives[:negative_size]
        else:
            _offset = epoch * negative_size % len(group_negatives)
            negs = [x for x in group_negatives]
            random.Random(_hashed_seed).shuffle(negs)
            negs = negs * 2
            negs = negs[_offset: _offset + negative_size]

        for neg_psg in negs:
            formated_passages.append(format_passage(neg_psg['text'], neg_psg['title'], self.data_args.passage_prefix))

        return formated_query, formated_passages


    def encode_ic_data(self, ic_data, index_dir='temp_index', type="dense", model_id='msmarco-MiniLM-L-6-v3'):
        if type == "dense":
            return self.encode_ic_data_faiss(ic_data, index_dir, model_id)
        else:
            return self.encode_ic_data_bm25(ic_data, index_dir)

    def encode_ic_data_faiss(self, ic_data, index_dir='temp_index', model_id='msmarco-MiniLM-L-6-v3'):
        ic_questions = [item.query for item in ic_data]

        documents = [{"id": f"doc{idx}", "contents": question} for idx, question in enumerate(ic_questions)]
        
        model = SentenceTransformer(model_id).to('cuda')
        embeddings = model.encode([doc["contents"] for doc in documents], batch_size=512)

        print(index_dir)
        if not os.path.exists(index_dir):
            os.makedirs(index_dir, exist_ok=True)
        
        index = faiss.IndexFlatL2(embeddings.shape[1])
        index.add(embeddings)
        faiss.write_index(index, os.path.join(index_dir, 'index'))

        with open(os.path.join(index_dir, 'docid'), 'w') as f_out:
            for idx, doc in enumerate(documents):
                docid = doc['id']
                f_out.write(f'{docid}\n')
            
        searcher = FaissSearcher(index_dir, model_id)
        
        return searcher

    def encode_ic_data_bm25(self, ic_data, index_dir='temp_index'):
        """
        Encode the in-context data using pyserini and keep the index in memory.

        Args:
        ic_data: list of in-context data [list of dicts]

        Returns:
        searcher: SimpleSearcher instance with in-memory index
        ic_questions: List of questions from the in-context data
        """
        ic_questions = [item["query"] for item in ic_data]
        
        documents = [{"id": f"doc{idx}", "contents": question} for idx, question in enumerate(ic_questions)]
        
        os.makedirs(index_dir, exist_ok=True)
        indexer = LuceneIndexer(index_dir=index_dir, threads=1)
        
        indexer.add_batch_dict(documents)
        indexer.close()

        searcher = SimpleSearcher(index_dir)
        
        return searcher
    

    def construct_ic_query(self, sample, n_ic_examples, searcher=None, ic_data=None, use_negatives=False, instruction=None):
        query = sample["query"]

        if searcher:
            hits = searcher.search(query, k=n_ic_examples+1)
            ic_examples = [ic_data[int(hit.docid.split('doc')[-1])] for hit in hits]
            ic_examples = ic_examples[::-1]
            ic_examples = [example for example in ic_examples if example['query'] != query]
        else:
            ic_examples = random.sample(ic_data, n_ic_examples)

        if instruction:
            ic_query = f"Instruct: {instruction}\n"
            ic_query_flipped = f"Instruct: {instruction}\n"
        else:
            ic_query = ""
            ic_query_flipped = ""
        
        for example in ic_examples:
            positive_passage = example['positive_passages'][0]['title'] + " " + example['positive_passages'][0]['text']
            negative_passage = example['negative_passages'][0]['title'] + " " + example['negative_passages'][0]['text']
            if use_negatives:
                ic_query += f"Query: {example['query']}\nPositive Document: {positive_passage}\nNegative Document: {negative_passage}\n\n"
                ic_query_flipped += f"Query: {example['query']}\nPositive Document: {negative_passage}\nNegative Document: {positive_passage}\n\n"
            else:
                ic_query += f"Query: {example['query']}\nPositive Document: {positive_passage}\n\n"
                ic_query_flipped += f"Query: {example['query']}\nPositive Document: {negative_passage}\n\n"
            
        ic_query += f"Query: {query}"
        ic_query_flipped += f"Query: {query}"
        return ic_query, ic_query_flipped

    def construct_ic_queries(self, data):
        new_data = []
        index_dir = os.path.join(os.environ.get("TRANSFORMERS_CACHE", "temp_train_index"), f"temp_indexes_msmarco_tevatron")
        if not os.path.exists(index_dir):
            searcher = self.encode_ic_data(data, index_dir=index_dir, type='sparse', model_id=None)
        else:
            searcher = SimpleSearcher(index_dir)
            
        for idx in tqdm(range(len(data))):
            item = data[idx]
            if np.random.choice([0, 1], p=[0.3, 0.7])==1:
                query, query_flipped = self.construct_ic_query(item, self.n_ic_examples, searcher=searcher, ic_data=data)
            else:
                query = "Query: " + item['query']
                query_flipped = query
            
            new_data_item = item.copy()
            item['query'] = query
            new_data.append(item)
        return new_data


class EncodeDataset(Dataset):

    def __init__(self, data_args: DataArguments):
        self.data_args = data_args
        self.encode_data = load_dataset(
            self.data_args.dataset_name,
            self.data_args.dataset_config,
            data_files=self.data_args.dataset_path,
            split=self.data_args.dataset_split,
            cache_dir=self.data_args.dataset_cache_dir,
        )
        if self.data_args.dataset_number_of_shards > 1:
            self.encode_data = self.encode_data.shard(
                num_shards=self.data_args.dataset_number_of_shards,
                index=self.data_args.dataset_shard_index,
            )

    def __len__(self):
        return len(self.encode_data)

    def __getitem__(self, item) -> Tuple[str, str]:
        text = self.encode_data[item]
        if self.data_args.encode_is_query:
            text_id = text['query_id']
            formated_text = format_query(text['query'], self.data_args.query_prefix)
        else:
            text_id = text['docid']
            formated_text = format_passage(text['text'], text['title'], self.data_args.passage_prefix)
        return text_id, formated_text
