import os, sys
import fitz
import re
import json
from datetime import datetime
from typing import Optional, List, Callable, Any, Tuple, Dict
from abc import abstractmethod, ABC
import random
import numpy as np
import pandas as pd
import copy
import nltk
from nltk.corpus import stopwords
import pickle
from tqdm.autonotebook import tqdm
import itertools
import argparse
from dotenv import load_dotenv

sys.path.append("../")

load_dotenv(dotenv_path="../.env")
nltk.download('stopwords')

random.seed(42)
np.random.seed(42)

BENCH_TAG = "basic_bench"
BENCH_ID = "complexV3.1"#str(datetime.now().isoformat())
SAVE_LOC = os.path.join("results",BENCH_TAG, BENCH_ID)
os.makedirs(SAVE_LOC, exist_ok=True)

from dataset.dataset_utils.reader import ADIQDataset
from dataset.dataset_utils.question import Question


ds = ADIQDataset("../dataset/datasets/complexV3.1")
data_metadata = {"dataset_id":"complexV3.1"}

from bench_utils.question import MultipleChoiceQuestion
from bench_utils.inference_calls import LLMConfiguration, ModelConfig, MODEL_MAP

_template = """
## Asset Description:
{asset_type}: {asset_desc}

## Conditions:
{condition_str}

## How long the conditions were met:
{temporal_condition}

{question_prompt}
"""

def question_templating(q:Question, desc_dict:Dict[str,str]) -> str:
    asset_type = q.asset_type
    asset_desc = desc_dict.get(q.asset_type, "NONE")
    condition_str = "\n".join(list(map(lambda x:"- "+x, q.condition_description)))
    temporal_condition = q.temporal_condition[0] if len(q.temporal_condition)>0 else "None"
    question_prompt = q.question_prompt
    return _template.format(
        asset_type = asset_type,
        asset_desc = asset_desc,
        condition_str = condition_str,
        temporal_condition = temporal_condition,
        question_prompt = question_prompt
        )


models_to_test:List[ModelConfig] = [
    ModelConfig(**{
        "name":'granite-3-3-8b-instruct',
        "identifier" : 'ibm/granite-3-3-8b-instruct',
    }),
    ModelConfig(**{
        "name":'mistral-small-3-1-24b-instruct-2503',
        "identifier" : 'mistralai/mistral-small-3-1-24b-instruct-2503',
    }),
    ModelConfig(**{
        "name":'llama-4-maverick-17b-128e-instruct-fp8',
        "identifier" : 'meta-llama/llama-4-maverick-17b-128e-instruct-fp8',
    }),
    #ModelConfig(**{
    #    "name":'deepseek-r1',
    #    "identifier" : 'deepseek-ai/deepseek-r1',
    #}),
    ModelConfig(**{
        "name":'mistral-medium-2505',
        "identifier" : 'mistralai/mistral-medium-2505',
    }),
    ]


"""
ModelConfig(**{
        "name":'mistral-large',
        "identifier" : 'mistralai/mistral-large',
    }),
    ModelConfig(**{
        "name":'llama-3-3-70b-instruct',
        "identifier" : 'llama-3-3-70b-instruct',
    }),
    ModelConfig(**{
        "name":'qwen2-5-72b-instruct',
        'identifier':'Qwen/Qwen2.5-72B-Instruct'
    }),
    ModelConfig(**{
        "name":"llama-3-1-405b-instruct-fp8",
        "identifier":"llama-3-1-405b-instruct-fp8"
    }),
    ModelConfig(**{
        "name":"llama-3-1-8b-instruct",
        "identifier":"llama-3-1-8b-instruct"
    }),
    ModelConfig(**{
        "name":"microsoft-phi-4",
        "identifier":"microsoft-phi-4"
    }),
    #ModelConfig(**{
    #    "name":"mistral-small-3-1-24b-2503",
    #    "identifier":"mistral-small-3-1-24b-2503"
    #}),
    ModelConfig(**{
        "name":"o1",
        "identifier":"o1"
    }),
    #ModelConfig(**{
    #    "name":"o3-mini",
    #    "identifier":"o3-mini"
    #}),
    #ModelConfig(**{
    #    "name":"gpt-4.1-mini",
    #    "identifier":"gpt-4.1-mini"
    #}),
    #ModelConfig(**{
    #    "name":"gpt-4.1",
    #    "identifier":"gpt-4.1"
    #}),
    ModelConfig(**{
        "name":"o1",
        "identifier":"o1"
    })
"""

_ind = list(range(len(models_to_test)))
random.shuffle(_ind)
print(_ind)
models_to_test = [models_to_test[x] for x in _ind]

config = {
    "model_to_test": models_to_test
}


def ask_question_from_llm(q: Question, model:LLMConfiguration) -> dict[str,Any]:
    #setting up prompts configs
    q.question_first = True
    q.text_type = "choice"
    q.question = question_templating(q, ds.asset_descriptions)

    mcq = MultipleChoiceQuestion()
    mcq.load_dict(q.to_dict())

    prompt = mcq.get_prompt()
    response =  model.get_response(prompt)

    a = {
                "id": q.id,
                "prompt": prompt,
                "question_text": mcq.question,
                "options_text": mcq.options,
                "true_answer": mcq.correct,
                "model": model.__dict__,
                "model_output": response,
                "model_original_output": response,
            }
    
    return a

from models_utils.utils.concurrency import concurrent_dict_execution
from dataset.utils import file_handle

def run_bench_for_llm(llm:ModelConfig, dataset:ADIQDataset, already:List[str]) -> dict[int, dict[str,Any]]:
    print("Run Experiment for {llm}".format(llm=llm))
    client = LLMConfiguration(llm)
    param_dict = {v.id:[v, client] for v in dataset.questions if str(v.id) not in already}

    if param_dict:
        model_data = MODEL_MAP.get(llm.name)
        results = {k:v for k,v in concurrent_dict_execution(
                ask_question_from_llm,
                param_dict,
                num_max_workers=model_data.get("num_workers",4)
            )}
        return results
    else:
        return {}

def run_llm_suit(ds:ADIQDataset, save_loc_data:str, llm_list=config['model_to_test']):
    os.makedirs(save_loc_data, exist_ok=True)
    for l in llm_list:
        if os.path.exists(os.path.join(SAVE_LOC, f'{l.name}.json')):
            print("Experiment exist for {llm}".format(llm=l.name))
            
            _already = file_handle.load_json(
                os.path.join(SAVE_LOC, f'{l.name}.json'
            ))

            bench_part = run_bench_for_llm(
                l,
                ds,
                list(_already["results"].keys())
            )

            _already["results"].update(bench_part)

            file_handle.save_json(
                _already,
                os.path.join(save_loc_data, f'{l.name}.json')
            )

        else:

            _temp = copy.deepcopy(l.__dict__)
            _temp["results"] = run_bench_for_llm(
                l,
                ds,
                []
            )

            file_handle.save_json(
                _temp,
                os.path.join(save_loc_data, f'{l.name}.json')
            )


if __name__ == "__main__":
    run_llm_suit(ds,SAVE_LOC)