# %%
import os
import fitz
import sys
import re
import json
from datetime import datetime
from typing import Optional, List, Callable, Any, Tuple, Dict
from abc import abstractmethod, ABC
import random
import numpy as np
import pandas as pd
import copy
import nltk
from nltk.corpus import stopwords
import pickle
import itertools
from dataclasses import dataclass, asdict
from enum import Enum
from dotenv import load_dotenv

sys.path.append("../")

load_dotenv(dotenv_path="../.env")
nltk.download('stopwords')


SAVE_LOC = "rationale/saved_rationale_simple_contrast_prompt.jsonl"
# %%
system_prompt = "You are a assistant that analyses industrial asset health and guides humans to maintain said assets"

question_template = """
### Asset Description:
{asset_type}: {asset_description}

### Conditions:
{conditions}

### How long the conditions were met:
{temporal_condition}

{question_prompt}
{options}
"""

rational_template = """
Generate detailed asset rationales for guidance (”Guidance:”) based on the asset description (”Asset Description:”) 
and conditions (”Conditions:”) shown by the asset. These rationales should be the crucial cue for the guidance. 
Pretend that you don’t know the guidance (“Guidance:”). Compare and contrast with all the options to come to the answer
as the conclusion.

# Example 1
{example1}

# Example 2
{example2}

# Example 3
{example3}

# Question
{question}

# Answer
{answer}

# Guidance Rationale:
"""

# %%
from dataset_utils.reader import ADIQDataset

ds = ADIQDataset("../dataset/datasets/simpleV3.1")

# %%
from utils import file_handle

rated_examples = file_handle.load_json("rationale/rated_examples2.json")
flatten_rated_examples = [x for v in rated_examples.values() for x in v]
flatten_rated_examples = sorted(flatten_rated_examples, key= lambda x: x['rating'], reverse=True)


def select_examples(asset_type, examples=rated_examples):
    local_flatten = copy.deepcopy(flatten_rated_examples)
    try:
        examples_asset_type = examples[asset_type]
    except KeyError as ke:
        print("No Examples found for Type:", asset_type)
        examples_asset_type = []

    num_samples = len(examples_asset_type)
    
    if num_samples>3:
        return random.sample(examples_asset_type,3)
    elif num_samples == 3:
        return examples_asset_type
    else:
        num_extra = 3 - num_samples
        if num_samples>0:
            sel_ids = [x["id"] for x in examples_asset_type]
            local_flatten = [x for x in local_flatten if x["id"] not in sel_ids]

        extra = random.sample(local_flatten[:5], num_extra)
        examples_asset_type.extend(extra)
        return examples_asset_type 

# %% [markdown]
# ### Retrieve Answers

# %%
from models_utils.utils.concurrency import concurrent_dict_execution
from tqdm import tqdm
from benchmarking.bench_utils.inference_calls import LLMConfiguration, ModelConfig, MODEL_MAP
from dataset_utils.question import Question

model_config = ModelConfig(**{
        "name":'mistral-large',
        "identifier" : 'mistralai/mistral-large',
    })
client = LLMConfiguration(model_config)

def generate_rationale_for_question(que:Question):
    que_prompt = question_template.format(
        asset_type = que.asset_type,
        asset_description = ds.asset_descriptions.get(que.asset_type, "NONE"),
        conditions = "\n".join(list(map(lambda x:"- "+x, que.condition_description))),
        temporal_condition = que.temporal_condition[0] if len(que.temporal_condition)>0 else "NONE",
        question_prompt = que.question_prompt,
        options = "\n".join(["{}. {}".format(op_id, op) for op_id, op in zip(que.option_ids,que.options)]))
    
    _ind_cor = [i for i,x in enumerate(que.correct) if x][0]
    answer = f"\nAnswer: {que.option_ids[_ind_cor]}. {que.answer_str}\n"

    [ex1,ex2,ex3] = select_examples(que.asset_type)    

    rational_example = rational_template.format(
        example1 = ex1['text'],
        example2 = ex2['text'],
        example3 = ex3['text'],
        question = que_prompt,
        answer = answer
    )

    response = client.get_response(rational_example)

    if response:
        file_handle.save_jsonl({
            "id":que.id,
            "full_id":que.question_id,
            "examples":[ex1,ex2,ex3],
            "prompt":rational_example,
            "model_config":{**model_config.to_dict()},
            "rationale":response
        }, SAVE_LOC)
    else:
        print("None response")
        print(rational_example)
        raise ValueError("please check")


if os.path.exists(SAVE_LOC):
    keys = [x["id"] for x in file_handle.load_jsonl_generator(SAVE_LOC)]
else:
    keys = []

params = {q.id:[q] for q in ds.questions if q.id not in keys}

if params:
    {k:v for k,v in concurrent_dict_execution(generate_rationale_for_question, params)}
    

    


# %%


# %%



