import os, sys
import fitz
import re
import json
from datetime import datetime
from typing import Optional, List, Callable, Any, Tuple, Dict, Union
from abc import abstractmethod, ABC
import random
import numpy as np
import pandas as pd
import copy
import nltk
from nltk.corpus import stopwords
import pickle
from tqdm.autonotebook import tqdm
import itertools
import argparse
from dotenv import load_dotenv

sys.path.append("../")

load_dotenv(dotenv_path="../.env")
nltk.download('stopwords')

random.seed(42)
np.random.seed(42)

BENCH_TAG = "basic_bench"
BENCH_ID = "simpleV4"#str(datetime.now().isoformat())
SAVE_LOC = os.path.join("RESULTS_JSONL",BENCH_TAG, BENCH_ID)
os.makedirs(SAVE_LOC, exist_ok=True)

from dataset.dataset_utils.reader import ADIQDataset
from dataset.dataset_utils.question import Question


ds = ADIQDataset("../dataset/datasets/simpleV4")
data_metadata = {"dataset_id":"simpleV4"}
 
from bench_utils.question import MultipleChoiceQuestion
from bench_utils.inference_calls import LLMConfiguration, ModelConfig, MODEL_MAP

_template = """
## Asset Description:
{asset_type}: {asset_desc}

## Conditions:
{condition_str}

## How long the conditions were met:
{temporal_condition}

{question_prompt}
"""

def question_templating(q:Question, desc_dict:Dict[str,str]) -> str:
    asset_type = q.asset_type
    asset_desc = desc_dict.get(q.asset_type, "NONE")
    #condition_str = "\n".join(list(map(lambda x:"- "+x, q.condition_description)))
    condition_str = q.verberlized_conditions
    temporal_condition = q.temporal_condition[0] if len(q.temporal_condition)>0 else "None"
    question_prompt = q.question_prompt
    return _template.format(
        asset_type = asset_type,
        asset_desc = asset_desc,
        condition_str = condition_str,
        temporal_condition = temporal_condition,
        question_prompt = question_prompt
        )


models_to_test:List[ModelConfig] = [
    ModelConfig(**{
    "name":'granite-3-3-8b-instruct',
    "identifier" : 'ibm/granite-3-3-8b-instruct',
    }),
    ModelConfig(**{
        "name":'mistral-small-3-1-24b-instruct-2503',
        "identifier" : 'mistralai/mistral-small-3-1-24b-instruct-2503',
    }),
    ModelConfig(**{
        "name":'deepseek-v3-h200',
        "identifier" : 'deepseek-ai/deepseek-r1',
    }),
    ModelConfig(**{
        "name":'mistral-medium-2505',
        "identifier" : 'mistralai/mistral-medium-2505',
    }),
    ModelConfig(**{
        "name":'mistral-large',
        "identifier" : 'mistralai/mistral-large',
    }),
    ModelConfig(**{
        "name":'llama-3-3-70b-instruct',
        "identifier" : 'llama-3-3-70b-instruct',
    }),
    ModelConfig(**{
        "name":'qwen2-5-72b-instruct',
        'identifier':'Qwen/Qwen2.5-72B-Instruct'
    }),
    ModelConfig(**{
        "name":"llama-3-1-405b-instruct-fp8",
        "identifier":"llama-3-1-405b-instruct-fp8"
    }),
    ModelConfig(**{
        "name":"llama-3-1-8b-instruct",
        "identifier":"llama-3-1-8b-instruct"
    }),
    ModelConfig(**{
        "name":"microsoft-phi-4",
        "identifier":"microsoft-phi-4"
    })
]

models_to_test:List[ModelConfig] = [
    ModelConfig(**{
    "name":'o1',
    "identifier" : 'o1',
    }),
    ModelConfig(**{
        "name":'gemini-2.0-flash',
        "identifier" : 'GCP/gemini-2.0-flash',
    }),
    ModelConfig(**{
        "name":'gemini-1.5-pro',
        "identifier" : 'GCP/gemini-1.5-pro',
    }),
    ModelConfig(**{
        "name":'claude-3-5-haiku',
        "identifier" : 'GCP/claude-3-5-haiku',
    }),
    ModelConfig(**{
        "name":'claude-3-7-sonnet',
        "identifier" : 'GCP/claude-3-7-sonnet',
    })
]

