import os, sys
import fitz
import re
import json
from datetime import datetime
from typing import Optional, List, Callable, Any, Tuple, Dict
from abc import abstractmethod, ABC
import random
import numpy as np
import pandas as pd
import copy
import nltk
from nltk.corpus import stopwords
import pickle
from tqdm.autonotebook import tqdm
import itertools
from dotenv import load_dotenv

sys.path.append("../")

load_dotenv(dotenv_path="../.env")
nltk.download('stopwords')

random.seed(42)
np.random.seed(42)

BENCH_TAG = "taxonomy_bench"
BENCH_ID = "simpleV3.1"#str(datetime.now().isoformat())
SAVE_LOC = os.path.join("results",BENCH_TAG, BENCH_ID)
os.makedirs(SAVE_LOC, exist_ok=True)

from dataset.dataset_utils.reader import ADIQDataset
from dataset.dataset_utils.question import Question


ds = ADIQDataset("../dataset/datasets/simpleV3.1")
data_metadata = {"dataset_id":"simpleV3.1"}

from bench_utils.question import MultipleChoiceQuestion
from bench_utils.inference_calls import LLMConfiguration, ModelConfig, MODEL_MAP

from sentence_transformers import SentenceTransformer, util
from dataset.utils import file_handle


emb_model = SentenceTransformer("all-mpnet-base-v2")

def get_db_data(loc = "../dataset/extracted/taxonomy/concepts_reduced.json"):
    db = file_handle.load_json(loc)
    list_keys = list(db.keys())
    keys_map = {k:v for k,v in enumerate(list_keys)}

    db_enc = emb_model.encode(list_keys)

    return keys_map, db, db_enc

DB_KEY_MAP, DB, DB_ENC = get_db_data()

def get_search_terms(terms:List[str], n =5):
    sen_enc = emb_model.encode(terms)
    
    sim = sen_enc@DB_ENC.T
    norm = np.linalg.norm(sim, axis=0)
    sim = np.divide(sim, norm+10-16)

    sel_term = []
    for i in range(len(terms)):
        sim_row = sim[i]
        max_arg = np.argsort(sim_row, axis=0)[::-1]
        c = 0

        for e in max_arg:
            k = DB_KEY_MAP[e]
            t = copy.deepcopy(DB[k])

            if t["meaning"]:
                t['sim'] = sim_row[e]
                sel_term.append(t)
                c+=1

            if c>= n:
                break

    
    sel_term = sorted(sel_term, key=lambda x:x['sim'])
    return sel_term
    

def format_search_terms(terms):
    txt = ""
    for t in terms:
        txt += "{} : {}".format(
            ",".join(t["words"]),
            t["meaning"]
        )

        txt += "\n"
    
    return txt

def get_helpful_words(asset, conditions = None, topk=5):
    search_terms = [asset]
    if conditions:
        search_terms.extend(conditions)

    search_terms = get_search_terms(search_terms)
    search_terms = search_terms[:topk]
    search_terms = format_search_terms(search_terms)


    return search_terms

from bench_utils.question import MultipleChoiceQuestion
from bench_utils.inference_calls import LLMConfiguration, ModelConfig, MODEL_MAP

_template = """
## Asset Description:
{asset_type}: {asset_desc}

## Helpful words:
{helpful_words}

## Conditions:
{condition_str}

## How long the conditions were met:
{temporal_condition}

{question_prompt}
"""



def question_templating(q:Question, desc_dict:Dict[str,str]) -> str:
    asset_type = q.asset_type
    asset_desc = desc_dict.get(q.asset_type, "NONE")
    condition_str = "\n".join(list(map(lambda x:"- "+x, q.condition_description)))
    temporal_condition = q.temporal_condition[0] if len(q.temporal_condition)>0 else "None"
    question_prompt = q.question_prompt

    helpful_words = get_helpful_words(q.asset_type, q.condition_description)

    return _template.format(
        asset_type = asset_type,
        asset_desc = asset_desc,
        condition_str = condition_str,
        temporal_condition = temporal_condition,
        question_prompt = question_prompt,
        helpful_words = helpful_words
        )



models_to_test:List[ModelConfig] = [
    ModelConfig(**{
        "name":'mistral-large',
        "identifier" : 'mistralai/mistral-large',
    }),
    ModelConfig(**{
        "name":'llama-3-3-70b-instruct',
        "identifier" : 'llama-3-3-70b-instruct',
    }),
    ModelConfig(**{
        "name":'qwen2-5-72b-instruct',
        'identifier':'Qwen/Qwen2.5-72B-Instruct'
    }),
    ModelConfig(**{
        "name":"llama-3-1-405b-instruct-fp8",
        "identifier":"llama-3-1-405b-instruct-fp8"
    }),
    ModelConfig(**{
        "name":"llama-3-1-8b-instruct",
        "identifier":"llama-3-1-8b-instruct"
    }),
    ModelConfig(**{
        "name":"microsoft-phi-4",
        "identifier":"microsoft-phi-4"
    }),
    ]


