import random
from dataclasses import dataclass
from typing import List, Tuple
import collections

import torch
import datasets
from torch.utils.data import Dataset
from transformers import (
    PreTrainedTokenizer, 
    BatchEncoding, 
    DataCollatorWithPadding
)
import csv
import sys 
import json
from tqdm import tqdm
sys.path.append("..")
from arguments import DataArguments
from trainers import DenseTrainer

import logging
logger = logging.getLogger(__name__)


def load_beir_qrels(qrels_file):
    qrels = {}
    with open(qrels_file) as f:
        tsvreader = csv.DictReader(f, delimiter="\t")
        for row in tsvreader:
            qid = row["query-id"]
            pid = row["corpus-id"]
            rel = int(row["score"])
            if qid in qrels:
                qrels[qid][pid] = rel
            else:
                qrels[qid] = {pid: rel}
    return qrels


def load_beir_queries(query_file, qrels, qry_template=None):
    queries = {}
    with open(query_file, "r", encoding="utf-8") as fi:
        for line in fi:
            data = json.loads(line)
            qid = data["_id"]
            text = data["text"]
            if qid in qrels:
                if qry_template:
                    queries[qid] = qry_template.replace("<text>", text)
                else:
                    queries[qid] = text
    return queries


def load_beir_corpus(corpus_file, psg_template=None, filter_fn=None, verbose=True):
    corpus = {}
    with open(corpus_file, "r", encoding="utf-8") as fi:
        for line in tqdm(fi, disable=not verbose):
            data = json.loads(line)
            
            corpus_id = data["_id"]
            title = data["title"].strip()
            text = data["text"].strip()
            
            ## filter empty
            if len(title) < 1 and len(text) < 1:
                continue
                
            ## get template
            if psg_template:
                corpus[corpus_id] = psg_template.replace("<title>", title).replace("<text>", text)
            else:
                corpus[corpus_id] = title + " " + text
            
    return corpus



class BEIRDataset(Dataset):
    def __init__(self, text_lst: List[str], text_ids: List[int]=None):
        self.text_lst = text_lst
        self.text_ids = text_ids
        assert self.text_ids is None or len(text_lst) == len(text_ids)
    
    def __len__(self):
        return len(self.text_lst)

    def __getitem__(self, item):
        if self.text_ids is not None:
            return self.text_ids[item], self.text_lst[item]
        else:
            return self.text_lst[item]
        
        
        
# @dataclass
# class BEIRCollator(DataCollatorWithPadding):
#     def __call__(self, features):
        
#         text_features = self.tokenizer(features, padding=True, truncation=True, max_length=self.max_length)
#         collated_features = super().__call__(text_features)
#         return collated_features


@dataclass
class BEIRQryBracketCollator(DataCollatorWithPadding):
    def __call__(self, features):
        
        batch_tokens = collections.defaultdict(list)
        
        total_toks = 0
        for txt in features:
            txt = txt.replace("\n", " ")
            tokens = self.tokenizer.tokenize(txt)
            tokens = self.tokenizer.convert_tokens_to_ids(tokens)
            
            input_dict = self.tokenizer.prepare_for_model(
                tokens[: self.max_length], add_special_tokens=True
            )
            
            input_dict["input_ids"] = self.tokenizer.bos_token_q + input_dict["input_ids"] + self.tokenizer.eos_token_q
            input_dict["attention_mask"] = [1] + input_dict["attention_mask"] + [1]
        
            batch_tokens["input_ids"].append(input_dict["input_ids"])
            batch_tokens["attention_mask"].append(input_dict["attention_mask"])
            assert len(input_dict["input_ids"]) == len(input_dict["attention_mask"])
            
        batch_tokens = self.tokenizer.pad(batch_tokens, padding=True, return_tensors="pt")
        
        collated_features = super().__call__(batch_tokens)
        
        # text_features = self.tokenizer(features, padding=True, truncation=True, max_length=self.max_length)
        # collated_features = super().__call__(text_features)
        return collated_features
    
@dataclass
class BEIRPsgBracketCollator(DataCollatorWithPadding):
    def __call__(self, features):
        
        batch_tokens = collections.defaultdict(list)
        
        total_toks = 0
        for txt in features:
            txt = txt.replace("\n", " ")
            tokens = self.tokenizer.tokenize(txt)
            tokens = self.tokenizer.convert_tokens_to_ids(tokens)
            
            input_dict = self.tokenizer.prepare_for_model(
                tokens[: self.max_length], add_special_tokens=True
            )
            
            input_dict["input_ids"] = self.tokenizer.bos_token_d + input_dict["input_ids"] + self.tokenizer.eos_token_d
            input_dict["attention_mask"] = [1] + input_dict["attention_mask"] + [1]
        
            batch_tokens["input_ids"].append(input_dict["input_ids"])
            batch_tokens["attention_mask"].append(input_dict["attention_mask"])
            assert len(input_dict["input_ids"]) == len(input_dict["attention_mask"])
            
        batch_tokens = self.tokenizer.pad(batch_tokens, padding=True, return_tensors="pt")
        
        collated_features = super().__call__(batch_tokens)
        
        # text_features = self.tokenizer(features, padding=True, truncation=True, max_length=self.max_length)
        # collated_features = super().__call__(text_features)
        return collated_features
    
    
    

