# %%
import Levenshtein
import json
from rouge_score import rouge_scorer
from rouge_score.tokenizers import Tokenizer
from tqdm import tqdm
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import logging

rouge_logger = logging.getLogger('rouge_scorer')

rouge_logger.addHandler(logging.NullHandler())

rouge_logger.propagate = False

def response_text_without_think(response_text):
    if "<|im_end|>" in response_text:
        response_text = response_text.split("<|im_end|>")[0].strip()
    if "</think>" in response_text:
        response_text = response_text.split("</think>")[-1].strip()
    if '## Final Response\n\n' in response_text:
        response_text = response_text.split('## Final Response\n\n')[-1].strip()
        
    try:
        _ = json.loads(response_text)
        return response_text
    except:
        pass
    
    try:
        tmp_response_text = response_text.split('\n')[-1].strip()
        _ = json.loads(tmp_response_text)
        return tmp_response_text
    except:
        pass
    
    try:
        start = response_text.index('{')
        end = response_text.rindex('}') + 1
        
        tmp_response_text = response_text[start:end]
        _ = json.loads(tmp_response_text)
        return tmp_response_text
    except:
        pass
    
    return response_text

def event_filter(data, event_type):
    try:
        if isinstance(data['ground_truth'], str):
            ground_truth = json.loads(data['ground_truth'])
        else:
            ground_truth = data['ground_truth']

        if isinstance(ground_truth, dict) and 'event_type' in ground_truth:
            return ground_truth['event_type'] == event_type
    except (json.JSONDecodeError, TypeError):
        return False
    return False

def calculate_event_type(data):
    try:
        if isinstance(data['ground_truth'], str):
            ground_truth = json.loads(data['ground_truth'])
        else:
            ground_truth = data['ground_truth']

        if isinstance(ground_truth, dict) and 'event_type' in ground_truth:
            return ground_truth['event_type']
    except (json.JSONDecodeError, TypeError):
        return None
    return None

def get_value_complexity(obj):
    if isinstance(obj, dict):
        if not obj: return 1
        return sum(get_value_complexity(v) for v in obj.values())
    elif isinstance(obj, list):
        if not obj: return 1
        return sum(get_value_complexity(v) for v in obj)
    else:
        return 1

def numeric_similarity(v1: float, v2: float) -> float:
    if v1 == v2:
        return 1.0
    denominator = max(abs(v1), abs(v2))
    if denominator == 0:
        return 1
    return max(0.0, 1.0 - abs(v1 - v2) / denominator)

def value_string_similarity(s1: str, s2: str) -> float:
    if not isinstance(s1, str) or not isinstance(s2, str):
        return 0.0
    if s1 == s2:
        return 1.0
    if not s1 or not s2:
        return 0.0

    distance = Levenshtein.distance(s1, s2)
    max_len = max(len(s1), len(s2))
    if max_len == 0: return 1.0
    lev_sim = 1.0 - (distance / max_len)

    try:
        rouge_types = ['rouge1', 'rouge2', 'rougeL']
        scorer = rouge_scorer.RougeScorer(rouge_types, use_stemmer=True)
        scores = scorer.score(s1, s2)
        rouge1_f = scores['rouge1'].fmeasure
        rouge2_f = scores['rouge2'].fmeasure
        rougeL_f = scores['rougeL'].fmeasure
    except Exception:
        rouge1_f, rouge2_f, rougeL_f = 0.0, 0.0, 0.0

    final_score = (lev_sim + rouge1_f + rouge2_f + rougeL_f) * (1/4)

    return final_score

def key_string_similarity(s1: str, s2: str) -> float:
    if not isinstance(s1, str) or not isinstance(s2, str):
        return 0.0
    if s1 == s2:
        return 1.0
    if not s1 or not s2:
        return 0.0
    distance = Levenshtein.distance(s1, s2)
    max_len = max(len(s1), len(s2))
    if max_len == 0: return 1.0
    return 1.0 - (distance / max_len)

