from dataset.dataset_utils.reader import ADIQDataset
from dataset.dataset_utils.question import Question
from dataset.utils import file_handle
import os
import os, sys
#import fitz
import re
import json
from json import JSONDecodeError
from datetime import datetime
from typing import Optional, List, Callable, Any, Tuple, Dict
from abc import abstractmethod, ABC
import random
import numpy as np
import pandas as pd
import copy
from dotenv import load_dotenv

sys.path.append(".")
load_dotenv(dotenv_path=".env")

def get_all_metadata(root_fol:str, ds_loc:str):
    if not os.path.exists(root_fol):
        print("Given folder is wrong: {}".format(root_fol))

    ds = ADIQDataset(ds_loc)
    ds_ques = {k.id:k for k in ds.questions}
    res_dict = {}

    for i,file in enumerate(os.listdir(root_fol)):
        results_list = file_handle.load_jsonl(
            os.path.join(root_fol,file)
        )
        results = {v["id"]:v for v in results_list}
        print(len(results.keys()), os.path.join(root_fol,file))

        model_id = file.replace(".jsonl", "")
        res_dict[model_id] = get_evaluation_metadata(ds_ques, results)

    return res_dict

### Step 1: Extract the JSON code block
def extract_json_block(text):
    match = re.search(r'{(.*?)}', text)
    if match:
        return match.group(1)
    return None

### Step 2: Clean the extracted string and convert to Python object
def clean_and_parse_json(json_str):
    # Unescape the string
    unescaped = json_str.encode().decode('unicode_escape')
    return unescaped

def get_evaluation_metadata(ds_ques, results):    
    evaluation_data = []
    #print(results)
    for k,v in ds_ques.items():
        data = {"id":k, "q":v, "decode-error":False, 'not-in-options':False, "no-answer":False, "true":None, "pred":None, "correct":0, "in-correct":0}
        try:
            _r = results[k]['model_output']
            #print (k)
            #print (results[k])
            try:
                _r = json.loads(_r)
            except JSONDecodeError as e:
                _r = extract_json_block(_r)
                if not _r:
                    raise JSONDecodeError("custom", str(_r),0)
                _r = clean_and_parse_json(_r)
                _r = json.loads("{"+_r+"}")

            if 'answer' not in _r:
                raise JSONDecodeError("custom", str(_r),0)
            
            pred = _r['answer']
            data['pred'] = pred
            if len(set(pred).intersection(set(v.option_ids))) <= 0:
                data['not-in-options'] = True

                if isinstance(pred, list) and len(pred) == 0:
                    data['no-answer'] =True
            
            cor = np.array(v.option_ids)[np.array(v.correct)].tolist()
            data['true'] = cor

            if len(set(cor).intersection(pred)) > 0 and len(pred) == 1:
                data['correct'] = 1

            if len(set(cor).intersection(pred)) > 0 and len(pred) >= 1:
                data['in-correct'] = 1

        except JSONDecodeError as e:
            data['decode-error'] = True

        evaluation_data.append(data)

    return evaluation_data


        
    
def asset_wise_metrics(_data) -> dict:
    assets = set([x["q"].asset_type for x in _data])
    asset_data = {}
    for k in assets:
        correct = [x["correct"] for x in _data if x["q"].asset_type == k]
        asset_data[k] = {'count':sum(correct), 'perc':sum(correct)/len(correct)}
    return asset_data

def question_type(_data) -> dict:
    assets = set([x["q"].q for x in _data])
    asset_data = {}
    for k in assets:
        correct = [x["correct"] for x in _data if x["q"].asset_type == k]
        asset_data[k] = {'count':sum(correct), 'perc':sum(correct)/len(correct)}
    return asset_data

def evaluate(_data):
    def condition_count_percentage(d):
        counts = len([x for x in _data if x[d]])
        perc = counts/len(_data)

        return {'count':counts, 'perc':perc}
    
    metrics = {}

    metrics['overall_acc'] = condition_count_percentage('correct')
    metrics['format_errors'] = condition_count_percentage('decode-error')
    metrics['no_answers'] = condition_count_percentage('no-answer')

    metrics['asset_wise_metrics'] = asset_wise_metrics(_data)

    return metrics
    
