# %%
import os, sys
import fitz
import re
import json
from datetime import datetime
from typing import Optional, List, Callable, Any, Tuple, Dict, Self, Union, TypedDict
from abc import abstractmethod, ABC
import random
import numpy as np
import pandas as pd
import copy
import nltk
from nltk.corpus import stopwords
import pickle
from tqdm.autonotebook import tqdm
import itertools
from dotenv import load_dotenv
import string

sys.path.append("../")

load_dotenv(dotenv_path="../.env")
nltk.download('stopwords')

# %%
MAIN_DIR = "MRie_rationale"

# %%
from dataset.dataset_utils.reader import ADIQDataset
from dataset.dataset_utils.question import Question


ds = ADIQDataset("../dataset/datasets/simpleV3.2")
TIMES = 10

# %%
unq_obs = ds.unique_observations
rule_info = ds.rule_info["rule_set"]

# %%
asset_actions = {}

for r in rule_info:
    if r['asset'] in asset_actions:
        asset_actions[r['asset']] = asset_actions[r['asset']].union(set([unq_obs[x] for x in r['display_text']['observations']]))
    else:
        asset_actions[r['asset']] = set([unq_obs[x] for x in r['display_text']['observations']])

# %%
from dataset.utils import file_handle

load_data = file_handle.load_json("rule_logic_test2.json")

# %%
for k,v in load_data.items():
    print(v.keys())
    asset_actions[v['asset']] = asset_actions[v['asset']].union(set(v["observations"]))

# %%
from bench_utils.question import MultipleChoiceQuestion
from bench_utils.inference_calls import LLMConfiguration, ModelConfig, MODEL_MAP

#model_config = ModelConfig(**{
#        "name":'claude-3-7-sonnet',
#        "identifier" : 'GCP/claude-3-7-sonnet',
#   })

model_config = ModelConfig(**{
        "name":'mistral-large',
        "identifier" : 'mistralai/mistral-large',
    })
client = LLMConfiguration(model_config)

# %%
SAVE_LOC = os.path.join(MAIN_DIR,"without_examples",model_config.name)

# %%
response =  client.get_response("test are you online?")
response

# %%
import json
from json import JSONDecodeError

### Step 1: Extract the JSON code block
def extract_json_block(text):
    match = re.search(r'{(.*?)}', text)
    if match:
        return match.group(1)
    return None

### Step 2: Clean the extracted string and convert to Python object
def clean_and_parse_json(json_str):
    # Unescape the string
    unescaped = json_str.encode().decode('unicode_escape')
    return unescaped

def extract_json_block2(text):
    
    start = text[::-1].find(':"tluser"{')
    end = text.find('}')
    print(text[start-5:start+5], text[end-5:end+5], text[-1*start:end])
    return '{"result":'+text[-1*start:end]

def load_json_from_response(_r):
    try:
        _r = json.loads(_r)
    except JSONDecodeError as e:
        _r = extract_json_block(_r)
        if not _r:
            raise JSONDecodeError("custom", str(_r),0)

        _r = _r.replace("'",'"')
        _r = clean_and_parse_json("{"+_r+"}")
        _r = json.loads(_r)
    
    return _r

def retry_model(prompt,times=5):
    n = 0
    while n < times:
        try:
            response =  client.get_response(prompt)
            response = load_json_from_response(response)
            return {"error":False, "res":response}
        except JSONDecodeError as er:
            n+=1

    print("###", response)
    return {"error":True}

    

# %%
selection_no_example_template = """
Please RANK the options that MOST likely gives the reason for the conditions?

## Asset Description:
{asset_type}

## Conditions:
{condition_str}

## How long the conditions were met:
{temporal_condition}

Analyse the given conditions of the presented asset and rank the options that MOST likely gives the reason for the conditions?
{observations}
Your output must strictly follow this format:
"""
capital_letters = string.ascii_uppercase

