import random
import os

from datasets import load_dataset, interleave_datasets
from torch.utils.data import Dataset
from leanfinder.retriever.arguments import DataArguments

import logging
from tqdm import tqdm

logger = logging.getLogger(__name__)

class DPODoubleDataset(Dataset):
    def __init__(self,
                 data_args,
                 trainer=None,
                 dpo_dataset_path=None,
                 contrastive_dataset_path=None,
                 contrastive_dataset_name="/local-scratch1/.hf_cache/lean-finder-data",
                 train_group_size=5):
        logger.info(f"Initializing DPODoubleDataset")
        self.data_args = data_args
        self.trainer = trainer
        self.dpo_dataset_path = dpo_dataset_path
        self.contrastive_dataset_name = contrastive_dataset_name
        self.contrastive_dataset_path = contrastive_dataset_path
        self.train_group_size = train_group_size

        logger.info(f"loading constrastive learning dataset from {self.contrastive_dataset_path}")
        self.contrastive_dataset = load_dataset(
            self.contrastive_dataset_name,
            data_files=self.contrastive_dataset_path,
            cache_dir="/local-scratch1/.hf_cache",
            split="train",
            num_proc=1,
            trust_remote_code=True,
        )
        logger.info(f"loading contrastive learning dataset from {self.contrastive_dataset_path}, length: {len(self.contrastive_dataset)}")

        logger.info(f"loading dpo dataset from {self.dpo_dataset_path}")
        self.dpo_dataset = load_dataset(
            "json",
            data_files=self.dpo_dataset_path,
            split="train",
        )
        logger.info(f"loading dpo dataset from {self.dpo_dataset_path}, length: {len(self.dpo_dataset)}")

        self.dpo_dataset = self.dpo_dataset.shuffle(seed=42)
        self.contrastive_dataset = self.contrastive_dataset.shuffle(seed=42)
        self.contrastive_dataset_length = len(self.contrastive_dataset)

    def set_trainer(self, trainer):
        self.trainer = trainer
    
    def __len__(self):
        return len(self.dpo_dataset)

    def __getitem__(self, item):
        index_for_contrastive = (item + random.randint(0, self.contrastive_dataset_length - 1)) % self.contrastive_dataset_length
        group = self.contrastive_dataset[index_for_contrastive]
        epoch = int(self.trainer.state.epoch)
        _hashed_seed = hash(index_for_contrastive + self.trainer.args.seed)

        query_text = group['query']
        formatted_query = (query_text, None, None, None)
        formatted_documents = []
        selected_positive = group['positive_passages'][(_hashed_seed + epoch) % len(group['positive_passages'])] 
        positive_text = selected_positive['text']
        formatted_documents.append((positive_text, None, None, None))

        negative_size = self.train_group_size - 1
        if len(group['negative_passages']) < negative_size:
            selected_negatives = random.choices(group['negative_passages'], k=negative_size)
        elif self.train_group_size == 1:
            selected_negatives = []
        else:
            offset = epoch * negative_size % len(group['negative_passages'])
            selected_negatives = list(group['negative_passages'])
            random.Random(_hashed_seed).shuffle(selected_negatives)
            selected_negatives = selected_negatives * 2
            selected_negatives = selected_negatives[offset: offset + negative_size]

        for negative in selected_negatives:
            negative_text = negative['text']
            formatted_documents.append((negative_text, None, None, None))

        dpo_item = self.dpo_dataset[item]

        return formatted_query, formatted_documents, dpo_item


