from data.dataset import FSdata, Documents
from source.controller.mcts import MCTS
from source.controller.q_model import QMCTS, Q_search

import datetime
from source.language_models.lama import Llama2HF
import json
import os 
from source.controller.ranking import top_k
from source.controller.retriever.contriever import Contriever
from transformers import AutoTokenizer
import argparse
from source.models.model import QA_model
from source.models.model import QT_QA_model



def evaluate(config):
    label = None

    date = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    output_dir = config["output_dir"] + date + config["method_name"] + "_" +  config["query_set_name"]+"/"
    os.mkdir(output_dir)
    
    
    query_set = FSdata(config["query_set_name"]).get_query_set()
   
    results = []
    baseline_results = []
    
    if "LLM" not in config["method_name"]:
        
        if "med_records" in config["collection_name"]:
            collection = None
            model = "med"
        else:
            collection = Documents(config["collection_name"], config["document_rate"], n_bits=config["n_bits"]).load_collection()
            model = "bool"

        if config["method_name"] == "mcts":
            method = MCTS(dataset_name=config["collection_name"], size=config["size"], collection=collection, save_graph=config["save_graph"], save_folder=output_dir, verbose=config["verbose"])
                
        elif "q_greedy" in config["method_name"]:
            method = Q_search(dataset_name=config["collection_name"], collection=collection, save_graph=config["save_graph"], save_folder=output_dir, collect_data=True, verbose=config["verbose"], q_esti_parents=f"./model/gbm_model_parents_{model}.pt",)
            label = "Q"
        
        elif "qmcts" in config["method_name"]:
            method = QMCTS(dataset_name=config["collection_name"], collection=collection,  size=config["size"], save_graph=config["save_graph"], save_folder=output_dir, collect_data=True, verbose=config["verbose"], q_esti_parents=f"./model/gbm_model_parents_{model}.pt",)
            method.treshold = 0.22
            label = "Q"

        if ("ade_qa_med2" in config["query_set_name"]):
            method.scoring_model = QT_QA_model(method.LLM, config["verbose"])

        method.new_document_rate = config["document_rate"]
        method.p_document = config["p_document"]
        method.thought_sample_size = config["thought_sample_size"]
        method.threshold = config["threshold"]        

        if config["document_rate"] == 0:
            method.p_document = 0
    
        if "test_size" in config:
            sample_size = config["test_size"]
        else:
            sample_size = len(query_set)
        
        for i,query in enumerate(query_set[:sample_size]):
            
            true_value = query[2]
            if config["training"]:
                label = true_value
                
            query = query[1]
            result = method.search(query, label)
            
            if "ade_qa_med2" in config["query_set_name"] :
                true_value = str(true_value).split("-")
                prediction = result["Best answer"][0]
                prediction_baseline = result["Baseline answer"]
                
                r = False
                br = False
                for v in true_value:
                    if v.strip() == prediction:
                        r = True
                    if v.strip() == prediction_baseline:
                        br = True
                
                results.append(int(r))
                baseline_results.append(int(br))
            

            else:
                results.append(int(result["Best answer"][0] == true_value))
                baseline_results.append(int(result["Baseline answer"] == true_value))
            print(f"Test point {i}/{len(query_set)} : good answer percent : ", sum(results) / len(results), "baseline : ", sum(baseline_results) / len(baseline_results))
            
            config["results"] = sum(results) / len(results)
            config["results baseline"] = sum(baseline_results) / len(baseline_results)
            config["sample id"] = i
            with open(output_dir + "config_w_result.json", "a+") as f:
                json.dump(config, f)
                f.write("\n")
    else:
        LLM = Llama2HF(config_path="source/language_models/config.json", model_name="TheBloke/Llama-2-70B-chat-AWQ", verbose = config["verbose"])
        
        if "test_size" in config:
            sample_size = config["test_size"]
        else:
            sample_size = len(query_set)
            
        if "RAG" in config["method_name"]:
            retriever = Contriever.from_pretrained("facebook/contriever").to("cuda")
            retriever_tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")
            collection = Documents(config["collection_name"], config["document_rate"], n_bits=config["n_bits"]).load_collection()

        
        if ("ade_qa_med2" in config["query_set_name"]):
            scoring_model = QT_QA_model(LLM, config["verbose"])
        else:
            scoring_model = QA_model(LLM, config["verbose"])
            
        for i,query in enumerate(query_set[:sample_size]):
            label = query[2]
            query = query[1]

            if "RAG" in config["method_name"]:
                doc =  top_k([query], dataset_name=config["collection_name"], k=config["document_rate"], retriever=retriever, tokenizer=retriever_tokenizer, collection=collection)[0]["ctxs"][0]["text"]
            else:
                doc = ""
                
            if "ade_qa_med2" in config["query_set_name"]:
                result = scoring_model.evaluate(query, [doc], label)[0]
                results.append(result)    
            else:
                result = scoring_model.predict(query, [doc])[0]    
                results.append(int(result == label))    
                
            print(f"Test point {i}/{len(query_set)} : good answer percent : ", sum(results) / len(results))
            config["results"] = sum(results) / len(results)
            config["sample id"] = i
            with open(output_dir + "config_w_result.json", "a+") as f:
                json.dump(config, f)
                f.write("\n")
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--method_name", type=str, default="mcts")

    # Model params
    parser.add_argument("--query_set_name", type=str, default="boolq_test")
    parser.add_argument("--collection_name", type=str, default="boolq_test")
    parser.add_argument("--document_rate", type=int, default=5)
    parser.add_argument("--size", type=int, default=10)
    parser.add_argument("--oracle", type=bool, default=False)
    parser.add_argument("--save_graph", type=bool, default=False)
    parser.add_argument("--output_dir", type=str, default="results/evaluate/")
    parser.add_argument("--verbose", type=bool, default=False)
    parser.add_argument("--test_size", type=int, default=-1)
    parser.add_argument("--n_bits", type=int, default=8)
    parser.add_argument("--p_document", type=float, default=0.5)
    parser.add_argument("--thought_sample_size", type=int, default=5)
    parser.add_argument("--threshold", type=float, default=0.5)

    args = parser.parse_args()

    
    config = vars(args)
    evaluate(config)
    
    