def ranking_templating(_dic:Dict[str,Any], observation:List[str]) -> str:
    asset_type = _dic['asset']
    condition_str = "\n".join(list(map(lambda x:"- "+x, _dic['conditions'])))
    temporal_condition = _dic["temporal"] if len(_dic["temporal"])>0 else "None"
    observations = "\n".join([f"{capital_letters[i]}. {x['text']}" for i,x in enumerate(observation)])
    prompt =  selection_no_example_template.format(
        asset_type = asset_type,
        condition_str = condition_str,
        temporal_condition = temporal_condition,
        observations = observations
        )
    
    option_text= """
{"option": <list of the option tag each option tag should be in double qoutes e.g. """+str([f"{capital_letters[i]}" for i,x in enumerate(observation)])+""">,"score":<list of scoring value inline with rank ranging from 1,-1 eg: """+str([1-i/10 for i,x in enumerate(observation)])+""">,"rank":<list of the rank eg: """+str([10-i for i,x in enumerate(observation)])+""">}
    Your output in a single line:"""

    prompt += option_text

    return prompt

TEMP_FUNC = ranking_templating

def select_observations(dic, obs, topk):
    prompt = TEMP_FUNC(dic,obs)

    res = retry_model(prompt)

    if res["error"]:
        raise RuntimeError("Issue solve")
    
    assert len(res["res"]['score']) == len(res["res"]['option'])

    for _opt,_sco, _rank in zip(res["res"]['option'],res["res"]['score'],res["res"]['rank']):
        try:
            ind = capital_letters.find(_opt)
            obs[ind]["score"] = _sco
        except IndexError as ke:
            continue

    for o in obs:
        if 'score' not in o:
            o['score'] = -1

    scores = [o["score"] for o in obs]
    _ind = np.argsort(scores)[::-1]
    return [obs[x] for x in _ind[:topk]], [obs[x] for x in _ind[topk:]]


def DAC(_dic, observations, n=10, topk=3):
    random.shuffle(observations)
    #assert observations != _observations
    observations = [
        {"text":x, "tier":0} for x in observations
        ]
    
    terms = {}
    error = None
    for _ in range(3):
        try:
            terms = recursive_selection(_dic, observations, n, topk=topk)
            break
        except AssertionError as e:
            error = {"error":"AssertionError"}

    for o in terms:
        if 'score' not in o:
            o['score'] = -1

    terms = sorted(terms, key=lambda x: x["tier"]*10000+x["score"], reverse=True)
    return terms, error



def recursive_selection(_dic, observations, n=10, topk = 3):
    if len(observations) <= topk:
        return observations
    
    sel_obser_list = []
    unsel_obser_list = []
    for i in range(0,len(observations),n):
        i_s = i
        i_e = i+n

        if i_e>len(observations):
            obs = observations[i_s:]
        else:
            obs = observations[i_s:i_e]

        sel_obser, unsel_obser = select_observations(_dic, obs, topk=topk)
        for o in sel_obser:
            o["tier"] += 1

        sel_obser_list.extend(sel_obser)
        unsel_obser_list.extend(unsel_obser)

    selected = recursive_selection(_dic, sel_obser_list,n=n,topk=topk)
    return selected + unsel_obser_list

# %%
os.makedirs(SAVE_LOC, exist_ok=True)

# %%

for T in range(TIMES):
    for k,v in tqdm(load_data.items()):
        if file_handle.file_exist(SAVE_LOC, f"{T}_{k}.json"):
            continue
        n_obs = 10 #random.choice([12,11,10,9,8])
        n_top = 3 #random.choice([4,3,2]) 

        r, e = DAC(v,list(asset_actions[v['asset']]),n_obs, n_top)
        file_handle.save_json(
            {"ranking":r, "data":v, "n_obs": n_obs, "n_tops":n_top,"error":e},
            os.path.join(SAVE_LOC, f"{T}_{k}.json"))

# %% [markdown]
# ### With Examples

# %%
from sentence_transformers import SentenceTransformer

# 1. Load a pretrained Sentence Transformer model
model = SentenceTransformer("all-MiniLM-L6-v2")

# %%
qa_cond_str = ["\n".join(x.condition_description) for x in ds.questions]
QA_COND_ENC = model.encode(qa_cond_str)
QA_QA_ENC = (QA_COND_ENC @ QA_COND_ENC.T)
QA_QA_ENC = np.argsort(QA_QA_ENC, axis=0)


