from datasets import load_dataset
from torch.utils.data import Dataset
import numpy as np
from source.controller.retriever.generate_passage_embeddings import main_modified
import os 
import pandas as pd 
import glob
from source.controller.retriever.index import Indexer
# import nltk
# nltk.download('punkt')  # This is needed for tokenization
from nltk.tokenize import word_tokenize, sent_tokenize
import json

config = lambda dataset_name, k:{
        "projection_size":768, #768
        "n_subquantizers":0,
        "n_bits":8,
        "passages_embeddings":dict_path[dataset_name]["passages_embeddings"],
        "indexing_batch_size":1000000,
        "passages":dict_path[dataset_name]["passages"],
        "output_dir":"",
        "validartion_workers":6,
        "n_docs":k,
        "data":"",
        "per_gpu_batch_size": 64,
        "save_or_load_index": True,
        "question_maxlength": 512,
        "dataset":"none",
        "lowercase": True,
        "normalize_text": True,
    }

DATASET_DIR = "/mnt/data/tp531.data/"

dict_path = {"wikipedia_dump_100": {"passages_embeddings":DATASET_DIR + "wikipedia_dump_contriever/wikipedia_embeddings/", "passages":DATASET_DIR + "wikipedia_dump_contriever/psgs_w100.tsv"},
            "med_records_100" : {"passages_embeddings":DATASET_DIR + "ADE_med/embedding_100/", "passages":DATASET_DIR + "ADE_med/data_split_100/"},
            "wikipedia_dump_mini": {"passages_embeddings":DATASET_DIR + "wikipedia_dump_contriever_mini/wikipedia_embeddings/", "passages":DATASET_DIR + "wikipedia_dump_contriever_mini/psgs_w100.tsv"},

            }

     
template_task_query = {"boolq": lambda x: "{text}".format(id=x["id"], text=x["text"]),
                      "boolq_test": lambda x: "{text}".format(id=x["id"], text=x["text"]),
                      "ade_qa_med2": lambda x: "{text}".format(id=x["id"], text=x["text"]),
                      "ade_qa_med2_test": lambda x: "{text}".format(id=x["id"], text=x["text"]),
                      }


dict_voc = {  "boolq": {"query":"question", "label": "answer"},
              "boolq_test": {"query":"question", "label": "answer"},
              "ade_qa_med2": {"query":"question", "label": "answer"},
              "ade_qa_med2_test": {"query":"question", "label": "answer"},
              }

def recursive_gathering(target, dict_df):
    if type(dict_df) == list:
        target += dict_df
    elif type(dict_df) == dict:
        for key in dict_df.keys():
            recursive_gathering(target, dict_df[key])
    else:
        target.append(dict_df)
    return target

        
def split_into_passages(text, word_limit=500):
    sentences = sent_tokenize(text)
    passages = []
    current_passage = ""

    for sentence in sentences:
        # Check if adding the next sentence would exceed the word limit
        if len(word_tokenize(current_passage)) + len(word_tokenize(sentence)) > word_limit:
            passages.append(current_passage)
            current_passage = sentence
        else:
            current_passage += " " + sentence

    # Add the last passage if it's not empty
    if current_passage.strip() != "":
        passages.append(current_passage)

    return passages

