import os
import json
import torch
import random
import numpy as np
import torch.distributed as dist
from tqdm import tqdm
from collections import defaultdict
# from transformers import BertTokenizer

with open("mapping_30k_500k.json", "r") as f:
    MAPPING = json.load(f)
max_tok_id = max(int(k) for k in MAPPING.keys()) + 1
bert_map = [np.array(MAPPING.get(str(i), {}).get("bert_tok_id", []), dtype=np.int32)
            for i in range(max_tok_id)]
bert_len = np.array([len(arr) for arr in bert_map], dtype=np.int32)
PAD_ID_ORIG = 30002
PAD_ID_MAP = 500002

class GenericMemmapDataset(torch.utils.data.Dataset):
    def __init__(self, mode, data_dir, folder_name, max_queries, max_passages, max_query_len = 30, max_pas_len = 512, load_hash = True):
        self.mode = mode
        self.max_query_len = max_query_len
        self.max_pas_len = max_pas_len
        self.load_hash = load_hash
        self.query_ids = np.memmap(os.path.join(data_dir, f"{folder_name}/query.ids.{mode}.memmap"), dtype=np.int32, mode='r', shape=(max_queries, 256))
        self.query_attn = np.memmap(os.path.join(data_dir, f"{folder_name}/query.attn.{mode}.memmap"), dtype=np.int32, mode='r', shape=(max_queries, 256))
        self.corpus_ids = np.memmap(os.path.join(data_dir, f"{folder_name}/corpus.ids.{mode}.memmap"), dtype=np.int32, mode='r', shape=(max_passages, 256))
        self.corpus_attn = np.memmap(os.path.join(data_dir, f"{folder_name}/corpus.attn.{mode}.memmap"), dtype=np.int32, mode='r', shape=(max_passages, 256))
        self.qrels = np.loadtxt(os.path.join(data_dir, f"qrels.tsv"), usecols=(0, 1), dtype=np.int32, delimiter='\t')
        
    def __len__(self):
        return len(self.qrels)
                
    def __getitem__(self, idx):
        
        qid, pid = self.qrels[idx]

        return {
            'query_input_ids': np.array(self.query_ids[qid]),
            'query_attention_mask': np.array(self.query_attn[qid]),
            'passage_input_ids': np.array(self.corpus_ids[pid]),
            'passage_attention_mask': np.array(self.corpus_attn[pid]),
            'qid' : qid,
            'pid' : pid,
            'qp_mat' : torch.tensor([[1]])
        }

    def token_mapping_batch(self, tokenmonster_tok_id, tokenmonster_attn_mask, length):
        batch_size = tokenmonster_tok_id.shape[0]
        bert_tok_ids = np.full((batch_size, length), PAD_ID_ORIG, dtype=np.int32)
        bert_attn_masks = np.zeros((batch_size, length), dtype=np.int32)
        hash_ids = np.full((batch_size, length), PAD_ID_MAP, dtype=np.int32)

        for b in range(batch_size):
            pos = 0
            for tok_id, attn_mask in zip(tokenmonster_tok_id[b], tokenmonster_attn_mask[b]):
                tok_id_int = int(tok_id)
                if tok_id_int == PAD_ID_MAP:
                    break
                mapped_tokens = bert_map[tok_id_int]
                num_tokens = bert_len[tok_id_int]
                if num_tokens == 0:
                    continue

                end_pos = min(pos + num_tokens, length)
                bert_tok_ids[b, pos:end_pos] = mapped_tokens[:end_pos-pos]
                bert_attn_masks[b, pos:end_pos] = attn_mask
                hash_ids[b, pos:end_pos] = tok_id_int
                pos = end_pos

                if pos >= length:
                    break
        return bert_tok_ids, bert_attn_masks, hash_ids

    def collate_fn(self, batch):

        if self.load_hash:
            query_input_ids, query_attn_mask, query_hash_ids = self.token_mapping_batch(
                tokenmonster_tok_id = np.vstack([x['query_input_ids'] for x in batch]),
                tokenmonster_attn_mask = np.vstack([x['query_attention_mask'] for x in batch]),
                length = self.max_query_len
            )
            passage_input_ids, passage_attn_mask, passage_hash_ids = self.token_mapping_batch(
                tokenmonster_tok_id = np.vstack([x['passage_input_ids'] for x in batch]),
                tokenmonster_attn_mask = np.vstack([x['passage_attention_mask'] for x in batch]),
                length = self.max_pas_len
            )

        batch_data = {
            'query_input_ids': torch.tensor(query_input_ids) if self.load_hash else torch.tensor(np.vstack([x['query_input_ids'] for x in batch])),
            'query_attention_mask': torch.tensor(query_attn_mask) if self.load_hash else torch.tensor(np.vstack([x['query_attention_mask'] for x in batch])),
            'query_hash_ids': torch.tensor(query_hash_ids) if self.load_hash else None,
            'passage_input_ids': torch.tensor(passage_input_ids) if self.load_hash else torch.tensor(np.vstack([x['passage_input_ids'] for x in batch])),
            'passage_attention_mask': torch.tensor(passage_attn_mask) if self.load_hash else torch.tensor(np.vstack([x['passage_attention_mask'] for x in batch])),
            'passage_hash_ids': torch.tensor(passage_hash_ids) if self.load_hash else None,
            'qid' : [x['qid'] for x in batch],
            'pid' : [x['pid'] for x in batch],
        }

        qp_mat = torch.eye(n = len(batch_data['qid']), m = len(batch_data['pid']))
        batch_data['qp_mat'] = qp_mat

        return batch_data