# %%
selection_exe_template = """
Your a helper in industrial asset maintainance, Please RANK the options that MOST likely gives the reason for the conditions?

## Use following Questions and answers as help for the ranking.
{examples}

# Question

## Asset Description:
{asset_type}

## Conditions:
{condition_str}

## How long the conditions were met:
{temporal_condition}

Analyse the given conditions of the presented asset and rank the options that MOST likely gives the reason for the conditions?
{observations}
  
Your output must strictly follow this format:
"""

question_template = """
You are given the Answer and the Guidance Rationale for the question.

### Asset Description:
{asset_type}: {asset_description}

### Conditions:
{conditions}

### How long the conditions were met:
{temporal_condition}

{question_prompt}
{options}
"""

def get_examples(asset, conditions):
    filtered = [i for i,x in enumerate(ds.questions) if x.asset_type == asset]
    conditions_string = "\n".join(conditions)

    con_enc = model.encode(conditions_string).reshape(1,-1)
    sim = (con_enc @ QA_COND_ENC.T).flatten()
    max_sort = np.argsort(sim)[::-1]
    
    question_sel = []
    for m in max_sort:
        if m in filtered:
            question_sel.append(m)

        if len(question_sel)>= N:
            break
    
    print("sim {} FROM {}".format(question_sel, len(filtered)))

    examples = ""
    for i,_ind in enumerate(question_sel):
        q:Question = ds.questions[_ind]
        examples += f"\n### Example {i+1}\n"
        examples += question_template.format(asset_type = q.asset_type,
            asset_description = ds.asset_descriptions.get(q.asset_type, "NONE"),
            conditions = "\n".join(list(map(lambda x:"- "+x, q.condition_description))),
            temporal_condition = q.temporal_condition[0] if len(q.temporal_condition)>0 else "NONE",
            question_prompt = q.question_prompt,
            options = "\n".join(["{}. {}".format(op_id, op) for op_id, op in zip(q.option_ids,q.options)]))
        
        _ind_cor = [i for i,x in enumerate(q.correct) if x][0]
        examples += f"\nAnswer: {q.option_ids[_ind_cor]}. {q.answer_str}\n"
        examples += f"\nGuidance Rationale: {q.rationale}\n"

    return examples



def ranking_example_templating(_dic:Dict[str,Any], observation) -> str:

    asset_type = _dic['asset']
    condition_str = "\n".join(list(map(lambda x:"- "+x, _dic['conditions'])))
    temporal_condition = _dic["temporal"] if len(_dic["temporal"])>0 else "None"
    observations = "\n".join([f"{capital_letters[i]}. {x['text']}" for i,x in enumerate(observation)])

    exes = get_examples(asset_type, _dic['conditions'])

    prompt =  selection_exe_template.format(
        asset_type = asset_type,
        condition_str = condition_str,
        temporal_condition = temporal_condition,
        observations = observations,
        examples = exes
        )
    
    option_text= """
{"option": <list of the option tag e.g. """+str([f"{capital_letters[i]}" for i,x in enumerate(observation)])+""">,"score":<list of scoring value inline with rank ranging from 1,-1 eg: """+str([1-i/10 for i,x in enumerate(observation)])+""">,"rank":<list of the rank eg: """+str([10-i for i,x in enumerate(observation)])+""">}
    Your output in a single line:"""

    prompt += option_text

    return prompt

# %%
SAVE_LOC = os.path.join(MAIN_DIR,"with_examples",model_config.name)
TEMP_FUNC = ranking_example_templating

os.makedirs(SAVE_LOC, exist_ok=True)

# %%
len(load_data.items())

# %%

for T in range(TIMES):
    for N in range(1,10):
        for k,v in tqdm(load_data.items()):
            if file_handle.file_exist(SAVE_LOC, f"{T}_{k}_N{N}.json"):
                continue
            n_obs = 10 #random.choice([12,11,10,9,8])
            n_top = 3 #random.choice([4,3,2])

            r, e = DAC(v,list(asset_actions[v['asset']]),n_obs, n_top)
            file_handle.save_json(
                {"ranking":r, "data":v, "n_obs": n_obs, "n_tops":n_top,"error":e}, 
                os.path.join(SAVE_LOC, f"{T}_{k}_N{N}.json"))

# %%


