# %%
import os
import fitz
import sys
import re
import json
from datetime import datetime
from typing import Optional, List, Callable, Any, Tuple, Dict
from abc import abstractmethod, ABC
import random
import numpy as np
import pandas as pd
import copy
import nltk
from nltk.corpus import stopwords
import pickle
import itertools
from dataclasses import dataclass, asdict
from enum import Enum
from tqdm import tqdm
from dotenv import load_dotenv

sys.path.append("../../")

load_dotenv(dotenv_path="../.env")
nltk.download('stopwords')

# %%
from utils import file_handle
from skyspark.utils.sky_spark_wrapper import Sensor, AssetDescription
from utils import file_handle
from utils import tree
from utils.tree import Node
from dataset_utils.reader import ADIQDataset

# %%
asset_desc = file_handle.load_json("skyspark/data3/asset_desc.json")
sensor_data = file_handle.load_json("skyspark/data3/sensors.json")
extracted_sensors = file_handle.load_json("skyspark/extracted/extracted_sensors_llm.json")


ds = file_handle.load_pickle("extracted/TreeStruct.pkl")
ds = {v['#n']:v for v in ds['rule_set']}
dataset = ADIQDataset("../dataset/datasets/simpleV3.1")

sensor_map = file_handle.load_jsonl("skyspark/extracted/sensor-map.jsonl")

# %%
MAIN_DIR = '/mnt/data/DiagIQ'
RAW_DATA = os.path.join(MAIN_DIR, 'processed', 'raw')

# %%
temp = copy.deepcopy(extracted_sensors)

# %%
available_obs = {x.replace(".json",""):{k:v for k,v in zip(['site_name', 'asset_name', "sensor_name"],x.replace(".json","").split("_"))} for x in os.listdir(RAW_DATA)}
extracted_sensors = {f"{k.split('_')[0]}_{v['original']}":{"id":k,"rule_id":k.split("_")[0],  **v} for k,v in extracted_sensors.items()}

# %%
rule_forms = [{
    'asset_type':q.asset_type, 
    'conditions':q.condition_description, 
    'rule':q.rule_id,
    'temporal_condition':q.temporal_condition
    } for q in dataset.questions]

def make_hashable(obj):
    """Recursively convert lists/dicts/sets into hashable tuples."""
    if isinstance(obj, dict):
        return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
    elif isinstance(obj, (list, set, tuple)):
        return tuple(make_hashable(i) for i in obj)
    else:
        return obj  # leave other types unchanged

def unique_dicts(dict_list):
    seen = set()
    unique = []
    for d in dict_list:
        h = make_hashable(d)
        if h not in seen:
            seen.add(h)
            unique.append(d)
    return unique

rule_forms = unique_dicts(rule_forms)
print(len(rule_forms))


# %%
temp = {}
for k,v in available_obs.items():
    try:
        temp[k] = {**v, **asset_desc[f"{v['site_name']}_{v['asset_name']}"]}
    except KeyError as ke:
        continue

available_obs = temp

# %%

from sentence_transformers import SentenceTransformer, util

SEMANTIC_MODEL = SentenceTransformer("all-mpnet-base-v2")

# %%
from collections import Counter

def search_sensor_map(rule, cond, entity):
    rule = int(rule)
    def match_rule(x):
        i = int(x["id"].split("_")[0])
        c = x['condition']
        e = x['entity']
        return i == rule and c == cond and e == entity
    
    matches = [x for x in sensor_map if match_rule(x)]
    if len(matches) == 1:
        print(rule, cond, entity, matches[0])
        return matches[0]
    elif len(matches)> 1:
        raise ValueError("can't happen check")
    else:
        return None
    
def sensor_ranking(req_sen:List[str], availble_sen:List[str]):
    req_enc = SEMANTIC_MODEL.encode(req_sen)
    avail_enc = SEMANTIC_MODEL.encode(availble_sen)

    availble_sen = np.array(availble_sen)

    sim = req_enc @ avail_enc.T
    sim = np.exp(sim - np.max(sim, axis=-1, keepdims=True))
    sim = np.divide(sim,np.sum(sim, axis=-1, keepdims=True))

    sim_args = np.argsort(sim, axis=-1)[:,::-1]
    
    rankings = []
    for sa in sim_args:
        rankings.append(availble_sen[sa].tolist())

    return rankings
    

def word_to_char_vector(word):
    return Counter(word)

def manhattan_distance(a, b):
    vec_a = word_to_char_vector(a)
    vec_b = word_to_char_vector(b)
    all_chars = set(vec_a.keys()).union(vec_b.keys())
    return sum(abs(vec_a[c] - vec_b[c]) for c in all_chars)

def match_entity_rank(entity, rank0):
    def match_entity(ent, ran):
        if ent == ran:
            return ran
        
        if ent in ran or ran in ent:
            return ran
        
        if manhattan_distance(ent, ran)/ len(ent) < 0.3:
            return ran
        
        return None
        


    data = []
    for e,r in zip(entity, rank0):
        data.append(match_entity(e,r))

    return data

def search_sensors_in_available(sensors, sen_n):
    sensor_names = [x['sensor_name'].split(x['asset_name'])[-1].strip() for x in sensors]
    rankings = sensor_ranking(sen_n, sensor_names)

    rank0 = []
    for ran in rankings:
        rank0.append(ran[0])

    matching = match_entity_rank(sen_n, rank0)
    return matching
    

# %%
triples = []
for ir, r in tqdm(enumerate(rule_forms), total=len(rule_forms), desc="processed:"):
    r['analysis'] = {}
    sel_avail_obs = [x for x in available_obs.values() if x['asset_type'] == r["asset_type"]]
    ind_assets = set()
    for so in sel_avail_obs:
        ind_assets.add(f'{so["site_name"]}_{so["asset_name"]}')

    for ia in ind_assets:
        asset_sel = asset_desc[ia]
        r['analysis'][ia] = {}
        sen_is = [x for x in available_obs.values() if (x['asset_name'] == ia.split("_")[1]) and (x['site_name'] == ia.split("_")[0])]
        
        for cond in r["conditions"]:
            cond_info = extracted_sensors[f"{r['rule']}_{cond}"]
            r['analysis'][ia][f"{r['rule']}_{cond}"] = {}
            entities = []
            
            for ent in cond_info["extracted"]:
                entities.extend(ent['entities'])

            if not entities:
                continue
            #for e in entities:
                #mat = search_sensor_map(r['rule'], cond, e)
                #if not mat:
                #    r['analysis'][ia][f"{r['rule']}_{cond}"][e] = False
                #    continue

            matches = search_sensors_in_available(sen_is, entities)

            for ent, mat in zip(entities, matches):
                if not mat:
                    r['analysis'][ia][f"{r['rule']}_{cond}"][ent] = False
                else:
                    r['analysis'][ia][f"{r['rule']}_{cond}"][ent] = True

            triples.append((cond, entities, matches))


# %%
file_handle.save_json(rule_forms, "temp_rule_forms_llm.json")
