from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import numpy as np
from source.controller.retriever.generate_passage_embeddings import main_modified
import os 
import pandas as pd 
import glob
import json
import datetime
import subprocess
import torch
from source.controller.retriever.contriever import Contriever
from transformers import AutoTokenizer
from source.controller.retriever.generate_passage_embeddings import embed_passages
import pickle
import time

folder_list = glob.glob("ADE_med/data_split_100/*", recursive=False)
collection_name = "ade_qa_med2" #wikipedia or ade_qa_med2

if collection_name == "wikipedia":
    INPUT_DIR = "data/wikipedia_dump/"
    OUTPUT_DIR = INPUT_DIR.replace("data_split", "embedding").replace("/data.tsv", "")
    os.mkdir(OUTPUT_DIR)


    model = Contriever.from_pretrained("facebook/contriever") 
    tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")

    model.eval()
    model = model.cuda()

    list_folder_path = glob.glob(INPUT_DIR + "/*", recursive=False)
    passages = []
    for folder in list_folder_path:
        list_file_path = glob.glob(folder + "/*", recursive=False)
        for file in list_file_path:
            with open(file) as f:
                    try:
                        for line in f:
                            passages.append(json.loads(line))
                            if len(passages) % 100000 == 0:
                                print(len(passages))
                    except:
                        print("error")
                        print(file)
    
    passages = pd.read_csv(INPUT_DIR, sep="\t").reset_index().rename(columns={"index":"id"})
    passages.to_csv(INPUT_DIR, sep="\t")
    passages = passages.to_dict("records")


    print("Total number of passages: ", len(passages))

    num_shards = 1
    starting_shard = 0

    for shard_id in range(starting_shard,num_shards):
        print("shard : ", shard_id)
        t1 = time.time()
        shard_size = len(passages) // num_shards
        start_idx = shard_id * shard_size
        end_idx = start_idx + shard_size
        if shard_id == num_shards - 1:
            end_idx = len(passages)

        new_passages = passages[start_idx:end_idx]    
        allids, allembeddings = embed_passages(new_passages, model, tokenizer)

        save_file = os.path.join(OUTPUT_DIR, f"_{shard_id:02d}")
        print(f"Saving {len(allids)} passage embeddings to {save_file}.")
        with open(save_file, mode="wb") as f:
            pickle.dump((allids, allembeddings), f)

        print(f"Total passages processed {len(allids)}. Written to {save_file}.")
        t2 = time.time()
        print("Time taken: ", t2-t1)

else:

    for folder in folder_list:
        INPUT_DIR = folder + "/data.tsv"
        OUTPUT_DIR = folder.replace("data_split", "embedding").replace("/data.tsv", "")
        os.mkdir(OUTPUT_DIR)


        model = Contriever.from_pretrained("facebook/contriever") 
        tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")

        model.eval()
        model = model.cuda()

        list_folder_path = glob.glob(INPUT_DIR + "/*", recursive=False)
        passages = []
        for folder in list_folder_path:
            list_file_path = glob.glob(folder + "/*", recursive=False)
            for file in list_file_path:
                with open(file) as f:
                        try:
                            for line in f:
                                passages.append(json.loads(line))
                                if len(passages) % 100000 == 0:
                                    print(len(passages))
                        except:
                            print("error")
                            print(file)
        
        passages = pd.read_csv(INPUT_DIR, sep="\t").reset_index().rename(columns={"index":"id"})
        passages.to_csv(INPUT_DIR, sep="\t")
        passages = passages.to_dict("records")


        print("Total number of passages: ", len(passages))

        num_shards = 1
        starting_shard = 0

        for shard_id in range(starting_shard,num_shards):
            print("shard : ", shard_id)
            t1 = time.time()
            shard_size = len(passages) // num_shards
            start_idx = shard_id * shard_size
            end_idx = start_idx + shard_size
            if shard_id == num_shards - 1:
                end_idx = len(passages)

            new_passages = passages[start_idx:end_idx]    
            allids, allembeddings = embed_passages(new_passages, model, tokenizer)

            save_file = os.path.join(OUTPUT_DIR, f"_{shard_id:02d}")
            print(f"Saving {len(allids)} passage embeddings to {save_file}.")
            with open(save_file, mode="wb") as f:
                pickle.dump((allids, allembeddings), f)

            print(f"Total passages processed {len(allids)}. Written to {save_file}.")
            t2 = time.time()
            print("Time taken: ", t2-t1)