class TrainDataset(Dataset):
    def __init__(self,
                 data_args: DataArguments,
                 trainer=None,
                 dataset_name=None,
                 corpus_name=None,
                 dataset_path=None,
                 corpus_path=None,
                 corpus_assets_path=None):
        logger.info(f"Initializing TrainDataset")
        self.data_args = data_args
        self.trainer = trainer

        if len(self.data_args.dataset_path_list) == 0:
            logger.info(f"Loading dataset from {self.data_args.dataset_path_list}")
            self.train_data = load_dataset(
                self.data_args.dataset_name if dataset_name is None else dataset_name,
                self.data_args.dataset_config,
                data_files=self.data_args.dataset_path if dataset_path is None else dataset_path,
                split=self.data_args.dataset_split,
                cache_dir=self.data_args.dataset_cache_dir,
                num_proc=self.data_args.num_proc,
                trust_remote_code=True,
            )
        else:
            logger.info(f"Loading {len(self.data_args.dataset_path_list)} datasets from the list: {self.data_args.dataset_path_list}")
            dataset_list = []
            for path in self.data_args.dataset_path_list:
                dataset_list.append(load_dataset(
                    self.data_args.dataset_name if dataset_name is None else dataset_name,
                    self.data_args.dataset_config,
                    data_files=path,
                    split=self.data_args.dataset_split,
                    cache_dir=self.data_args.dataset_cache_dir,
                    num_proc=self.data_args.num_proc,
                    trust_remote_code=True,
                ))
            lengths = [len(dataset) for dataset in dataset_list]
            logger.info(f"Lengths of the datasets: {lengths}")
            probs = [length/sum(lengths) for length in lengths]
            logger.info(f"Initial probabilities: {probs}")
            probs = [prob**0.5 for prob in probs]
            probs = [prob/sum(probs) for prob in probs]
            logger.info(f"Final probabilities of sampling after adjustment: {probs}")
            self.train_data = interleave_datasets(dataset_list, probabilities=probs, seed=42, stopping_strategy="all_exhausted")
            logger.info(f"Total train data length: {len(self.train_data)}")
            
        self.train_data = self.train_data.shuffle(seed=42)

        if self.data_args.corpus_name is None and corpus_name is None:
            self.corpus = None
        else:
            self.corpus = load_dataset(
                self.data_args.corpus_name if corpus_name is None else corpus_name,
                self.data_args.corpus_config,
                data_files=self.data_args.corpus_path if corpus_path is None else corpus_path,
                split=self.data_args.corpus_split,
                cache_dir=self.data_args.dataset_cache_dir,
                num_proc=self.data_args.num_proc,
                trust_remote_code=True,
            )
            logger.info("Train data: %s", self.train_data)
            logger.info("Train data keys: %s", self.train_data.column_names)
            logger.info("Train data first item: %s", self.train_data[0])
        
        self.docid_to_index = {}
        if self.corpus is not None:
            corpus_ids = self.corpus.select_columns(['docid'])
            docids = corpus_ids['docid']
            self.docid_to_index = {docid: index for index, docid in enumerate(tqdm(docids))}


    def set_trainer(self, trainer):
        self.trainer = trainer

    def __len__(self):
        return len(self.train_data)

    def _get_info_from_docid(self, docid, prefix):
        document_info = self.corpus[self.docid_to_index[docid]]
        assert document_info['docid'] == docid
        text = document_info.get('text', '')
        return prefix + text, None, None, None

    def __getitem__(self, item):
        group = self.train_data[item]
        epoch = int(self.trainer.state.epoch)
        _hashed_seed = hash(item + self.trainer.args.seed)

        if 'positive_passages' in group:
            query_text = group['query']
            formatted_query = (self.data_args.query_prefix + query_text,
                               None, None, None)

            formatted_documents = []
            selected_positive = group['positive_passages'][(_hashed_seed + epoch) % len(group['positive_passages'])]
            positive_text = (selected_positive['title'] + ' ' + selected_positive['text']
                             if 'title' in selected_positive else selected_positive['text'])
            formatted_documents.append((self.data_args.passage_prefix + positive_text, None, None, None))

            negative_size = self.data_args.train_group_size - 1
            if len(group['negative_passages']) < negative_size:
                selected_negatives = random.choices(group['negative_passages'], k=negative_size)
            elif self.data_args.train_group_size == 1:
                selected_negatives = []
            else:
                offset = epoch * negative_size % len(group['negative_passages'])
                selected_negatives = list(group['negative_passages'])
                random.Random(_hashed_seed).shuffle(selected_negatives)
                selected_negatives = selected_negatives * 2
                selected_negatives = selected_negatives[offset: offset + negative_size]

            for negative in selected_negatives:
                negative_text = (negative['title'] + ' ' + negative['text']
                                 if 'title' in negative else negative['text'])
                formatted_documents.append((self.data_args.passage_prefix + negative_text, None, None, None))

            return formatted_query, formatted_documents

        query_id = group['query_id']
        query_text = group.get('query_text', '') or ''
        formatted_query = (self.data_args.query_prefix + query_text,
                           None, None, None)

        formatted_documents = []
        positive_document_ids = group['positive_document_ids']
        negative_document_ids = group['negative_document_ids']

        selected_positive_docid = positive_document_ids[(_hashed_seed + epoch) % len(positive_document_ids)]
        formatted_documents.append(
            self._get_info_from_docid(selected_positive_docid, self.data_args.passage_prefix)
        )

        negative_size = self.data_args.train_group_size - 1
        if len(negative_document_ids) < negative_size:
            selected_negative_docids = random.choices(negative_document_ids, k=negative_size)
        elif self.data_args.train_group_size == 1:
            selected_negative_docids = []
        else:
            offset = epoch * negative_size % len(negative_document_ids)
            selected_negative_docids = list(negative_document_ids)
            random.Random(_hashed_seed).shuffle(selected_negative_docids)
            selected_negative_docids = selected_negative_docids * 2
            selected_negative_docids = selected_negative_docids[offset: offset + negative_size]

        for neg_docid in selected_negative_docids:
            formatted_documents.append(
                self._get_info_from_docid(neg_docid, self.data_args.passage_prefix)
            )

        return formatted_query, formatted_documents