def metadata_evaluation(_all_metadata:dict):
    eval_metrics = {}
    for model_id, meta_data in _all_metadata.items():
        eval_metrics[model_id] = evaluate(meta_data)

    return eval_metrics



def get_eval_data(root_fol:str, ds_loc:str):
    _metadata = get_all_metadata(root_fol,ds_loc)
    eval_met = metadata_evaluation(_metadata)

    return eval_met

data_simple = get_all_metadata("benchmarking/RESULTS_JSONL/basic_bench/simpleV3.1", "dataset/datasets/simpleV3.1")
data_complex = get_all_metadata("benchmarking/RESULTS_JSONL/basic_bench/complexV3.1", "dataset/datasets/complexV3.1")
#data_simple_pert = get_all_metadata("benchmarking/RESULTS_JSONL/basic_bench/simplePertV3.1", "dataset/datasets/simplePertV3.1")

from dataset.dataset_utils.question import Question
q_prompt_map = {
}

metrics = []
for n, data in {"simpleV":data_simple, "complexV": data_complex}.items():
    for model,model_data in data.items():
        for dp in model_data:
            que:Question =  dp["q"]
            if dp["q"].question_prompt not in q_prompt_map:
                q_prompt_map[dp["q"].question_prompt] = len(q_prompt_map.keys())
            
            
            metrics.append({
                "model":model,
                "ds":n,
                "q_id":que.id,
                "q_type":que.question_type,
                "q_prompt":q_prompt_map[que.question_prompt],
                "asset_type":que.asset_type,
                "correct":dp["correct"],
                "in-correct":dp["in-correct"]
                })

metrics = pd.DataFrame.from_records(metrics)
metrics.to_csv('metric_data.csv',index=False)
print(metrics)

import pandas as pd

# Filter only simpleV rows
# simplev_res = metrics.loc[metrics["ds"] == "complexV"]
simplev_res = metrics.loc[metrics["ds"] == "simpleV"]

# Group by model + asset_type, compute accuracy
asset_wise = (
    simplev_res.groupby(["model", "asset_type"])["correct"]
    .agg(lambda x: x.sum() / x.count())
    .reset_index(name="accuracy")
)

# Drop unwanted asset_type
asset_wise = asset_wise.loc[asset_wise["asset_type"] != "AHU Humidity"]

# Format accuracy as percentage
asset_wise["accuracy"] = (asset_wise["accuracy"] * 100).round(2).astype(str) + "%"

# Convert to markdown
md_output = asset_wise.to_markdown(index=False)

# Save to file
with open("asset_wise_accuracy_simple.md", "w") as f:
    f.write("# Per-Model Per-Asset Accuracy (ds = simpleV)\n\n")
    f.write(md_output)

print("Markdown file created: asset_wise_accuracy.md")

import pandas as pd

# Filter only simpleV rows
simplev_res = metrics.loc[metrics["ds"] == "simpleV"]

# Group by model + asset_type, compute accuracy
asset_wise = (
    simplev_res.groupby(["model", "asset_type"])["correct"]
    .agg(lambda x: x.sum() / x.count())
    .reset_index(name="accuracy")
)

# Drop unwanted asset_type
asset_wise = asset_wise.loc[asset_wise["asset_type"] != "AHU Humidity"]

# Pivot so each asset is a column
pivot_table = asset_wise.pivot(index="model", columns="asset_type", values="accuracy")

# Convert accuracy to percentage strings
pivot_table = (pivot_table * 100).round(2).astype(str) + "%"

# Reset index so "model" is a column again
pivot_table = pivot_table.reset_index()

# Convert to markdown
md_output = pivot_table.to_markdown(index=False)

# Save to file
with open("asset_wise_accuracy_simple.md", "w") as f:
    f.write("# Per-Model Accuracy by Asset (ds = simpleV)\n\n")
    f.write(md_output)

print("Markdown file created: asset_wise_accuracy.md")