"""
ModelConfig(**{
        "name":'mistral-large',
        "identifier" : 'mistralai/mistral-large',
    }),
    ModelConfig(**{
        "name":'llama-3-3-70b-instruct',
        "identifier" : 'llama-3-3-70b-instruct',
    }),
    ModelConfig(**{
        "name":'qwen2-5-72b-instruct',
        'identifier':'Qwen/Qwen2.5-72B-Instruct'
    }),
    ModelConfig(**{
        "name":"llama-3-1-405b-instruct-fp8",
        "identifier":"llama-3-1-405b-instruct-fp8"
    }),
    ModelConfig(**{
        "name":"llama-3-1-8b-instruct",
        "identifier":"llama-3-1-8b-instruct"
    }),
    ModelConfig(**{
        "name":"microsoft-phi-4",
        "identifier":"microsoft-phi-4"
    }),
    #ModelConfig(**{
    #    "name":"mistral-small-3-1-24b-2503",
    #    "identifier":"mistral-small-3-1-24b-2503"
    #}),
    ModelConfig(**{
        "name":"o1",
        "identifier":"o1"
    }),
    #ModelConfig(**{
    #    "name":"o3-mini",
    #    "identifier":"o3-mini"
    #}),
    #ModelConfig(**{
    #    "name":"gpt-4.1-mini",
    #    "identifier":"gpt-4.1-mini"
    #}),
    #ModelConfig(**{
    #    "name":"gpt-4.1",
    #    "identifier":"gpt-4.1"
    #}),
    ModelConfig(**{
        "name":"o1",
        "identifier":"o1"
    })
"""

random.shuffle(models_to_test)

config = {
    "model_to_test": models_to_test
}


def ask_question_from_llm(q: Question, model:ModelConfig) -> dict[str,Any]:
    #setting up prompts configs
    q.question_first = True
    q.text_type = "choice"
    q.question = question_templating(q, ds.asset_descriptions)

    mcq = MultipleChoiceQuestion()
    mcq.load_dict(q.to_dict())

    prompt = mcq.get_prompt()
    client = LLMConfiguration(model)
    response =  client.get_response(prompt)

    a = {
                "id": q.id,
                "prompt": prompt,
                "question_text": mcq.question,
                "options_text": mcq.options,
                "true_answer": mcq.correct,
                "model": model.__dict__,
                "model_output": response,
                "model_original_output": response,
            }
    
    return a

from models_utils.utils.concurrency import concurrent_dict_execution
from dataset.utils import file_handle

def run_bench_for_llm(llm:ModelConfig, dataset:ADIQDataset, already:List[str]) -> dict[int, dict[str,Any]]:
    print("Run Experiment for {llm}".format(llm=llm))

    param_dict = {v.id:[v, llm] for v in dataset.questions if str(v.id) not in already}

    if param_dict:
        model_data = MODEL_MAP.get(llm.name)
        results = {k:v for k,v in concurrent_dict_execution(
                ask_question_from_llm,
                param_dict,
                num_max_workers=model_data.get("num_workers",4)
            )}
        return results
    else:
        return {}

def run_llm_suit(ds:ADIQDataset, save_loc_data:str, llm_list=config['model_to_test']):
    os.makedirs(save_loc_data, exist_ok=True)
    for l in llm_list:
        if os.path.exists(os.path.join(SAVE_LOC, f'{l.name}.json')):
            print("Experiment exist for {llm}".format(llm=l.name))
            
            _already = file_handle.load_json(
                os.path.join(SAVE_LOC, f'{l.name}.json'
            ))

            bench_part = run_bench_for_llm(
                l,
                ds,
                list(_already["results"].keys())
            )

            _already["results"].update(bench_part)

            file_handle.save_json(
                _already,
                os.path.join(save_loc_data, f'{l.name}.json')
            )

        else:

            _temp = copy.deepcopy(l.__dict__)
            _temp["results"] = run_bench_for_llm(
                l,
                ds,
                []
            )

            file_handle.save_json(
                _temp,
                os.path.join(save_loc_data, f'{l.name}.json')
            )
    
run_llm_suit(ds,SAVE_LOC)