# %%
def _calculate_similarity_recursive(obj1, obj2, fuzzy_key_threshold):
    if type(obj1) is not type(obj2):
        return 0.0

    if isinstance(obj1, dict):
        keys1 = set(obj1.keys())
        keys2 = set(obj2.keys())
        
        if not keys1 and not keys2:
            return 1.0

        exact_common_keys = keys1 & keys2
        unmatched_keys1 = list(keys1 - keys2)
        unmatched_keys2 = list(keys2 - keys1)
        
        fuzzy_matches = []
        temp_unmatched_keys2 = list(unmatched_keys2)
        for key1 in unmatched_keys1:
            best_match_key2 = None
            max_score = fuzzy_key_threshold - 1e-9
            for key2 in temp_unmatched_keys2:
                score = key_string_similarity(key1, key2)
                if score > max_score:
                    max_score = score
                    best_match_key2 = key2
            if best_match_key2:
                fuzzy_matches.append((key1, best_match_key2, max_score))
                temp_unmatched_keys2.remove(best_match_key2)

        key_score_numerator = len(exact_common_keys) + sum(score for _, _, score in fuzzy_matches)
        all_keys_union_len = len(keys1 | keys2)
        key_score = key_score_numerator / all_keys_union_len if all_keys_union_len > 0 else 1.0

        has_text_key = "text" in exact_common_keys

        if has_text_key:
            text_weight = 0.5
            other_weight = 0.5
            
            text_sim = _calculate_similarity_recursive(obj1["text"], obj2["text"], fuzzy_key_threshold)

            other_keys_in_gt = keys1 - {"text"}
            total_other_complexity = sum(get_value_complexity(obj1[k]) for k in other_keys_in_gt) if other_keys_in_gt else 0
            achieved_other_complexity_score = 0
            
            other_exact_matches = exact_common_keys - {"text"}
            for key in other_exact_matches:
                if "id" in key.lower():
                    sim = 1.0
                else:
                    sim = _calculate_similarity_recursive(obj1[key], obj2[key], fuzzy_key_threshold)
                achieved_other_complexity_score += sim * get_value_complexity(obj1[key])

            other_fuzzy_matches = [fm for fm in fuzzy_matches if fm[0] != "text"]
            for key1, key2, _ in other_fuzzy_matches:
                if "id" in key1.lower():
                    sim = 1.0
                else:
                    sim = _calculate_similarity_recursive(obj1[key1], obj2[key2], fuzzy_key_threshold)
                achieved_other_complexity_score += sim * get_value_complexity(obj1[key1])
            
            avg_other_sim = (achieved_other_complexity_score / total_other_complexity) if total_other_complexity > 0 else 1.0
            
            value_score = text_weight * text_sim + other_weight * avg_other_sim

        else:
            total_gt_complexity = sum(get_value_complexity(v) for v in obj1.values()) if obj1 else 0
            achieved_complexity_score = 0
            
            for key in exact_common_keys:
                if "id" in key.lower():
                    sim = 1.0
                else:
                    sim = _calculate_similarity_recursive(obj1[key], obj2[key], fuzzy_key_threshold)
                achieved_complexity_score += sim * get_value_complexity(obj1[key])
                
            for key1, key2, _ in fuzzy_matches:
                if "id" in key1.lower():
                    sim = 1.0
                else:
                    sim = _calculate_similarity_recursive(obj1[key1], obj2[key2], fuzzy_key_threshold)
                achieved_complexity_score += sim * get_value_complexity(obj1[key1])

            if total_gt_complexity > 0:
                value_score = achieved_complexity_score / total_gt_complexity
            else:
                value_score = 0.0 if keys2 else 1.0

        return key_score * value_score

    if isinstance(obj1, list):
        list1 = obj1
        list2 = obj2
        if not list1 and not list2: return 1.0

        total_gt_complexity = sum(get_value_complexity(v) for v in list1)
        if total_gt_complexity == 0:
            return 1.0 if not list2 else 0.0

        achieved_complexity_score = 0
        available_gt_items = list(enumerate(list1)) 

        for item2 in list2:
            best_match_info = None
            max_sim = -1.0
            
            if not available_gt_items: break

            for gt_index, item1 in available_gt_items:
                sim = _calculate_similarity_recursive(item1, item2, fuzzy_key_threshold)
                if sim > max_sim:
                    max_sim = sim
                    best_match_info = (gt_index, item1, max_sim)
            
            if best_match_info:
                gt_index, gt_item, sim_score = best_match_info
                achieved_complexity_score += sim_score * get_value_complexity(gt_item)
                available_gt_items = [item for item in available_gt_items if item[0] != gt_index]

        return achieved_complexity_score / total_gt_complexity

    if isinstance(obj1, str):
        return value_string_similarity(obj1, obj2)
    
    if isinstance(obj1, (int, float)):
        return numeric_similarity(obj1, obj2)

    return 1.0 if obj1 == obj2 else 0.0

# %%

def env_pred_score(
    ground_truth: str, 
    response_text_without_think: str,
    fuzzy_key_threshold: float = 0.8
) -> float:
    try:
        gt_dict = json.loads(ground_truth)
        res_dict = json.loads(response_text_without_think)
    except Exception as e:
        return 0.0
    return _calculate_similarity_recursive(gt_dict, res_dict, fuzzy_key_threshold)

def plot_score_distributions(scores_dict, event_filter_type=None):
    if not scores_dict:
        print("No scores data available to plot.")
        return

    plot_data = []
    for model_name, scores_list in scores_dict.items():
        for score in scores_list:
            plot_data.append({
                'Score': score['env_pred_score'],
                'Model': model_name
            })

    if not plot_data:
        print("No data available for plotting.")
        return

    df = pd.DataFrame(plot_data)

    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(12, 8))

    sns.kdeplot(data=df, x='Score', hue='Model', fill=True,
                common_norm=False, alpha=0.4, linewidth=1.5)

    plt.xlabel('env_pred_score', fontsize=12)
    plt.ylabel('Density', fontsize=12)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    
    title = f"Scores Distribution ({event_filter_type if event_filter_type else 'All Events'})"
    plt.title(title, fontsize=16)
    plt.show()