class GenericMemmapDatasetFT(torch.utils.data.Dataset):
    def __init__(self, mode, data_dir, folder_name, max_queries, max_passages, max_query_len = 30, max_pas_len = 512, load_hash = True):
        self.mode = mode
        self.max_query_len = max_query_len
        self.max_pas_len = max_pas_len
        self.load_hash = load_hash
        self.query_ids = np.memmap(os.path.join(data_dir, f"{folder_name}/query.ids.{mode}.memmap"), dtype=np.int32, mode='r', shape=(max_queries, 256))
        self.query_attn = np.memmap(os.path.join(data_dir, f"{folder_name}/query.attn.{mode}.memmap"), dtype=np.int32, mode='r', shape=(max_queries, 256))
        self.corpus_ids = np.memmap(os.path.join(data_dir, f"{folder_name}/corpus.ids.{mode}.memmap"), dtype=np.int32, mode='r', shape=(max_passages, 256))
        self.corpus_attn = np.memmap(os.path.join(data_dir, f"{folder_name}/corpus.attn.{mode}.memmap"), dtype=np.int32, mode='r', shape=(max_passages, 256))
        self.qrels = self.load_qrels(os.path.join(data_dir, f"{folder_name}/negatives.jsonl"))
    
    def load_qrels(self, path):
        data = []
        with open(path) as f:
            for line in tqdm(f, desc = f'Loading {path}'):
                itm = json.loads(line)
                qid = itm["query"]
                pid = itm["corpus"]
                neg_pids = itm["negative"]
                data.append((qid, pid, neg_pids))
        print(f'{path} loaded.')
        return data
        
    def __len__(self):
        return len(self.qrels)
                
    def __getitem__(self, idx):
        
        qid, pid, neg_pids = self.qrels[idx]
        neg_pids = random.sample(neg_pids, 7)

        return {
            'query_input_ids': np.array(self.query_ids[qid]),
            'query_attention_mask': np.array(self.query_attn[qid]),
            'passage_input_ids': np.array(self.corpus_ids[pid]),
            'passage_attention_mask': np.array(self.corpus_attn[pid]),
            'neg_passage_input_ids': np.array(self.corpus_ids[neg_pids]),
            'neg_passage_attention_mask': np.array(self.corpus_attn[neg_pids]),
            'qid' : qid,
            'pid' : pid,
            'neg_pid' : neg_pids,
        }

    def token_mapping_batch(self, tokenmonster_tok_id, tokenmonster_attn_mask, length):
        batch_size = tokenmonster_tok_id.shape[0]
        bert_tok_ids = np.full((batch_size, length), PAD_ID_ORIG, dtype=np.int32)
        bert_attn_masks = np.zeros((batch_size, length), dtype=np.int32)
        hash_ids = np.full((batch_size, length), PAD_ID_MAP, dtype=np.int32)

        for b in range(batch_size):
            pos = 0
            for tok_id, attn_mask in zip(tokenmonster_tok_id[b], tokenmonster_attn_mask[b]):
                tok_id_int = int(tok_id)
                if tok_id_int == PAD_ID_MAP:
                    break
                mapped_tokens = bert_map[tok_id_int]
                num_tokens = bert_len[tok_id_int]
                if num_tokens == 0:
                    continue

                end_pos = min(pos + num_tokens, length)
                bert_tok_ids[b, pos:end_pos] = mapped_tokens[:end_pos-pos]
                bert_attn_masks[b, pos:end_pos] = attn_mask
                hash_ids[b, pos:end_pos] = tok_id_int
                pos = end_pos

                if pos >= length:
                    break
        return bert_tok_ids, bert_attn_masks, hash_ids

    def collate_fn(self, batch):

        if self.load_hash:
            query_input_ids, query_attn_mask, query_hash_ids = self.token_mapping_batch(
                tokenmonster_tok_id = np.vstack([x['query_input_ids'] for x in batch]),
                tokenmonster_attn_mask = np.vstack([x['query_attention_mask'] for x in batch]),
                length = self.max_query_len
            )
            passage_input_ids, passage_attn_mask, passage_hash_ids = self.token_mapping_batch(
                tokenmonster_tok_id = np.concatenate([
                    np.vstack([x['passage_input_ids'] for x in batch]),
                    np.concatenate([x['neg_passage_input_ids'] for x in batch])
                ]),
                tokenmonster_attn_mask = np.concatenate([
                    np.vstack([x['passage_attention_mask'] for x in batch]),
                    np.concatenate([x['neg_passage_attention_mask'] for x in batch])
                ]),
                length = self.max_pas_len
            )

        batch_data = {
            'query_input_ids': torch.tensor(query_input_ids) if self.load_hash else torch.tensor(np.vstack([x['query_input_ids'] for x in batch])),
            'query_attention_mask': torch.tensor(query_attn_mask) if self.load_hash else torch.tensor(np.vstack([x['query_attention_mask'] for x in batch])),
            'query_hash_ids': torch.tensor(query_hash_ids) if self.load_hash else None,
            'passage_input_ids': torch.tensor(passage_input_ids) if self.load_hash else torch.tensor(np.vstack([x['passage_input_ids'] for x in batch])),
            'passage_attention_mask': torch.tensor(passage_attn_mask) if self.load_hash else torch.tensor(np.vstack([x['passage_attention_mask'] for x in batch])),
            'passage_hash_ids': torch.tensor(passage_hash_ids) if self.load_hash else None,
            'qid' : [x['qid'] for x in batch],
            'pid' : [x['pid'] for x in batch] + [neg for x in batch for neg in x['neg_pid']],
        }

        batch_size = len(batch)
        qp_mat = torch.zeros(len(batch_data['qid']), len(batch_data['pid']))
        for i in range(batch_size):
            qp_mat[i, i] = 1
        batch_data['qp_mat'] = qp_mat

        return batch_data