"""
ModelConfig(**{
    "name":'granite-3-3-8b-instruct',
    "identifier" : 'ibm/granite-3-3-8b-instruct',
}),
ModelConfig(**{
    "name":'mistral-small-3-1-24b-instruct-2503',
    "identifier" : 'mistralai/mistral-small-3-1-24b-instruct-2503',
}),
ModelConfig(**{
    "name":'deepseek-v3-h200',
    "identifier" : 'deepseek-ai/deepseek-r1',
}),
ModelConfig(**{
    "name":'mistral-medium-2505',
    "identifier" : 'mistralai/mistral-medium-2505',
}),
ModelConfig(**{
    "name":'mistral-large',
    "identifier" : 'mistralai/mistral-large',
}),
ModelConfig(**{
    "name":'llama-3-3-70b-instruct',
    "identifier" : 'llama-3-3-70b-instruct',
}),
ModelConfig(**{
    "name":'qwen2-5-72b-instruct',
    'identifier':'Qwen/Qwen2.5-72B-Instruct'
}),
ModelConfig(**{
    "name":"llama-3-1-405b-instruct-fp8",
    "identifier":"llama-3-1-405b-instruct-fp8"
}),
ModelConfig(**{
    "name":"llama-3-1-8b-instruct",
    "identifier":"llama-3-1-8b-instruct"
}),
ModelConfig(**{
    "name":"microsoft-phi-4",
    "identifier":"microsoft-phi-4"
}),
ModelConfig(**{
        "name":'mistral-large',
        "identifier" : 'mistralai/mistral-large',
    }),
    ModelConfig(**{
        "name":'llama-3-3-70b-instruct',
        "identifier" : 'llama-3-3-70b-instruct',
    }),
    ModelConfig(**{
        "name":'qwen2-5-72b-instruct',
        'identifier':'Qwen/Qwen2.5-72B-Instruct'
    }),
    ModelConfig(**{
        "name":"llama-3-1-405b-instruct-fp8",
        "identifier":"llama-3-1-405b-instruct-fp8"
    }),
    ModelConfig(**{
        "name":"llama-3-1-8b-instruct",
        "identifier":"llama-3-1-8b-instruct"
    }),
    ModelConfig(**{
        "name":"microsoft-phi-4",
        "identifier":"microsoft-phi-4"
    }),
    #ModelConfig(**{
    #    "name":"mistral-small-3-1-24b-2503",
    #    "identifier":"mistral-small-3-1-24b-2503"
    #}),
    ModelConfig(**{
        "name":"o1",
        "identifier":"o1"
    }),
    #ModelConfig(**{
    #    "name":"o3-mini",
    #    "identifier":"o3-mini"
    #}),
    #ModelConfig(**{
    #    "name":"gpt-4.1-mini",
    #    "identifier":"gpt-4.1-mini"
    #}),
    #ModelConfig(**{
    #    "name":"gpt-4.1",
    #    "identifier":"gpt-4.1"
    #}),
    ModelConfig(**{
        "name":"o1",
        "identifier":"o1"
    })
"""

_ind = list(range(len(models_to_test)))
random.shuffle(_ind)
print(_ind)
models_to_test = [models_to_test[x] for x in _ind]

config = {
    "model_to_test": models_to_test
}

def is_json_serializable(obj) -> bool:
    try:
        json.dumps(obj)
        return True
    except (TypeError, OverflowError):
        return False


def ask_question_from_llm(q: Question, model:LLMConfiguration, model_config:ModelConfig, metadata:Dict[str,Union[str,int]]) -> dict[str,Any]:
    #setting up prompts configs
    q.question_first = True
    q.text_type = "choice"
    q.question = question_templating(q, ds.asset_descriptions)

    mcq = MultipleChoiceQuestion()
    mcq.load_dict(q.to_dict())

    prompt = mcq.get_prompt()
    #print(prompt)
    response =  model.get_response(prompt, **model_config.prompt_params)

    a = {
                "id": q.id,
                "prompt": prompt,
                "question_text": mcq.question,
                "options_text": mcq.options,
                "true_answer": mcq.correct,
                "model": {k:v for k,v in model.__dict__.items() if is_json_serializable(v)},
                "model_output": response,
                "model_original_output": response,
                **metadata,
            }
    
    return a

from models_utils.utils.concurrency import concurrent_dict_execution
from dataset.utils import file_handle

def run_bench_for_llm_inst(llm:ModelConfig, dataset:ADIQDataset, path:str, already:List[str], metadata={})-> None:
    print("Run Experiment for {llm}".format(llm=llm))
    client = LLMConfiguration(llm)
    param_dict = {v.id:[v, client, llm, metadata] for v in dataset.questions if v.id not in already}

    if param_dict:
        model_data = MODEL_MAP.get(llm.name)
        for k,v in concurrent_dict_execution(
                ask_question_from_llm,
                param_dict,
                num_max_workers=model_data.get("num_workers",4)
            ):

            file_handle.save_jsonl(v, path)

def run_llm_suit(ds:ADIQDataset, save_loc_data:str, llm_list=config['model_to_test']):
    os.makedirs(save_loc_data, exist_ok=True)
    for l in llm_list:
        if os.path.exists(os.path.join(SAVE_LOC, f'{l.name}.jsonl')):
            print("Experiment exist for {llm}".format(llm=l.name))
            
            _already = file_handle.load_jsonl(
                os.path.join(SAVE_LOC, f'{l.name}.jsonl'
            ))

            _already = [x["id"] for x in _already]
        else:
            _already = []

        run_bench_for_llm_inst(
            l,
            ds,
            os.path.join(SAVE_LOC, f'{l.name}.jsonl'),
            _already,
            metadata=copy.deepcopy(l.__dict__)
        )

        


if __name__ == "__main__":
    run_llm_suit(ds,SAVE_LOC)