class FSdata():
    def __init__(self, dataset_name, dataset_dir=DATASET_DIR, max_length=500, split_query=False):
        self.dataset_name = dataset_name
        self.dataset_dir = dataset_dir
        print("Datasets are loaded")
        self.query = dict_voc[self.dataset_name]["query"]
        self.label = dict_voc[self.dataset_name]["label"]
        self.max_length = max_length
        self.split_query = split_query
        self.documents = []
        
    def load_dataset(self):
        if self.dataset_name == "boolq":
            self.data = load_dataset("boolq", split="train")
            self.query_set = []
            for x in self.data:
                self.query_set.append({"question" : x["question"][0].upper() + x["question"][1:] + "?" , "answer" : 1 if x["answer"] == True else 0})
        elif self.dataset_name == "boolq_test":
            self.data = load_dataset("boolq", split="validation")
            self.query_set = []
            for x in self.data:
                self.query_set.append({"question" : x["question"][0].upper() + x["question"][1:] + "?" , "answer" : 1 if x["answer"] == True else 0})
       
        elif self.dataset_name == "ade_qa_med2":
            self.query_set = []
            data = pd.read_excel(self.dataset_dir + "ADE_med/queries/annotations/query2.xlsx")
            for i,x in data.iterrows():
                if i <= 2000:
                    answers = x["Answer_Unstructured"].replace("(", "").replace(")","").strip().split(",")
                    answers = [x.split() for x in answers]
                    for answer in answers : 
                        if len(answer) == 2 and answer[1].lower() =="mg" and answer[0].isdigit():     
                            self.query_set.append({"question" : x["NL_Question"], "answer" : answer[0]}) #.replace("(", "").replace(")","").split(",")
                            break            
        elif self.dataset_name == "ade_qa_med2_test":
            self.query_set = []
            data = pd.read_excel(self.dataset_dir + "ADE_med/queries/annotations/query2.xlsx")
            for i,x in data.iterrows():
                if i > 2000:
                    answers = x["Answer_Unstructured"].replace("(", "").replace(")","").strip().split(",")
                    answers = [x.split() for x in answers]
                    for answer in answers : 
                        if len(answer) == 2 and answer[1].lower() =="mg" and answer[0].isdigit():     
                            self.query_set.append({"question" : x["NL_Question"], "answer" : answer[0]}) #.replace("(", "").replace(")","").split(",")
                            break
                        
    def generate_embeddings(self, data):
        return main_modified("facebook\contriever", data)

    def split_text(self, data):
        result = []
        for i,x in enumerate(data):
            print("split : ", i, " / ", len(data))
            passages = split_into_passages(x[self.query], self.max_length)
            result += passages
        return result

    def get_documents(self):
        if os.path.exists(os.path.join(self.dataset_dir, self.dataset_name, "documents.csv")):
            documents = pd.read_csv(os.path.join(self.dataset_dir, self.dataset_name, "documents.csv"), header=0)
            return documents.values
        else:
            print(self.documents[0])
            self.documents = self.split_text(self.documents)
            self.document_index, self.document_embedding = self.generate_embeddings([{"id" : i, "text" : self.document_texts[i]}for i,x in enumerate(self.documents)])
            return [[self.document_texts[i],self.document_embedding[i]] for (i,x) in enumerate(self.documents)]
    
    def get_query_set(self):
        if os.path.exists(os.path.join(self.dataset_dir, self.dataset_name, "queries.csv")):
            testset = pd.read_csv(os.path.join(self.dataset_dir, self.dataset_name, "queries.csv"), header=0)
            return testset.values
        else:
            self.load_dataset()
            if self.split_query:
                self.query_set = self.split_text(self.query_set)

            self.save_dataset()
            return self.query_set
        
    
    def save_dataset(self, output_dir = DATASET_DIR):
        save_path = os.path.join(output_dir, self.dataset_name)
        os.makedirs(save_path, exist_ok=True)
        if self.documents != None:
            train_file = os.path.join(save_path, "documents.csv")
            trainset = pd.DataFrame([[self.document_texts[i],x[self.label],self.document_embedding[i]] for (i,x) in enumerate(self.documents)], columns=["text", "embedding"], dtype=object)
            trainset.to_csv(train_file, index=True, header=True)
        else:
            print("No training set")
        if self.query_set != None:
            test_file = os.path.join(save_path, "queries.csv")
            testset = pd.DataFrame(self.query_set, columns=list(self.query_set[0].keys()), dtype=object)
            testset.to_csv(test_file, index=True, header=True)
        else:
            print("No test set")
  
class Documents():
    def __init__(self, collection_name,document_rate, n_bits=None):
        self.collection_name = collection_name
        self.document_rate = document_rate
        if n_bits == None:
            self.n_bits = config["n_bits"]
        else:
            self.n_bits = n_bits
    def load_collection(self):
        
        if "wikipedia" in self.collection_name:
            path = dict_path[self.collection_name]["passages"]
            
            print("Load wikipedia_dump")
            wiki = pd.read_csv(path, sep="\t").to_dict("records")
            print("Load Index")
            
            args = config(self.collection_name, self.document_rate)
            index = Indexer(args["projection_size"], args["n_subquantizers"],  args["n_bits"])
            input_paths = glob.glob(args["passages_embeddings"] +"*")
            input_paths = sorted(input_paths)
            embeddings_dir = os.path.dirname(input_paths[0])
            index.deserialize_from(embeddings_dir)
            
            collection = (wiki, {int(x["id"]): x for x in wiki}, index)
            return collection
            