class RandomDataLoader:
    def __init__(self, dataloaders):
        self.dataloaders = dataloaders
        self.itrs = [iter(dl) for dl in dataloaders]
        self.active_dl = list(range(len(dataloaders)))
        self.rank = dist.get_rank() if dist.is_initialized() else 0
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
        self.device = torch.device("cuda", torch.cuda.current_device())

    def __iter__(self):
        return self

    def _get_synced_index(self):
        local_choice = np.random.choice(self.active_dl)
        if dist.is_initialized():
            choice_tensor = torch.tensor(local_choice, dtype=torch.int64, device=self.device)
            dist.broadcast(choice_tensor, src=0)
            return choice_tensor.item()
        return local_choice

    def __next__(self):
        if len(self.active_dl) == 0:
            self.itrs = [iter(dl) for dl in self.dataloaders]
            self.active_dl = list(range(len(self.dataloaders)))
            raise StopIteration

        dl_idx = self._get_synced_index()
        try:
            return next(self.itrs[dl_idx])
        except StopIteration:
            self.active_dl.remove(dl_idx)
            return self.__next__()

    def __len__(self):
        return sum([len(dl) for dl in self.dataloaders])


def create_dataloaders(data_dir, dataset_info, folder_name, do_pretrain = False, batch_size = 512, num_workers = 16, val_split = 0.02, load_hash = True):

    datasets = []
    for name, info in dataset_info.items():
        if do_pretrain:
            datasets.append(GenericMemmapDataset('train', os.path.join(data_dir, name), folder_name, load_hash = load_hash, **info))
        else:
            datasets.append(GenericMemmapDatasetFT('train', os.path.join(data_dir, name), folder_name, load_hash = load_hash, **info))
        print(f'{name} loaded.')
   
    train_dataloaders = []
    for dataset in datasets:
        train_dataloaders.append(torch.utils.data.DataLoader(
            dataset, 
            shuffle = False, 
            batch_size = batch_size,
            num_workers = num_workers,
            collate_fn = dataset.collate_fn,
            sampler = torch.utils.data.DistributedSampler(
                dataset,
                shuffle = True,
                drop_last = True,
            ) if load_hash else None
        ))
    return RandomDataLoader(train_dataloaders)